diff --git a/include/sem/Sema.h b/include/sem/Sema.h index 9ac057b..64e2595 100644 --- a/include/sem/Sema.h +++ b/include/sem/Sema.h @@ -1,30 +1,69 @@ // 基于语法树的语义检查与名称绑定。 #pragma once +#include #include +#include #include "SysYParser.h" +enum class SemanticType { + Void, + Int, + Float, +}; + +struct ScalarConstant { + SemanticType type = SemanticType::Int; + double number = 0.0; +}; + +struct ObjectBinding { + enum class DeclKind { + Var, + Const, + Param, + }; + + std::string name; + SemanticType type = SemanticType::Int; + DeclKind decl_kind = DeclKind::Var; + bool is_array_param = false; + std::vector dimensions; + const SysYParser::VarDefContext* var_def = nullptr; + const SysYParser::ConstDefContext* const_def = nullptr; + const SysYParser::FuncFParamContext* func_param = nullptr; + bool has_const_value = false; + ScalarConstant const_value; +}; + +struct FunctionBinding { + std::string name; + SemanticType return_type = SemanticType::Int; + std::vector params; + const SysYParser::FuncDefContext* func_def = nullptr; + bool is_builtin = false; +}; + class SemanticContext { public: - void BindVarUse(SysYParser::VarContext* use, - SysYParser::VarDefContext* decl) { - var_uses_[use] = decl; - } + void BindObjectUse(const SysYParser::LValContext* use, ObjectBinding binding); + const ObjectBinding* ResolveObjectUse( + const SysYParser::LValContext* use) const; + + void BindFunctionCall(const SysYParser::UnaryExpContext* call, + FunctionBinding binding); + const FunctionBinding* ResolveFunctionCall( + const SysYParser::UnaryExpContext* call) const; - SysYParser::VarDefContext* ResolveVarUse( - const SysYParser::VarContext* use) const { - auto it = var_uses_.find(use); - return it == var_uses_.end() ? nullptr : it->second; - } + void RegisterFunction(FunctionBinding binding); + const FunctionBinding* ResolveFunction(const std::string& name) const; private: - std::unordered_map - var_uses_; + std::unordered_map object_uses_; + std::unordered_map + function_calls_; + std::unordered_map functions_; }; -// 目前仅检查: -// - 变量先声明后使用 -// - 局部变量不允许重复定义 SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); diff --git a/include/sem/SymbolTable.h b/include/sem/SymbolTable.h index c9396dd..201112c 100644 --- a/include/sem/SymbolTable.h +++ b/include/sem/SymbolTable.h @@ -1,17 +1,25 @@ -// 极简符号表:记录局部变量定义点。 +// 维护对象符号的多层作用域。 #pragma once #include +#include #include +#include -#include "SysYParser.h" +#include "sem/Sema.h" class SymbolTable { public: - void Add(const std::string& name, SysYParser::VarDefContext* decl); - bool Contains(const std::string& name) const; - SysYParser::VarDefContext* Lookup(const std::string& name) const; + SymbolTable(); + + void EnterScope(); + void ExitScope(); + + bool Add(const ObjectBinding& symbol); + bool ContainsInCurrentScope(std::string_view name) const; + const ObjectBinding* Lookup(std::string_view name) const; + size_t Depth() const; private: - std::unordered_map table_; + std::vector> scopes_; }; diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index 745374c..f0b49e5 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -1,8 +1,13 @@ #include "sem/Sema.h" #include +#include +#include #include #include +#include +#include +#include #include "SysYBaseVisitor.h" #include "sem/SymbolTable.h" @@ -10,74 +15,258 @@ namespace { -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("sema", "非法左值")); +constexpr int kUnknownArrayDim = -1; + +struct ExprInfo { + SemanticType type = SemanticType::Int; + bool is_lvalue = false; + bool is_const_object = false; + std::vector dimensions; + bool has_const_value = false; + ScalarConstant const_value; + + bool IsScalar() const { return dimensions.empty() && type != SemanticType::Void; } + bool IsArray() const { return !dimensions.empty(); } +}; + +SemanticType ParseBType(SysYParser::BTypeContext& ctx) { + if (ctx.INT()) { + return SemanticType::Int; + } + if (ctx.FLOAT()) { + return SemanticType::Float; + } + throw std::runtime_error(FormatError("sema", "未知基础类型")); +} + +SemanticType ParseFuncType(SysYParser::FuncTypeContext& ctx) { + if (ctx.VOID()) { + return SemanticType::Void; + } + if (ctx.INT()) { + return SemanticType::Int; + } + if (ctx.FLOAT()) { + return SemanticType::Float; + } + throw std::runtime_error(FormatError("sema", "未知函数返回类型")); +} + +int ConvertToInt(const ScalarConstant& value) { + return static_cast(value.number); +} + +double ConvertToFloat(const ScalarConstant& value) { return value.number; } + +bool IsNumericType(SemanticType type) { + return type == SemanticType::Int || type == SemanticType::Float; +} + +bool CanImplicitlyConvert(SemanticType from, SemanticType to) { + if (from == to) { + return true; + } + if (!IsNumericType(from) || !IsNumericType(to)) { + return false; + } + return true; +} + +ScalarConstant CastConstant(const ScalarConstant& value, SemanticType to) { + if (!CanImplicitlyConvert(value.type, to)) { + throw std::runtime_error(FormatError("sema", "非法常量类型转换")); + } + ScalarConstant result; + result.type = to; + result.number = to == SemanticType::Int ? static_cast(ConvertToInt(value)) + : ConvertToFloat(value); + return result; +} + +bool IsTrue(const ScalarConstant& value) { + if (value.type == SemanticType::Float) { + return value.number != 0.0; + } + return ConvertToInt(value) != 0; +} + +ScalarConstant MakeInt(int value) { + return ScalarConstant{SemanticType::Int, static_cast(value)}; +} + +ScalarConstant MakeFloat(double value) { + return ScalarConstant{SemanticType::Float, value}; +} + +const antlr4::Token* StartToken(const antlr4::ParserRuleContext* ctx) { + return ctx ? ctx->getStart() : nullptr; +} + +[[noreturn]] void ThrowSemaError(const antlr4::ParserRuleContext* ctx, + std::string_view msg) { + if (const auto* tok = StartToken(ctx)) { + throw std::runtime_error( + FormatErrorAt("sema", tok->getLine(), tok->getCharPositionInLine(), msg)); + } + throw std::runtime_error(FormatError("sema", msg)); +} + +int ParseIntLiteral(SysYParser::IntConstContext& ctx) { + return std::stoi(ctx.getText(), nullptr, 0); +} + +double ParseFloatLiteral(SysYParser::FloatConstContext& ctx) { + const std::string text = ctx.getText(); + char* end = nullptr; + const double value = std::strtod(text.c_str(), &end); + if (end == nullptr || *end != '\0') { + throw std::runtime_error(FormatError("sema", "非法浮点字面量: " + text)); } - return lvalue.ID()->getText(); + return value; +} + +FunctionBinding MakeBuiltinFunction(std::string name, SemanticType return_type, + std::vector params) { + FunctionBinding fn; + fn.name = std::move(name); + fn.return_type = return_type; + fn.params = std::move(params); + fn.is_builtin = true; + return fn; +} + +ObjectBinding MakeParam(std::string name, SemanticType type, + std::vector dimensions = {}, + bool is_array_param = false) { + ObjectBinding param; + param.name = std::move(name); + param.type = type; + param.decl_kind = ObjectBinding::DeclKind::Param; + param.dimensions = std::move(dimensions); + param.is_array_param = is_array_param; + return param; } class SemaVisitor final : public SysYBaseVisitor { public: + SemaVisitor() { RegisterBuiltins(); } + std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { if (!ctx) { - throw std::runtime_error(FormatError("sema", "缺少编译单元")); + ThrowSemaError(ctx, "缺少编译单元"); + } + + CollectFunctions(*ctx); + for (auto* item : ctx->topLevelItem()) { + if (!item) { + continue; + } + item->accept(this); + } + + const FunctionBinding* main = sema_.ResolveFunction("main"); + if (!main || main->is_builtin) { + ThrowSemaError(ctx, "缺少 main 函数定义"); } - auto* func = ctx->funcDef(); - if (!func || !func->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); + if (main->return_type != SemanticType::Int || !main->params.empty()) { + ThrowSemaError(main->func_def, "main 函数必须是无参 int main()"); + } + return {}; + } + + std::any visitTopLevelItem(SysYParser::TopLevelItemContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "缺少顶层定义"); } - if (!func->ID() || func->ID()->getText() != "main") { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); + if (ctx->decl()) { + ctx->decl()->accept(this); + return {}; + } + if (ctx->funcDef()) { + ctx->funcDef()->accept(this); + return {}; + } + ThrowSemaError(ctx, "暂不支持的顶层定义"); + } + + std::any visitDecl(SysYParser::DeclContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "缺少声明"); + } + if (ctx->constDecl()) { + ctx->constDecl()->accept(this); + return {}; + } + if (ctx->varDecl()) { + ctx->varDecl()->accept(this); + return {}; + } + ThrowSemaError(ctx, "非法声明"); + } + + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override { + if (!ctx || !ctx->bType()) { + ThrowSemaError(ctx, "非法常量声明"); + } + const SemanticType type = ParseBType(*ctx->bType()); + for (auto* def : ctx->constDef()) { + DeclareConst(*def, type); + } + return {}; + } + + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override { + if (!ctx || !ctx->bType()) { + ThrowSemaError(ctx, "非法变量声明"); } - func->accept(this); - if (!seen_return_) { - throw std::runtime_error( - FormatError("sema", "main 函数必须包含 return 语句")); + const SemanticType type = ParseBType(*ctx->bType()); + for (auto* def : ctx->varDef()) { + DeclareVar(*def, type); } return {}; } std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { - if (!ctx || !ctx->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); + if (!ctx || !ctx->ID() || !ctx->funcType() || !ctx->block()) { + ThrowSemaError(ctx, "非法函数定义"); } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持 int main")); + + const FunctionBinding* binding = sema_.ResolveFunction(ctx->ID()->getText()); + if (!binding) { + ThrowSemaError(ctx, "函数未完成预收集: " + ctx->ID()->getText()); } - const auto& items = ctx->blockStmt()->blockItem(); - if (items.empty()) { - throw std::runtime_error( - FormatError("sema", "main 函数不能为空,且必须以 return 结束")); + + const FunctionBinding* prev = current_function_; + current_function_ = binding; + symbols_.EnterScope(); + for (const auto& param : binding->params) { + if (!symbols_.Add(param)) { + ThrowSemaError(ctx, "函数形参重复定义: " + param.name); + } } - ctx->blockStmt()->accept(this); + ctx->block()->accept(this); + symbols_.ExitScope(); + current_function_ = prev; return {}; } - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override { + std::any visitBlock(SysYParser::BlockContext* ctx) override { if (!ctx) { - throw std::runtime_error(FormatError("sema", "缺少语句块")); + ThrowSemaError(ctx, "缺少语句块"); } - const auto& items = ctx->blockItem(); - for (size_t i = 0; i < items.size(); ++i) { - auto* item = items[i]; - if (!item) { - continue; - } - if (seen_return_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); + symbols_.EnterScope(); + for (auto* item : ctx->blockItem()) { + if (item) { + item->accept(this); } - current_item_index_ = i; - total_items_ = items.size(); - item->accept(this); } + symbols_.ExitScope(); return {}; } std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { if (!ctx) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + ThrowSemaError(ctx, "缺少块内语句"); } if (ctx->decl()) { ctx->decl()->accept(this); @@ -87,112 +276,766 @@ class SemaVisitor final : public SysYBaseVisitor { ctx->stmt()->accept(this); return {}; } - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + ThrowSemaError(ctx, "非法块内语句"); } - std::any visitDecl(SysYParser::DeclContext* ctx) override { + std::any visitStmt(SysYParser::StmtContext* ctx) override { if (!ctx) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); + ThrowSemaError(ctx, "缺少语句"); } - if (!ctx->btype() || !ctx->btype()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明")); + + if (ctx->BREAK()) { + if (loop_depth_ == 0) { + ThrowSemaError(ctx, "break 只能出现在循环内部"); + } + return {}; + } + if (ctx->CONTINUE()) { + if (loop_depth_ == 0) { + ThrowSemaError(ctx, "continue 只能出现在循环内部"); + } + return {}; } - auto* var_def = ctx->varDef(); - if (!var_def || !var_def->lValue()) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); + if (ctx->RETURN()) { + CheckReturn(*ctx); + return {}; } - const std::string name = GetLValueName(*var_def->lValue()); - if (table_.Contains(name)) { - throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); + if (ctx->WHILE()) { + RequireScalar(ctx->cond(), EvalCond(*ctx->cond()), "while 条件必须是标量表达式"); + ++loop_depth_; + ctx->stmt(0)->accept(this); + --loop_depth_; + return {}; } - if (auto* init = var_def->initValue()) { - if (!init->exp()) { - throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化")); + if (ctx->IF()) { + RequireScalar(ctx->cond(), EvalCond(*ctx->cond()), "if 条件必须是标量表达式"); + ctx->stmt(0)->accept(this); + if (ctx->stmt().size() > 1 && ctx->stmt(1)) { + ctx->stmt(1)->accept(this); } - init->exp()->accept(this); + return {}; + } + if (ctx->block()) { + ctx->block()->accept(this); + return {}; + } + if (ctx->lVal() && ctx->ASSIGN()) { + CheckAssignment(*ctx); + return {}; + } + if (ctx->exp()) { + EvalExpr(*ctx->exp()); + return {}; } - table_.Add(name, var_def); return {}; } - std::any visitStmt(SysYParser::StmtContext* ctx) override { - if (!ctx || !ctx->returnStmt()) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + std::any visitExp(SysYParser::ExpContext* ctx) override { + if (!ctx || !ctx->addExp()) { + ThrowSemaError(ctx, "非法表达式"); } - ctx->returnStmt()->accept(this); - return {}; + return EvalExpr(*ctx->addExp()); } - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "return 缺少表达式")); + std::any visitCond(SysYParser::CondContext* ctx) override { + if (!ctx || !ctx->lOrExp()) { + ThrowSemaError(ctx, "非法条件表达式"); } - ctx->exp()->accept(this); - seen_return_ = true; - if (current_item_index_ + 1 != total_items_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); + return EvalExpr(*ctx->lOrExp()); + } + + std::any visitLVal(SysYParser::LValContext* ctx) override { + if (!ctx || !ctx->ID()) { + ThrowSemaError(ctx, "非法左值"); } - return {}; + return AnalyzeLVal(*ctx); } - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "非法括号表达式")); + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法基础表达式"); } - ctx->exp()->accept(this); - return {}; + if (ctx->exp()) { + return EvalExpr(*ctx->exp()); + } + if (ctx->lVal()) { + return AnalyzeLVal(*ctx->lVal()); + } + if (ctx->number()) { + return EvalExpr(*ctx->number()); + } + ThrowSemaError(ctx, "非法基础表达式"); } - std::any visitVarExp(SysYParser::VarExpContext* ctx) override { - if (!ctx || !ctx->var()) { - throw std::runtime_error(FormatError("sema", "非法变量表达式")); + std::any visitNumber(SysYParser::NumberContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法数字字面量"); } - ctx->var()->accept(this); - return {}; + if (ctx->intConst()) { + return EvalExpr(*ctx->intConst()); + } + if (ctx->floatConst()) { + return EvalExpr(*ctx->floatConst()); + } + ThrowSemaError(ctx, "非法数字字面量"); } - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { - throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量")); + std::any visitIntConst(SysYParser::IntConstContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法整数字面量"); } - return {}; + ExprInfo expr; + expr.type = SemanticType::Int; + expr.has_const_value = true; + expr.const_value = MakeInt(ParseIntLiteral(*ctx)); + return expr; } - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { - throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式")); + std::any visitFloatConst(SysYParser::FloatConstContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法浮点字面量"); } - ctx->exp(0)->accept(this); - ctx->exp(1)->accept(this); - return {}; + ExprInfo expr; + expr.type = SemanticType::Float; + expr.has_const_value = true; + expr.const_value = MakeFloat(ParseFloatLiteral(*ctx)); + return expr; } - std::any visitVar(SysYParser::VarContext* ctx) override { - if (!ctx || !ctx->ID()) { - throw std::runtime_error(FormatError("sema", "非法变量引用")); + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法一元表达式"); } - const std::string name = ctx->ID()->getText(); - auto* decl = table_.Lookup(name); - if (!decl) { - throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); + if (ctx->primaryExp()) { + return EvalExpr(*ctx->primaryExp()); } - sema_.BindVarUse(ctx, decl); - return {}; + if (ctx->ID()) { + return AnalyzeCall(*ctx); + } + if (ctx->addUnaryOp() && ctx->unaryExp()) { + ExprInfo operand = EvalExpr(*ctx->unaryExp()); + RequireScalar(ctx->unaryExp(), operand, "一元运算要求标量操作数"); + ExprInfo result; + result.type = operand.type; + if (ctx->addUnaryOp()->SUB() && operand.has_const_value) { + result.has_const_value = true; + result.const_value = operand.const_value; + result.const_value.number = -result.const_value.number; + } else if (ctx->addUnaryOp()->ADD() && operand.has_const_value) { + result.has_const_value = true; + result.const_value = operand.const_value; + } + return result; + } + ThrowSemaError(ctx, "非法一元表达式"); + } + + std::any visitMulExp(SysYParser::MulExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法乘法表达式"); + } + if (ctx->unaryExp()) { + return EvalExpr(*ctx->unaryExp()); + } + ExprInfo lhs = EvalExpr(*ctx->mulExp()); + ExprInfo rhs = EvalExpr(*ctx->unaryExp()); + return EvalArithmetic(*ctx, lhs, rhs, ctx->MUL() ? '*' : (ctx->DIV() ? '/' : '%')); + } + + std::any visitAddExp(SysYParser::AddExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法加法表达式"); + } + if (ctx->mulExp()) { + return EvalExpr(*ctx->mulExp()); + } + ExprInfo lhs = EvalExpr(*ctx->addExp()); + ExprInfo rhs = EvalExpr(*ctx->mulExp()); + return EvalArithmetic(*ctx, lhs, rhs, ctx->ADD() ? '+' : '-'); + } + + std::any visitRelExp(SysYParser::RelExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法关系表达式"); + } + if (ctx->addExp()) { + return EvalExpr(*ctx->addExp()); + } + ExprInfo lhs = EvalExpr(*ctx->relExp()); + ExprInfo rhs = EvalExpr(*ctx->addExp()); + return EvalCompare(*ctx, lhs, rhs); + } + + std::any visitEqExp(SysYParser::EqExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法相等表达式"); + } + if (ctx->relExp()) { + return EvalExpr(*ctx->relExp()); + } + ExprInfo lhs = EvalExpr(*ctx->eqExp()); + ExprInfo rhs = EvalExpr(*ctx->relExp()); + return EvalCompare(*ctx, lhs, rhs); + } + + std::any visitCondUnaryExp(SysYParser::CondUnaryExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法条件一元表达式"); + } + if (ctx->eqExp()) { + return EvalExpr(*ctx->eqExp()); + } + ExprInfo operand = EvalExpr(*ctx->condUnaryExp()); + RequireScalar(ctx->condUnaryExp(), operand, "逻辑非要求标量操作数"); + ExprInfo result; + result.type = SemanticType::Int; + if (operand.has_const_value) { + result.has_const_value = true; + result.const_value = MakeInt(IsTrue(operand.const_value) ? 0 : 1); + } + return result; + } + + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法逻辑与表达式"); + } + if (ctx->condUnaryExp()) { + return EvalExpr(*ctx->condUnaryExp()); + } + ExprInfo lhs = EvalExpr(*ctx->lAndExp()); + ExprInfo rhs = EvalExpr(*ctx->condUnaryExp()); + return EvalLogical(*ctx, lhs, rhs, true); + } + + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法逻辑或表达式"); + } + if (ctx->lAndExp()) { + return EvalExpr(*ctx->lAndExp()); + } + ExprInfo lhs = EvalExpr(*ctx->lOrExp()); + ExprInfo rhs = EvalExpr(*ctx->lAndExp()); + return EvalLogical(*ctx, lhs, rhs, false); + } + + std::any visitConstExp(SysYParser::ConstExpContext* ctx) override { + if (!ctx || !ctx->addExp()) { + ThrowSemaError(ctx, "非法常量表达式"); + } + ExprInfo expr = EvalExpr(*ctx->addExp()); + if (!expr.IsScalar() || !expr.has_const_value) { + ThrowSemaError(ctx, "要求编译期常量表达式"); + } + return expr; } SemanticContext TakeSemanticContext() { return std::move(sema_); } private: - SymbolTable table_; + ExprInfo EvalExpr(antlr4::tree::ParseTree& node) { + return std::any_cast(node.accept(this)); + } + + ExprInfo EvalCond(SysYParser::CondContext& cond) { return EvalExpr(cond); } + + ExprInfo AnalyzeLVal(SysYParser::LValContext& ctx) { + const std::string name = ctx.ID()->getText(); + const ObjectBinding* symbol = symbols_.Lookup(name); + if (!symbol) { + ThrowSemaError(&ctx, "使用了未声明的标识符: " + name); + } + + sema_.BindObjectUse(&ctx, *symbol); + + if (ctx.exp().size() > symbol->dimensions.size()) { + ThrowSemaError(&ctx, "数组下标过多: " + name); + } + + for (auto* exp : ctx.exp()) { + ExprInfo index = EvalExpr(*exp); + RequireScalar(exp, index, "数组下标必须是标量表达式"); + } + + ExprInfo result; + result.type = symbol->type; + result.is_const_object = symbol->decl_kind == ObjectBinding::DeclKind::Const; + result.is_lvalue = ctx.exp().size() == symbol->dimensions.size(); + result.dimensions.assign(symbol->dimensions.begin() + ctx.exp().size(), + symbol->dimensions.end()); + if (result.dimensions.empty() && symbol->has_const_value) { + result.has_const_value = true; + result.const_value = symbol->const_value; + } + return result; + } + + ExprInfo AnalyzeCall(SysYParser::UnaryExpContext& ctx) { + const std::string name = ctx.ID()->getText(); + if (const ObjectBinding* object = symbols_.Lookup(name)) { + ThrowSemaError(&ctx, "标识符不是函数: " + object->name); + } + + const FunctionBinding* fn = sema_.ResolveFunction(name); + if (!fn) { + ThrowSemaError(&ctx, "调用了未定义的函数: " + name); + } + + std::vector args; + if (ctx.funcRParams()) { + for (auto* exp : ctx.funcRParams()->exp()) { + args.push_back(EvalExpr(*exp)); + } + } + if (args.size() != fn->params.size()) { + ThrowSemaError(&ctx, "函数参数个数不匹配: " + name); + } + for (size_t i = 0; i < args.size(); ++i) { + CheckArgument(ctx, fn->params[i], args[i], i); + } + + sema_.BindFunctionCall(&ctx, *fn); + + ExprInfo result; + result.type = fn->return_type; + return result; + } + + void CheckArgument(const antlr4::ParserRuleContext& call_site, + const ObjectBinding& param, const ExprInfo& arg, + size_t index) { + if (param.dimensions.empty()) { + if (!arg.IsScalar()) { + ThrowSemaError(&call_site, "第 " + std::to_string(index + 1) + + " 个参数需要标量实参"); + } + if (!CanImplicitlyConvert(arg.type, param.type)) { + ThrowSemaError(&call_site, "第 " + std::to_string(index + 1) + + " 个参数类型不匹配"); + } + return; + } + + if (!arg.IsArray()) { + ThrowSemaError(&call_site, "第 " + std::to_string(index + 1) + + " 个参数需要数组实参"); + } + if (arg.type != param.type || arg.dimensions.size() != param.dimensions.size()) { + ThrowSemaError(&call_site, "第 " + std::to_string(index + 1) + + " 个数组参数类型不匹配"); + } + for (size_t dim = 1; dim < param.dimensions.size(); ++dim) { + if (param.dimensions[dim] != kUnknownArrayDim && + arg.dimensions[dim] != param.dimensions[dim]) { + ThrowSemaError(&call_site, "第 " + std::to_string(index + 1) + + " 个数组参数维度不匹配"); + } + } + } + + void CheckAssignment(SysYParser::StmtContext& ctx) { + ExprInfo lhs = AnalyzeLVal(*ctx.lVal()); + if (!lhs.IsScalar() || !lhs.is_lvalue) { + ThrowSemaError(&ctx, "赋值语句左侧必须是可写标量左值"); + } + if (lhs.is_const_object) { + ThrowSemaError(&ctx, "不能给 const 对象赋值"); + } + ExprInfo rhs = EvalExpr(*ctx.exp()); + RequireScalar(ctx.exp(), rhs, "赋值语句右侧必须是标量表达式"); + if (!CanImplicitlyConvert(rhs.type, lhs.type)) { + ThrowSemaError(&ctx, "赋值语句两侧类型不兼容"); + } + } + + void CheckReturn(SysYParser::StmtContext& ctx) { + if (!current_function_) { + ThrowSemaError(&ctx, "return 语句不在函数内部"); + } + if (current_function_->return_type == SemanticType::Void) { + if (ctx.exp()) { + ThrowSemaError(&ctx, "void 函数不能返回值"); + } + return; + } + if (!ctx.exp()) { + ThrowSemaError(&ctx, "非 void 函数必须返回值"); + } + ExprInfo expr = EvalExpr(*ctx.exp()); + RequireScalar(ctx.exp(), expr, "return 表达式必须是标量"); + if (!CanImplicitlyConvert(expr.type, current_function_->return_type)) { + ThrowSemaError(&ctx, "return 表达式类型与函数返回类型不匹配"); + } + } + + void DeclareConst(SysYParser::ConstDefContext& ctx, SemanticType type) { + ObjectBinding symbol; + symbol.name = ctx.ID()->getText(); + symbol.type = type; + symbol.decl_kind = ObjectBinding::DeclKind::Const; + symbol.const_def = &ctx; + symbol.dimensions = EvalArrayDims(ctx.constIndex(), true); + + if (symbols_.ContainsInCurrentScope(symbol.name)) { + ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name); + } + if (symbols_.Depth() == 1 && sema_.ResolveFunction(symbol.name)) { + ThrowSemaError(&ctx, "全局对象与函数重名: " + symbol.name); + } + + if (!ctx.constInitVal()) { + ThrowSemaError(&ctx, "const 对象缺少初始化"); + } + if (symbol.dimensions.empty()) { + symbol.const_value = ValidateConstInitScalar(*ctx.constInitVal(), type); + symbol.has_const_value = true; + } else { + ValidateConstInitAggregate(*ctx.constInitVal(), type); + } + + if (!symbols_.Add(symbol)) { + ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name); + } + } + + void DeclareVar(SysYParser::VarDefContext& ctx, SemanticType type) { + ObjectBinding symbol; + symbol.name = ctx.ID()->getText(); + symbol.type = type; + symbol.decl_kind = ObjectBinding::DeclKind::Var; + symbol.var_def = &ctx; + symbol.dimensions = EvalArrayDims(ctx.constIndex(), true); + + if (symbols_.ContainsInCurrentScope(symbol.name)) { + ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name); + } + if (symbols_.Depth() == 1 && sema_.ResolveFunction(symbol.name)) { + ThrowSemaError(&ctx, "全局对象与函数重名: " + symbol.name); + } + + if (!symbols_.Add(symbol)) { + ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name); + } + + if (!ctx.initVal()) { + return; + } + if (symbol.dimensions.empty()) { + ValidateVarInitScalar(*ctx.initVal(), type, symbols_.Depth() == 1); + } else { + ValidateVarInitAggregate(*ctx.initVal(), type, symbols_.Depth() == 1); + } + } + + std::vector EvalArrayDims( + const std::vector& indices, + bool require_positive) { + std::vector dims; + dims.reserve(indices.size()); + for (auto* index : indices) { + if (!index || !index->constExp()) { + ThrowSemaError(index, "数组维度缺少常量表达式"); + } + ExprInfo expr = EvalExpr(*index->constExp()); + if (!expr.IsScalar() || !expr.has_const_value) { + ThrowSemaError(index, "数组维度必须是整型常量表达式"); + } + const int dim = ConvertToInt(CastConstant(expr.const_value, SemanticType::Int)); + if (require_positive && dim <= 0) { + ThrowSemaError(index, "数组维度必须为正整数"); + } + dims.push_back(dim); + } + return dims; + } + + ScalarConstant ValidateConstInitScalar(SysYParser::ConstInitValContext& init, + SemanticType target_type) { + if (!init.constExp()) { + ThrowSemaError(&init, "标量 const 初始化必须是常量表达式"); + } + ExprInfo expr = EvalExpr(*init.constExp()); + if (!expr.IsScalar() || !expr.has_const_value) { + ThrowSemaError(&init, "标量 const 初始化必须是常量表达式"); + } + return CastConstant(expr.const_value, target_type); + } + + void ValidateConstInitAggregate(SysYParser::ConstInitValContext& init, + SemanticType target_type) { + if (init.constExp()) { + ExprInfo expr = EvalExpr(*init.constExp()); + if (!expr.IsScalar() || !expr.has_const_value) { + ThrowSemaError(&init, "数组 const 初始化要求常量表达式"); + } + CastConstant(expr.const_value, target_type); + return; + } + for (auto* nested : init.constInitVal()) { + if (nested) { + ValidateConstInitAggregate(*nested, target_type); + } + } + } + + void ValidateVarInitScalar(SysYParser::InitValContext& init, + SemanticType target_type, bool require_constant) { + if (!init.exp()) { + ThrowSemaError(&init, "标量初始化非法"); + } + ExprInfo expr = EvalExpr(*init.exp()); + RequireScalar(&init, expr, "标量初始化要求标量表达式"); + if (!CanImplicitlyConvert(expr.type, target_type)) { + ThrowSemaError(&init, "初始化表达式类型不兼容"); + } + if (require_constant && !expr.has_const_value) { + ThrowSemaError(&init, "全局变量初始化要求编译期常量"); + } + } + + void ValidateVarInitAggregate(SysYParser::InitValContext& init, + SemanticType target_type, bool require_constant) { + if (init.exp()) { + ExprInfo expr = EvalExpr(*init.exp()); + RequireScalar(&init, expr, "数组初始化元素必须是标量表达式"); + if (!CanImplicitlyConvert(expr.type, target_type)) { + ThrowSemaError(&init, "数组初始化元素类型不兼容"); + } + if (require_constant && !expr.has_const_value) { + ThrowSemaError(&init, "全局数组初始化要求编译期常量"); + } + return; + } + for (auto* nested : init.initVal()) { + if (nested) { + ValidateVarInitAggregate(*nested, target_type, require_constant); + } + } + } + + ExprInfo EvalArithmetic(const antlr4::ParserRuleContext& ctx, const ExprInfo& lhs, + const ExprInfo& rhs, char op) { + RequireScalar(&ctx, lhs, "算术运算要求标量操作数"); + RequireScalar(&ctx, rhs, "算术运算要求标量操作数"); + ExprInfo result; + result.type = lhs.type == SemanticType::Float || rhs.type == SemanticType::Float + ? SemanticType::Float + : SemanticType::Int; + if (!lhs.has_const_value || !rhs.has_const_value) { + return result; + } + + result.has_const_value = true; + const ScalarConstant lc = CastConstant(lhs.const_value, result.type); + const ScalarConstant rc = CastConstant(rhs.const_value, result.type); + if (result.type == SemanticType::Float) { + double value = 0.0; + if (op == '+') value = lc.number + rc.number; + if (op == '-') value = lc.number - rc.number; + if (op == '*') value = lc.number * rc.number; + if (op == '/') value = lc.number / rc.number; + if (op == '%') { + ThrowSemaError(&ctx, "浮点数不支持取模运算"); + } + result.const_value = MakeFloat(value); + return result; + } + + const int li = ConvertToInt(lc); + const int ri = ConvertToInt(rc); + int value = 0; + if (op == '+') value = li + ri; + if (op == '-') value = li - ri; + if (op == '*') value = li * ri; + if (op == '/') value = li / ri; + if (op == '%') value = li % ri; + result.const_value = MakeInt(value); + return result; + } + + ExprInfo EvalCompare(antlr4::ParserRuleContext& ctx, const ExprInfo& lhs, + const ExprInfo& rhs) { + RequireScalar(&ctx, lhs, "比较运算要求标量操作数"); + RequireScalar(&ctx, rhs, "比较运算要求标量操作数"); + ExprInfo result; + result.type = SemanticType::Int; + if (!lhs.has_const_value || !rhs.has_const_value) { + return result; + } + + const SemanticType promoted = + lhs.type == SemanticType::Float || rhs.type == SemanticType::Float + ? SemanticType::Float + : SemanticType::Int; + const ScalarConstant lc = CastConstant(lhs.const_value, promoted); + const ScalarConstant rc = CastConstant(rhs.const_value, promoted); + bool value = false; + if (auto* rel = dynamic_cast(&ctx)) { + if (rel->LT()) value = lc.number < rc.number; + if (rel->GT()) value = lc.number > rc.number; + if (rel->LE()) value = lc.number <= rc.number; + if (rel->GE()) value = lc.number >= rc.number; + } else if (auto* eq = dynamic_cast(&ctx)) { + if (eq->EQ()) value = lc.number == rc.number; + if (eq->NE()) value = lc.number != rc.number; + } + result.has_const_value = true; + result.const_value = MakeInt(value ? 1 : 0); + return result; + } + + ExprInfo EvalLogical(const antlr4::ParserRuleContext& ctx, const ExprInfo& lhs, + const ExprInfo& rhs, bool is_and) { + RequireScalar(&ctx, lhs, "逻辑运算要求标量操作数"); + RequireScalar(&ctx, rhs, "逻辑运算要求标量操作数"); + ExprInfo result; + result.type = SemanticType::Int; + if (!lhs.has_const_value || !rhs.has_const_value) { + return result; + } + const bool value = + is_and ? (IsTrue(lhs.const_value) && IsTrue(rhs.const_value)) + : (IsTrue(lhs.const_value) || IsTrue(rhs.const_value)); + result.has_const_value = true; + result.const_value = MakeInt(value ? 1 : 0); + return result; + } + + void RequireScalar(const antlr4::ParserRuleContext* ctx, const ExprInfo& expr, + std::string_view message) { + if (!expr.IsScalar()) { + ThrowSemaError(ctx, message); + } + } + + void CollectFunctions(SysYParser::CompUnitContext& ctx) { + for (auto* item : ctx.topLevelItem()) { + if (!item || !item->funcDef()) { + continue; + } + FunctionBinding fn = BuildFunctionSignature(*item->funcDef()); + if (sema_.ResolveFunction(fn.name)) { + ThrowSemaError(item->funcDef(), "重复定义函数: " + fn.name); + } + if (symbols_.ContainsInCurrentScope(fn.name)) { + ThrowSemaError(item->funcDef(), "函数与全局对象重名: " + fn.name); + } + sema_.RegisterFunction(std::move(fn)); + } + } + + FunctionBinding BuildFunctionSignature(SysYParser::FuncDefContext& ctx) { + FunctionBinding fn; + fn.name = ctx.ID()->getText(); + fn.return_type = ParseFuncType(*ctx.funcType()); + fn.func_def = &ctx; + if (ctx.funcFParams()) { + for (auto* param : ctx.funcFParams()->funcFParam()) { + fn.params.push_back(BuildParamBinding(*param)); + } + } + return fn; + } + + ObjectBinding BuildParamBinding(SysYParser::FuncFParamContext& ctx) { + if (!ctx.ID() || !ctx.bType()) { + ThrowSemaError(&ctx, "非法函数形参"); + } + ObjectBinding param; + param.name = ctx.ID()->getText(); + param.type = ParseBType(*ctx.bType()); + param.decl_kind = ObjectBinding::DeclKind::Param; + param.func_param = &ctx; + if (!ctx.LBRACK().empty()) { + param.is_array_param = true; + param.dimensions.push_back(kUnknownArrayDim); + for (auto* exp : ctx.exp()) { + ExprInfo dim = EvalExpr(*exp); + if (!dim.IsScalar() || !dim.has_const_value) { + ThrowSemaError(&ctx, "数组形参维度必须是整型常量表达式"); + } + const int value = ConvertToInt(CastConstant(dim.const_value, SemanticType::Int)); + if (value <= 0) { + ThrowSemaError(&ctx, "数组形参维度必须为正整数"); + } + param.dimensions.push_back(value); + } + } + return param; + } + + void RegisterBuiltins() { + sema_.RegisterFunction(MakeBuiltinFunction("getint", SemanticType::Int, {})); + sema_.RegisterFunction(MakeBuiltinFunction("getch", SemanticType::Int, {})); + sema_.RegisterFunction( + MakeBuiltinFunction("getfloat", SemanticType::Float, {})); + sema_.RegisterFunction(MakeBuiltinFunction( + "getarray", SemanticType::Int, + {MakeParam("a", SemanticType::Int, {kUnknownArrayDim}, true)})); + sema_.RegisterFunction(MakeBuiltinFunction( + "getfarray", SemanticType::Int, + {MakeParam("a", SemanticType::Float, {kUnknownArrayDim}, true)})); + sema_.RegisterFunction(MakeBuiltinFunction( + "putint", SemanticType::Void, {MakeParam("x", SemanticType::Int)})); + sema_.RegisterFunction(MakeBuiltinFunction( + "putch", SemanticType::Void, {MakeParam("x", SemanticType::Int)})); + sema_.RegisterFunction(MakeBuiltinFunction( + "putfloat", SemanticType::Void, {MakeParam("x", SemanticType::Float)})); + sema_.RegisterFunction(MakeBuiltinFunction( + "putarray", SemanticType::Void, + {MakeParam("n", SemanticType::Int), + MakeParam("a", SemanticType::Int, {kUnknownArrayDim}, true)})); + sema_.RegisterFunction(MakeBuiltinFunction( + "putfarray", SemanticType::Void, + {MakeParam("n", SemanticType::Int), + MakeParam("a", SemanticType::Float, {kUnknownArrayDim}, true)})); + sema_.RegisterFunction( + MakeBuiltinFunction("starttime", SemanticType::Void, {})); + sema_.RegisterFunction( + MakeBuiltinFunction("stoptime", SemanticType::Void, {})); + } + + SymbolTable symbols_; SemanticContext sema_; - bool seen_return_ = false; - size_t current_item_index_ = 0; - size_t total_items_ = 0; + const FunctionBinding* current_function_ = nullptr; + int loop_depth_ = 0; }; } // namespace +void SemanticContext::BindObjectUse(const SysYParser::LValContext* use, + ObjectBinding binding) { + object_uses_[use] = std::move(binding); +} + +const ObjectBinding* SemanticContext::ResolveObjectUse( + const SysYParser::LValContext* use) const { + auto it = object_uses_.find(use); + return it == object_uses_.end() ? nullptr : &it->second; +} + +void SemanticContext::BindFunctionCall(const SysYParser::UnaryExpContext* call, + FunctionBinding binding) { + function_calls_[call] = std::move(binding); +} + +const FunctionBinding* SemanticContext::ResolveFunctionCall( + const SysYParser::UnaryExpContext* call) const { + auto it = function_calls_.find(call); + return it == function_calls_.end() ? nullptr : &it->second; +} + +void SemanticContext::RegisterFunction(FunctionBinding binding) { + functions_[binding.name] = std::move(binding); +} + +const FunctionBinding* SemanticContext::ResolveFunction( + const std::string& name) const { + auto it = functions_.find(name); + return it == functions_.end() ? nullptr : &it->second; +} + SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { SemaVisitor visitor; comp_unit.accept(&visitor); diff --git a/src/sem/SymbolTable.cpp b/src/sem/SymbolTable.cpp index ffeea89..01b44bf 100644 --- a/src/sem/SymbolTable.cpp +++ b/src/sem/SymbolTable.cpp @@ -1,17 +1,39 @@ -// 维护局部变量声明的注册与查找。 +// 维护对象符号的注册与按作用域查找。 #include "sem/SymbolTable.h" -void SymbolTable::Add(const std::string& name, - SysYParser::VarDefContext* decl) { - table_[name] = decl; +#include + +SymbolTable::SymbolTable() : scopes_(1) {} + +void SymbolTable::EnterScope() { scopes_.emplace_back(); } + +void SymbolTable::ExitScope() { + if (scopes_.size() <= 1) { + throw std::runtime_error("symbol table scope underflow"); + } + scopes_.pop_back(); } -bool SymbolTable::Contains(const std::string& name) const { - return table_.find(name) != table_.end(); +bool SymbolTable::Add(const ObjectBinding& symbol) { + auto& scope = scopes_.back(); + return scope.emplace(symbol.name, symbol).second; } -SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const { - auto it = table_.find(name); - return it == table_.end() ? nullptr : it->second; +bool SymbolTable::ContainsInCurrentScope(std::string_view name) const { + const auto& scope = scopes_.back(); + return scope.find(std::string(name)) != scope.end(); } + +const ObjectBinding* SymbolTable::Lookup(std::string_view name) const { + const std::string key(name); + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + auto found = it->find(key); + if (found != it->end()) { + return &found->second; + } + } + return nullptr; +} + +size_t SymbolTable::Depth() const { return scopes_.size(); }