#include "sem/Sema.h" #include #include #include #include #include "SysYBaseVisitor.h" #include "sem/SymbolTable.h" #include "utils/Log.h" namespace { static BaseTypeKind BaseTypeFromBType(SysYParser::BTypeContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("sema", "缺少 bType")); } if (ctx->INT()) return BaseTypeKind::Int; if (ctx->FLOAT()) return BaseTypeKind::Float; throw std::runtime_error(FormatError("sema", "未知基础类型")); } static BaseTypeKind BaseTypeFromFuncType(SysYParser::FuncTypeContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("sema", "缺少 funcType")); } if (ctx->VOID()) return BaseTypeKind::Void; if (ctx->INT()) return BaseTypeKind::Int; if (ctx->FLOAT()) return BaseTypeKind::Float; throw std::runtime_error(FormatError("sema", "未知函数返回类型")); } class ConstEvalVisitor final : public SysYBaseVisitor { public: explicit ConstEvalVisitor(const SymbolTable& table) : table_(table) {} std::any visitConstExp(SysYParser::ConstExpContext* ctx) override { return visitAddExp(ctx->addExp()); } std::any visitAddExp(SysYParser::AddExpContext* ctx) override { auto muls = ctx->mulExp(); if (muls.empty()) return 0; int value = std::any_cast(muls[0]->accept(this)); for (size_t i = 1; i < muls.size(); ++i) { int rhs = std::any_cast(muls[i]->accept(this)); auto* node = ctx->children.at(2 * i - 1); auto text = node ? node->getText() : "+"; if (text == "+") { value += rhs; } else if (text == "-") { value -= rhs; } else { throw std::runtime_error(FormatError("sema", "非法加法运算符")); } } return value; } std::any visitMulExp(SysYParser::MulExpContext* ctx) override { auto unaries = ctx->unaryExp(); if (unaries.empty()) return 0; int value = std::any_cast(unaries[0]->accept(this)); for (size_t i = 1; i < unaries.size(); ++i) { int rhs = std::any_cast(unaries[i]->accept(this)); auto* node = ctx->children.at(2 * i - 1); auto text = node ? node->getText() : "*"; if (text == "*") { value *= rhs; } else if (text == "/") { value /= rhs; } else if (text == "%") { value %= rhs; } else { throw std::runtime_error(FormatError("sema", "非法乘法运算符")); } } return value; } std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override { if (ctx->primaryExp()) return ctx->primaryExp()->accept(this); if (ctx->unaryOp() && ctx->unaryExp()) { int val = std::any_cast(ctx->unaryExp()->accept(this)); auto op = ctx->unaryOp()->getText(); if (op == "+") return val; if (op == "-") return -val; throw std::runtime_error(FormatError("sema", "constExp 不支持 !")); } throw std::runtime_error(FormatError("sema", "constExp 不支持函数调用")); } std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override { if (ctx->exp()) return ctx->exp()->accept(this); if (ctx->lVal()) return ctx->lVal()->accept(this); if (ctx->number()) return ctx->number()->accept(this); return 0; } std::any visitNumber(SysYParser::NumberContext* ctx) override { if (ctx->INT_CONST()) { const std::string text = ctx->getText(); size_t idx = 0; long long val = std::stoll(text, &idx, 0); if (idx != text.size()) { throw std::runtime_error(FormatError("sema", "非法整数常量")); } return static_cast(val); } if (ctx->FLOAT_CONST()) { return static_cast(std::stof(ctx->getText())); } throw std::runtime_error(FormatError("sema", "constExp 仅支持整数")); } std::any visitLVal(SysYParser::LValContext* ctx) override { if (!ctx || !ctx->ID()) { throw std::runtime_error(FormatError("sema", "constExp 非法变量")); } const auto* entry = table_.Lookup(ctx->ID()->getText()); if (!entry || !entry->is_const || !entry->const_value.has_value()) { throw std::runtime_error(FormatError("sema", "constExp 使用了非常量")); } return entry->const_value.value(); } private: const SymbolTable& table_; }; class SemaVisitor final : public SysYBaseVisitor { public: std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { if (!ctx) { throw std::runtime_error(FormatError("sema", "缺少编译单元")); } for (auto* func : ctx->funcDef()) { if (!func || !func->ID()) continue; std::string name = func->ID()->getText(); if (func_table_.find(name) != func_table_.end()) { throw std::runtime_error(FormatError("sema", "重复定义函数: " + name)); } func_table_[name] = func; } for (auto* decl : ctx->decl()) { if (decl) decl->accept(this); } for (auto* func : ctx->funcDef()) { if (func) func->accept(this); } if (func_table_.find("main") == func_table_.end()) { throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); } return {}; } std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { if (!ctx || !ctx->block()) { throw std::runtime_error(FormatError("sema", "函数体为空")); } if (!ctx->ID()) { throw std::runtime_error(FormatError("sema", "缺少函数名")); } FuncTypeDesc fty; fty.ret.base = BaseTypeFromFuncType(ctx->funcType()); if (ctx->funcFParams()) { for (auto* param : ctx->funcFParams()->funcFParam()) { fty.params.push_back(BuildParamType(param)); } } sema_.RegisterFunc(ctx, fty); current_ret_ = fty.ret.base; seen_return_ = false; table_.EnterScope(); if (ctx->funcFParams()) { for (auto* param : ctx->funcFParams()->funcFParam()) { RegisterParam(param); } } ctx->block()->accept(this); table_.ExitScope(); if (current_ret_ != BaseTypeKind::Void && !seen_return_) { throw std::runtime_error(FormatError("sema", "非 void 函数缺少 return")); } return {}; } std::any visitBlock(SysYParser::BlockContext* ctx) override { if (!ctx) return {}; table_.EnterScope(); for (auto* item : ctx->blockItem()) { if (item) item->accept(this); } table_.ExitScope(); return {}; } std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { if (!ctx) return {}; if (ctx->decl()) return ctx->decl()->accept(this); if (ctx->stmt()) return ctx->stmt()->accept(this); return {}; } std::any visitDecl(SysYParser::DeclContext* ctx) override { if (!ctx) return {}; if (auto* c = ctx->constDecl()) return c->accept(this); if (auto* v = ctx->varDecl()) return v->accept(this); return {}; } std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override { if (!ctx || !ctx->bType()) return {}; BaseTypeKind base = BaseTypeFromBType(ctx->bType()); for (auto* def : ctx->constDef()) { RegisterConst(def, base); } return {}; } std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override { if (!ctx || !ctx->bType()) return {}; BaseTypeKind base = BaseTypeFromBType(ctx->bType()); for (auto* def : ctx->varDef()) { RegisterVar(def, base); } return {}; } std::any visitStmt(SysYParser::StmtContext* ctx) override { if (!ctx) return {}; if (ctx->lVal() && ctx->ASSIGN()) { ctx->lVal()->accept(this); if (ctx->exp()) ctx->exp()->accept(this); return {}; } if (ctx->block()) return ctx->block()->accept(this); if (ctx->IF()) { if (ctx->cond()) ctx->cond()->accept(this); if (ctx->stmt(0)) ctx->stmt(0)->accept(this); if (ctx->stmt(1)) ctx->stmt(1)->accept(this); return {}; } if (ctx->WHILE()) { loop_depth_++; if (ctx->cond()) ctx->cond()->accept(this); if (ctx->stmt(0)) ctx->stmt(0)->accept(this); loop_depth_--; return {}; } if (ctx->BREAK()) { if (loop_depth_ == 0) { throw std::runtime_error(FormatError("sema", "break 不在循环内")); } return {}; } if (ctx->CONTINUE()) { if (loop_depth_ == 0) { throw std::runtime_error(FormatError("sema", "continue 不在循环内")); } return {}; } if (ctx->RETURN()) { if (ctx->exp()) ctx->exp()->accept(this); if (current_ret_ == BaseTypeKind::Void && ctx->exp()) { throw std::runtime_error(FormatError("sema", "void 函数不能返回值")); } if (current_ret_ != BaseTypeKind::Void && !ctx->exp()) { throw std::runtime_error(FormatError("sema", "非 void 函数必须返回值")); } seen_return_ = true; return {}; } if (ctx->exp()) ctx->exp()->accept(this); return {}; } std::any visitExp(SysYParser::ExpContext* ctx) override { if (ctx->addExp()) return ctx->addExp()->accept(this); return {}; } std::any visitCond(SysYParser::CondContext* ctx) override { if (ctx->lOrExp()) return ctx->lOrExp()->accept(this); return {}; } std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override { for (auto* e : ctx->lAndExp()) e->accept(this); return {}; } std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override { for (auto* e : ctx->eqExp()) e->accept(this); return {}; } std::any visitEqExp(SysYParser::EqExpContext* ctx) override { for (auto* e : ctx->relExp()) e->accept(this); return {}; } std::any visitRelExp(SysYParser::RelExpContext* ctx) override { for (auto* e : ctx->addExp()) e->accept(this); return {}; } std::any visitAddExp(SysYParser::AddExpContext* ctx) override { for (auto* mul : ctx->mulExp()) mul->accept(this); return {}; } std::any visitMulExp(SysYParser::MulExpContext* ctx) override { for (auto* unary : ctx->unaryExp()) unary->accept(this); return {}; } std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override { if (ctx->primaryExp()) return ctx->primaryExp()->accept(this); if (ctx->ID() && ctx->LPAREN()) { std::string name = ctx->ID()->getText(); auto it = func_table_.find(name); if (it == func_table_.end()) { if (builtin_funcs_.find(name) == builtin_funcs_.end()) { throw std::runtime_error(FormatError("sema", "未定义的函数: " + name)); } } else { sema_.BindFuncCall(ctx, it->second); } if (ctx->funcRParams()) ctx->funcRParams()->accept(this); return {}; } if (ctx->unaryExp()) return ctx->unaryExp()->accept(this); return {}; } std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override { for (auto* e : ctx->exp()) e->accept(this); return {}; } std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override { if (ctx->exp()) return ctx->exp()->accept(this); if (ctx->lVal()) return ctx->lVal()->accept(this); if (ctx->number()) return ctx->number()->accept(this); return {}; } std::any visitNumber(SysYParser::NumberContext* ctx) override { if (!ctx->INT_CONST() && !ctx->FLOAT_CONST()) { throw std::runtime_error(FormatError("sema", "非法常量")); } return {}; } std::any visitLVal(SysYParser::LValContext* ctx) override { if (!ctx || !ctx->ID()) { throw std::runtime_error(FormatError("sema", "非法变量引用")); } std::string name = ctx->ID()->getText(); const SymbolEntry* entry = table_.Lookup(name); if (!entry) { throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); } BoundDecl bound; if (entry->kind == SymbolKind::Var) { bound.kind = BoundDecl::Kind::Var; bound.var_decl = entry->var_decl; } else if (entry->kind == SymbolKind::Const) { bound.kind = BoundDecl::Kind::Const; bound.const_decl = entry->const_decl; } else { bound.kind = BoundDecl::Kind::Param; bound.param_decl = entry->param_decl; } sema_.BindVarUse(ctx, bound); for (auto* exp : ctx->exp()) { if (exp) { exp->accept(this); } } return {}; } SemanticContext TakeSemanticContext() { return std::move(sema_); } private: TypeDesc BuildParamType(SysYParser::FuncFParamContext* ctx) { if (!ctx || !ctx->bType()) { throw std::runtime_error(FormatError("sema", "非法参数")); } TypeDesc ty; ty.base = BaseTypeFromBType(ctx->bType()); if (ctx->LBRACK().size() > 0) { ty.dims.push_back(-1); for (auto* exp : ctx->exp()) { ty.dims.push_back(EvalConstExp(exp)); } } return ty; } void RegisterParam(SysYParser::FuncFParamContext* ctx) { if (!ctx || !ctx->ID()) { throw std::runtime_error(FormatError("sema", "参数缺少名称")); } std::string name = ctx->ID()->getText(); if (table_.ContainsInCurrentScope(name)) { throw std::runtime_error(FormatError("sema", "重复定义参数: " + name)); } TypeDesc ty = BuildParamType(ctx); SymbolEntry entry; entry.kind = SymbolKind::Param; entry.param_decl = ctx; entry.is_const = false; entry.type = ty; table_.Add(name, entry); sema_.RegisterParam(ctx, ty); } void RegisterVar(SysYParser::VarDefContext* ctx, BaseTypeKind base) { if (!ctx || !ctx->ID()) { throw std::runtime_error(FormatError("sema", "变量声明缺少名称")); } std::string name = ctx->ID()->getText(); if (table_.ContainsInCurrentScope(name)) { throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); } TypeDesc ty; ty.base = base; for (auto* dim : ctx->constExp()) { ty.dims.push_back(EvalConstExp(dim)); } SymbolEntry entry; entry.kind = SymbolKind::Var; entry.var_decl = ctx; entry.is_const = false; entry.type = ty; table_.Add(name, entry); sema_.RegisterVarDecl(ctx, ty); if (auto* init = ctx->initVal()) { init->accept(this); } } void RegisterConst(SysYParser::ConstDefContext* ctx, BaseTypeKind base) { if (!ctx || !ctx->ID()) { throw std::runtime_error(FormatError("sema", "常量声明缺少名称")); } std::string name = ctx->ID()->getText(); if (table_.ContainsInCurrentScope(name)) { throw std::runtime_error(FormatError("sema", "重复定义常量: " + name)); } TypeDesc ty; ty.base = base; ty.is_const = true; for (auto* dim : ctx->constExp()) { ty.dims.push_back(EvalConstExp(dim)); } SymbolEntry entry; entry.kind = SymbolKind::Const; entry.const_decl = ctx; entry.is_const = true; entry.type = ty; if (ctx->constInitVal() && ty.dims.empty() && ty.base == BaseTypeKind::Int) { if (auto* exp = ctx->constInitVal()->constExp()) { entry.const_value = EvalConstExp(exp); } } table_.Add(name, entry); sema_.RegisterConstDecl(ctx, ty); if (auto* init = ctx->constInitVal()) { init->accept(this); } } int EvalConstExp(SysYParser::ConstExpContext* ctx) { ConstEvalVisitor visitor(table_); return std::any_cast(ctx->accept(&visitor)); } int EvalConstExp(SysYParser::ExpContext* ctx) { if (!ctx || !ctx->addExp()) { throw std::runtime_error(FormatError("sema", "非法常量表达式")); } ConstEvalVisitor visitor(table_); return std::any_cast(ctx->addExp()->accept(&visitor)); } private: SymbolTable table_; SemanticContext sema_; std::unordered_map func_table_; const std::unordered_set builtin_funcs_ = { "getint", "getch", "getarray", "putint", "putch", "putarray", "getfloat", "getfarray", "putfloat", "putfarray", "starttime", "stoptime"}; BaseTypeKind current_ret_ = BaseTypeKind::Void; bool seen_return_ = false; int loop_depth_ = 0; }; } // namespace SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { SemaVisitor visitor; comp_unit.accept(&visitor); return visitor.TakeSemanticContext(); }