You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/sem/Sema.cpp

201 lines
6.0 KiB

#include "sem/Sema.h"
#include <any>
#include <stdexcept>
#include <string>
#include "SysYBaseVisitor.h"
#include "sem/SymbolTable.h"
#include "utils/Log.h"
namespace {
std::string GetLValueName(SysYParser::LValueContext& lvalue) {
if (!lvalue.ID()) {
throw std::runtime_error(FormatError("sema", "非法左值"));
}
return lvalue.ID()->getText();
}
class SemaVisitor final : public SysYBaseVisitor {
public:
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少编译单元"));
}
auto* func = ctx->funcDef();
if (!func || !func->blockStmt()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
if (!func->ID() || func->ID()->getText() != "main") {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
func->accept(this);
if (!seen_return_) {
throw std::runtime_error(
FormatError("sema", "main 函数必须包含 return 语句"));
}
return {};
}
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override {
if (!ctx || !ctx->blockStmt()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
if (!ctx->funcType() || !ctx->funcType()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持 int main"));
}
const auto& items = ctx->blockStmt()->blockItem();
if (items.empty()) {
throw std::runtime_error(
FormatError("sema", "main 函数不能为空,且必须以 return 结束"));
}
ctx->blockStmt()->accept(this);
return {};
}
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少语句块"));
}
const auto& items = ctx->blockItem();
for (size_t i = 0; i < items.size(); ++i) {
auto* item = items[i];
if (!item) {
continue;
}
if (seen_return_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
}
current_item_index_ = i;
total_items_ = items.size();
item->accept(this);
}
return {};
}
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
if (ctx->decl()) {
ctx->decl()->accept(this);
return {};
}
if (ctx->stmt()) {
ctx->stmt()->accept(this);
return {};
}
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
std::any visitDecl(SysYParser::DeclContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
}
if (!ctx->btype() || !ctx->btype()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明"));
}
auto* var_def = ctx->varDef();
if (!var_def || !var_def->lValue()) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
}
const std::string name = GetLValueName(*var_def->lValue());
if (table_.Contains(name)) {
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
}
if (auto* init = var_def->initValue()) {
if (!init->exp()) {
throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化"));
}
init->exp()->accept(this);
}
table_.Add(name, var_def);
return {};
}
std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx || !ctx->returnStmt()) {
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
ctx->returnStmt()->accept(this);
return {};
}
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "return 缺少表达式"));
}
ctx->exp()->accept(this);
seen_return_ = true;
if (current_item_index_ + 1 != total_items_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
}
return {};
}
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "非法括号表达式"));
}
ctx->exp()->accept(this);
return {};
}
std::any visitVarExp(SysYParser::VarExpContext* ctx) override {
if (!ctx || !ctx->var()) {
throw std::runtime_error(FormatError("sema", "非法变量表达式"));
}
ctx->var()->accept(this);
return {};
}
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量"));
}
return {};
}
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式"));
}
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitVar(SysYParser::VarContext* ctx) override {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "非法变量引用"));
}
const std::string name = ctx->ID()->getText();
auto* decl = table_.Lookup(name);
if (!decl) {
throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name));
}
sema_.BindVarUse(ctx, decl);
return {};
}
SemanticContext TakeSemanticContext() { return std::move(sema_); }
private:
SymbolTable table_;
SemanticContext sema_;
bool seen_return_ = false;
size_t current_item_index_ = 0;
size_t total_items_ = 0;
};
} // namespace
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
SemaVisitor visitor;
comp_unit.accept(&visitor);
return visitor.TakeSemanticContext();
}