#include "sem/Sema.h" #include #include #include #include #include #include #include #include #include "SysYBaseVisitor.h" #include "sem/SymbolTable.h" #include "utils/Log.h" namespace { constexpr int kUnknownArrayDim = -1; struct ExprInfo { SemanticType type = SemanticType::Int; bool is_lvalue = false; bool is_const_object = false; std::vector dimensions; bool has_const_value = false; ScalarConstant const_value; bool IsScalar() const { return dimensions.empty() && type != SemanticType::Void; } bool IsArray() const { return !dimensions.empty(); } }; SemanticType ParseBType(SysYParser::BTypeContext& ctx) { if (ctx.INT()) { return SemanticType::Int; } if (ctx.FLOAT()) { return SemanticType::Float; } throw std::runtime_error(FormatError("sema", "未知基础类型")); } SemanticType ParseFuncType(SysYParser::FuncTypeContext& ctx) { if (ctx.VOID()) { return SemanticType::Void; } if (ctx.INT()) { return SemanticType::Int; } if (ctx.FLOAT()) { return SemanticType::Float; } throw std::runtime_error(FormatError("sema", "未知函数返回类型")); } int ConvertToInt(const ScalarConstant& value) { return static_cast(value.number); } double ConvertToFloat(const ScalarConstant& value) { return value.number; } bool IsNumericType(SemanticType type) { return type == SemanticType::Int || type == SemanticType::Float; } bool CanImplicitlyConvert(SemanticType from, SemanticType to) { if (from == to) { return true; } if (!IsNumericType(from) || !IsNumericType(to)) { return false; } return true; } ScalarConstant CastConstant(const ScalarConstant& value, SemanticType to) { if (!CanImplicitlyConvert(value.type, to)) { throw std::runtime_error(FormatError("sema", "非法常量类型转换")); } ScalarConstant result; result.type = to; result.number = to == SemanticType::Int ? static_cast(ConvertToInt(value)) : ConvertToFloat(value); return result; } bool IsTrue(const ScalarConstant& value) { if (value.type == SemanticType::Float) { return value.number != 0.0; } return ConvertToInt(value) != 0; } ScalarConstant MakeInt(int value) { return ScalarConstant{SemanticType::Int, static_cast(value)}; } ScalarConstant MakeFloat(double value) { return ScalarConstant{SemanticType::Float, value}; } const antlr4::Token* StartToken(const antlr4::ParserRuleContext* ctx) { return ctx ? ctx->getStart() : nullptr; } [[noreturn]] void ThrowSemaError(const antlr4::ParserRuleContext* ctx, std::string_view msg) { if (const auto* tok = StartToken(ctx)) { throw std::runtime_error( FormatErrorAt("sema", tok->getLine(), tok->getCharPositionInLine(), msg)); } throw std::runtime_error(FormatError("sema", msg)); } int ParseIntLiteral(SysYParser::IntConstContext& ctx) { return std::stoi(ctx.getText(), nullptr, 0); } double ParseFloatLiteral(SysYParser::FloatConstContext& ctx) { const std::string text = ctx.getText(); char* end = nullptr; const double value = std::strtod(text.c_str(), &end); if (end == nullptr || *end != '\0') { throw std::runtime_error(FormatError("sema", "非法浮点字面量: " + text)); } return value; } FunctionBinding MakeBuiltinFunction(std::string name, SemanticType return_type, std::vector params) { FunctionBinding fn; fn.name = std::move(name); fn.return_type = return_type; fn.params = std::move(params); fn.is_builtin = true; return fn; } ObjectBinding MakeParam(std::string name, SemanticType type, std::vector dimensions = {}, bool is_array_param = false) { ObjectBinding param; param.name = std::move(name); param.type = type; param.decl_kind = ObjectBinding::DeclKind::Param; param.dimensions = std::move(dimensions); param.is_array_param = is_array_param; return param; } class SemaVisitor final : public SysYBaseVisitor { public: SemaVisitor() { RegisterBuiltins(); } std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "缺少编译单元"); } CollectFunctions(*ctx); for (auto* item : ctx->topLevelItem()) { if (!item) { continue; } item->accept(this); } const FunctionBinding* main = sema_.ResolveFunction("main"); if (!main || main->is_builtin) { ThrowSemaError(ctx, "缺少 main 函数定义"); } if (main->return_type != SemanticType::Int || !main->params.empty()) { ThrowSemaError(main->func_def, "main 函数必须是无参 int main()"); } return {}; } std::any visitTopLevelItem(SysYParser::TopLevelItemContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "缺少顶层定义"); } if (ctx->decl()) { ctx->decl()->accept(this); return {}; } if (ctx->funcDef()) { ctx->funcDef()->accept(this); return {}; } ThrowSemaError(ctx, "暂不支持的顶层定义"); } std::any visitDecl(SysYParser::DeclContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "缺少声明"); } if (ctx->constDecl()) { ctx->constDecl()->accept(this); return {}; } if (ctx->varDecl()) { ctx->varDecl()->accept(this); return {}; } ThrowSemaError(ctx, "非法声明"); } std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override { if (!ctx || !ctx->bType()) { ThrowSemaError(ctx, "非法常量声明"); } const SemanticType type = ParseBType(*ctx->bType()); for (auto* def : ctx->constDef()) { DeclareConst(*def, type); } return {}; } std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override { if (!ctx || !ctx->bType()) { ThrowSemaError(ctx, "非法变量声明"); } const SemanticType type = ParseBType(*ctx->bType()); for (auto* def : ctx->varDef()) { DeclareVar(*def, type); } return {}; } std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { if (!ctx || !ctx->ID() || !ctx->funcType() || !ctx->block()) { ThrowSemaError(ctx, "非法函数定义"); } const FunctionBinding* binding = sema_.ResolveFunction(ctx->ID()->getText()); if (!binding) { ThrowSemaError(ctx, "函数未完成预收集: " + ctx->ID()->getText()); } const FunctionBinding* prev = current_function_; current_function_ = binding; symbols_.EnterScope(); for (const auto& param : binding->params) { if (!symbols_.Add(param)) { ThrowSemaError(ctx, "函数形参重复定义: " + param.name); } } ctx->block()->accept(this); symbols_.ExitScope(); current_function_ = prev; return {}; } std::any visitBlock(SysYParser::BlockContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "缺少语句块"); } symbols_.EnterScope(); for (auto* item : ctx->blockItem()) { if (item) { item->accept(this); } } symbols_.ExitScope(); return {}; } std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "缺少块内语句"); } if (ctx->decl()) { ctx->decl()->accept(this); return {}; } if (ctx->stmt()) { ctx->stmt()->accept(this); return {}; } ThrowSemaError(ctx, "非法块内语句"); } std::any visitStmt(SysYParser::StmtContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "缺少语句"); } if (ctx->BREAK()) { if (loop_depth_ == 0) { ThrowSemaError(ctx, "break 只能出现在循环内部"); } return {}; } if (ctx->CONTINUE()) { if (loop_depth_ == 0) { ThrowSemaError(ctx, "continue 只能出现在循环内部"); } return {}; } if (ctx->RETURN()) { CheckReturn(*ctx); return {}; } if (ctx->WHILE()) { RequireScalar(ctx->cond(), EvalCond(*ctx->cond()), "while 条件必须是标量表达式"); ++loop_depth_; ctx->stmt(0)->accept(this); --loop_depth_; return {}; } if (ctx->IF()) { RequireScalar(ctx->cond(), EvalCond(*ctx->cond()), "if 条件必须是标量表达式"); ctx->stmt(0)->accept(this); if (ctx->stmt().size() > 1 && ctx->stmt(1)) { ctx->stmt(1)->accept(this); } return {}; } if (ctx->block()) { ctx->block()->accept(this); return {}; } if (ctx->lVal() && ctx->ASSIGN()) { CheckAssignment(*ctx); return {}; } if (ctx->exp()) { EvalExpr(*ctx->exp()); return {}; } return {}; } std::any visitExp(SysYParser::ExpContext* ctx) override { if (!ctx || !ctx->addExp()) { ThrowSemaError(ctx, "非法表达式"); } return EvalExpr(*ctx->addExp()); } std::any visitCond(SysYParser::CondContext* ctx) override { if (!ctx || !ctx->lOrExp()) { ThrowSemaError(ctx, "非法条件表达式"); } return EvalExpr(*ctx->lOrExp()); } std::any visitLVal(SysYParser::LValContext* ctx) override { if (!ctx || !ctx->ID()) { ThrowSemaError(ctx, "非法左值"); } return AnalyzeLVal(*ctx); } std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "非法基础表达式"); } if (ctx->exp()) { return EvalExpr(*ctx->exp()); } if (ctx->lVal()) { return AnalyzeLVal(*ctx->lVal()); } if (ctx->number()) { return EvalExpr(*ctx->number()); } ThrowSemaError(ctx, "非法基础表达式"); } std::any visitNumber(SysYParser::NumberContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "非法数字字面量"); } if (ctx->intConst()) { return EvalExpr(*ctx->intConst()); } if (ctx->floatConst()) { return EvalExpr(*ctx->floatConst()); } ThrowSemaError(ctx, "非法数字字面量"); } std::any visitIntConst(SysYParser::IntConstContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "非法整数字面量"); } ExprInfo expr; expr.type = SemanticType::Int; expr.has_const_value = true; expr.const_value = MakeInt(ParseIntLiteral(*ctx)); return expr; } std::any visitFloatConst(SysYParser::FloatConstContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "非法浮点字面量"); } ExprInfo expr; expr.type = SemanticType::Float; expr.has_const_value = true; expr.const_value = MakeFloat(ParseFloatLiteral(*ctx)); return expr; } std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "非法一元表达式"); } if (ctx->primaryExp()) { return EvalExpr(*ctx->primaryExp()); } if (ctx->ID()) { return AnalyzeCall(*ctx); } if (ctx->addUnaryOp() && ctx->unaryExp()) { ExprInfo operand = EvalExpr(*ctx->unaryExp()); RequireScalar(ctx->unaryExp(), operand, "一元运算要求标量操作数"); ExprInfo result; result.type = operand.type; if (ctx->addUnaryOp()->SUB() && operand.has_const_value) { result.has_const_value = true; result.const_value = operand.const_value; result.const_value.number = -result.const_value.number; } else if (ctx->addUnaryOp()->ADD() && operand.has_const_value) { result.has_const_value = true; result.const_value = operand.const_value; } return result; } ThrowSemaError(ctx, "非法一元表达式"); } std::any visitMulExp(SysYParser::MulExpContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "非法乘法表达式"); } // 如果是 mulExp : unaryExp 形式(没有 MUL/DIV/MOD token),直接处理 unaryExp if (!ctx->MUL() && !ctx->DIV() && !ctx->MOD()) { return EvalExpr(*ctx->unaryExp()); } // 否则是 mulExp MUL/DIV/MOD unaryExp 形式 ExprInfo lhs = EvalExpr(*ctx->mulExp()); ExprInfo rhs = EvalExpr(*ctx->unaryExp()); return EvalArithmetic(*ctx, lhs, rhs, ctx->MUL() ? '*' : (ctx->DIV() ? '/' : '%')); } std::any visitAddExp(SysYParser::AddExpContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "非法加法表达式"); } // 如果是 addExp : mulExp 形式(没有 ADD/SUB token),直接处理 mulExp if (!ctx->ADD() && !ctx->SUB()) { return EvalExpr(*ctx->mulExp()); } // 否则是 addExp ADD/SUB mulExp 形式 ExprInfo lhs = EvalExpr(*ctx->addExp()); ExprInfo rhs = EvalExpr(*ctx->mulExp()); return EvalArithmetic(*ctx, lhs, rhs, ctx->ADD() ? '+' : '-'); } std::any visitRelExp(SysYParser::RelExpContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "非法关系表达式"); } if (ctx->relExp() == nullptr) { return EvalExpr(*ctx->addExp()); } ExprInfo lhs = EvalExpr(*ctx->relExp()); ExprInfo rhs = EvalExpr(*ctx->addExp()); return EvalCompare(*ctx, lhs, rhs); } std::any visitEqExp(SysYParser::EqExpContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "非法相等表达式"); } if (ctx->eqExp() == nullptr) { return EvalExpr(*ctx->relExp()); } ExprInfo lhs = EvalExpr(*ctx->eqExp()); ExprInfo rhs = EvalExpr(*ctx->relExp()); return EvalCompare(*ctx, lhs, rhs); } std::any visitCondUnaryExp(SysYParser::CondUnaryExpContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "非法条件一元表达式"); } if (ctx->eqExp()) { return EvalExpr(*ctx->eqExp()); } ExprInfo operand = EvalExpr(*ctx->condUnaryExp()); RequireScalar(ctx->condUnaryExp(), operand, "逻辑非要求标量操作数"); ExprInfo result; result.type = SemanticType::Int; if (operand.has_const_value) { result.has_const_value = true; result.const_value = MakeInt(IsTrue(operand.const_value) ? 0 : 1); } return result; } std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "非法逻辑与表达式"); } if (ctx->lAndExp() == nullptr) { return EvalExpr(*ctx->condUnaryExp()); } ExprInfo lhs = EvalExpr(*ctx->lAndExp()); ExprInfo rhs = EvalExpr(*ctx->condUnaryExp()); return EvalLogical(*ctx, lhs, rhs, true); } std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override { if (!ctx) { ThrowSemaError(ctx, "非法逻辑或表达式"); } if (ctx->lOrExp() == nullptr) { return EvalExpr(*ctx->lAndExp()); } ExprInfo lhs = EvalExpr(*ctx->lOrExp()); ExprInfo rhs = EvalExpr(*ctx->lAndExp()); return EvalLogical(*ctx, lhs, rhs, false); } std::any visitConstExp(SysYParser::ConstExpContext* ctx) override { if (!ctx || !ctx->addExp()) { ThrowSemaError(ctx, "非法常量表达式"); } ExprInfo expr = EvalExpr(*ctx->addExp()); if (!expr.IsScalar() || !expr.has_const_value) { ThrowSemaError(ctx, "要求编译期常量表达式"); } return expr; } SemanticContext TakeSemanticContext() { return std::move(sema_); } private: ExprInfo EvalExpr(antlr4::tree::ParseTree& node) { return std::any_cast(node.accept(this)); } ExprInfo EvalCond(SysYParser::CondContext& cond) { return EvalExpr(cond); } ExprInfo AnalyzeLVal(SysYParser::LValContext& ctx) { const std::string name = ctx.ID()->getText(); const ObjectBinding* symbol = symbols_.Lookup(name); if (!symbol) { ThrowSemaError(&ctx, "使用了未声明的标识符:" + name); } sema_.BindObjectUse(&ctx, *symbol); if (ctx.exp().size() > symbol->dimensions.size()) { ThrowSemaError(&ctx, "数组下标过多: " + name); } for (auto* exp : ctx.exp()) { ExprInfo index = EvalExpr(*exp); RequireScalar(exp, index, "数组下标必须是标量表达式"); } ExprInfo result; result.type = symbol->type; result.is_const_object = symbol->decl_kind == ObjectBinding::DeclKind::Const; result.is_lvalue = ctx.exp().size() == symbol->dimensions.size(); result.dimensions.assign(symbol->dimensions.begin() + ctx.exp().size(), symbol->dimensions.end()); if (result.dimensions.empty() && symbol->has_const_value) { result.has_const_value = true; result.const_value = symbol->const_value; } return result; } ExprInfo AnalyzeCall(SysYParser::UnaryExpContext& ctx) { const std::string name = ctx.ID()->getText(); if (const ObjectBinding* object = symbols_.Lookup(name)) { ThrowSemaError(&ctx, "标识符不是函数: " + object->name); } const FunctionBinding* fn = sema_.ResolveFunction(name); if (!fn) { ThrowSemaError(&ctx, "调用了未定义的函数: " + name); } std::vector args; if (ctx.funcRParams()) { for (auto* exp : ctx.funcRParams()->exp()) { args.push_back(EvalExpr(*exp)); } } if (args.size() != fn->params.size()) { ThrowSemaError(&ctx, "函数参数个数不匹配: " + name); } for (size_t i = 0; i < args.size(); ++i) { CheckArgument(ctx, fn->params[i], args[i], i); } sema_.BindFunctionCall(&ctx, *fn); ExprInfo result; result.type = fn->return_type; return result; } void CheckArgument(const antlr4::ParserRuleContext& call_site, const ObjectBinding& param, const ExprInfo& arg, size_t index) { if (param.dimensions.empty()) { if (!arg.IsScalar()) { ThrowSemaError(&call_site, "第 " + std::to_string(index + 1) + " 个参数需要标量实参"); } if (!CanImplicitlyConvert(arg.type, param.type)) { ThrowSemaError(&call_site, "第 " + std::to_string(index + 1) + " 个参数类型不匹配"); } return; } if (!arg.IsArray()) { ThrowSemaError(&call_site, "第 " + std::to_string(index + 1) + " 个参数需要数组实参"); } if (arg.type != param.type || arg.dimensions.size() != param.dimensions.size()) { ThrowSemaError(&call_site, "第 " + std::to_string(index + 1) + " 个数组参数类型不匹配"); } for (size_t dim = 1; dim < param.dimensions.size(); ++dim) { if (param.dimensions[dim] != kUnknownArrayDim && arg.dimensions[dim] != param.dimensions[dim]) { ThrowSemaError(&call_site, "第 " + std::to_string(index + 1) + " 个数组参数维度不匹配"); } } } void CheckAssignment(SysYParser::StmtContext& ctx) { ExprInfo lhs = AnalyzeLVal(*ctx.lVal()); if (!lhs.IsScalar() || !lhs.is_lvalue) { ThrowSemaError(&ctx, "赋值语句左侧必须是可写标量左值"); } if (lhs.is_const_object) { ThrowSemaError(&ctx, "不能给 const 对象赋值"); } ExprInfo rhs = EvalExpr(*ctx.exp()); RequireScalar(ctx.exp(), rhs, "赋值语句右侧必须是标量表达式"); if (!CanImplicitlyConvert(rhs.type, lhs.type)) { ThrowSemaError(&ctx, "赋值语句两侧类型不兼容"); } } void CheckReturn(SysYParser::StmtContext& ctx) { if (!current_function_) { ThrowSemaError(&ctx, "return 语句不在函数内部"); } if (current_function_->return_type == SemanticType::Void) { if (ctx.exp()) { ThrowSemaError(&ctx, "void 函数不能返回值"); } return; } if (!ctx.exp()) { ThrowSemaError(&ctx, "非 void 函数必须返回值"); } ExprInfo expr = EvalExpr(*ctx.exp()); RequireScalar(ctx.exp(), expr, "return 表达式必须是标量"); if (!CanImplicitlyConvert(expr.type, current_function_->return_type)) { ThrowSemaError(&ctx, "return 表达式类型与函数返回类型不匹配"); } } void DeclareConst(SysYParser::ConstDefContext& ctx, SemanticType type) { ObjectBinding symbol; symbol.name = ctx.ID()->getText(); symbol.type = type; symbol.decl_kind = ObjectBinding::DeclKind::Const; symbol.const_def = &ctx; symbol.dimensions = EvalArrayDims(ctx.constIndex(), true); if (symbols_.ContainsInCurrentScope(symbol.name)) { ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name); } if (symbols_.Depth() == 1 && sema_.ResolveFunction(symbol.name)) { ThrowSemaError(&ctx, "全局对象与函数重名: " + symbol.name); } if (!ctx.constInitVal()) { ThrowSemaError(&ctx, "const 对象缺少初始化"); } if (symbol.dimensions.empty()) { symbol.const_value = ValidateConstInitScalar(*ctx.constInitVal(), type); symbol.has_const_value = true; } else { ValidateConstInitAggregate(*ctx.constInitVal(), type); } if (!symbols_.Add(symbol)) { ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name); } } void DeclareVar(SysYParser::VarDefContext& ctx, SemanticType type) { ObjectBinding symbol; symbol.name = ctx.ID()->getText(); symbol.type = type; symbol.decl_kind = ObjectBinding::DeclKind::Var; symbol.var_def = &ctx; symbol.dimensions = EvalArrayDims(ctx.constIndex(), true); if (symbols_.ContainsInCurrentScope(symbol.name)) { ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name); } if (symbols_.Depth() == 1 && sema_.ResolveFunction(symbol.name)) { ThrowSemaError(&ctx, "全局对象与函数重名: " + symbol.name); } if (!symbols_.Add(symbol)) { ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name); } if (!ctx.initVal()) { return; } if (symbol.dimensions.empty()) { ValidateVarInitScalar(*ctx.initVal(), type, symbols_.Depth() == 1); } else { ValidateVarInitAggregate(*ctx.initVal(), type, symbols_.Depth() == 1); } } std::vector EvalArrayDims( const std::vector& indices, bool require_positive) { std::vector dims; dims.reserve(indices.size()); for (auto* index : indices) { if (!index || !index->constExp()) { ThrowSemaError(index, "数组维度缺少常量表达式"); } ExprInfo expr = EvalExpr(*index->constExp()); if (!expr.IsScalar() || !expr.has_const_value) { ThrowSemaError(index, "数组维度必须是整型常量表达式"); } const int dim = ConvertToInt(CastConstant(expr.const_value, SemanticType::Int)); if (require_positive && dim <= 0) { ThrowSemaError(index, "数组维度必须为正整数"); } dims.push_back(dim); } return dims; } ScalarConstant ValidateConstInitScalar(SysYParser::ConstInitValContext& init, SemanticType target_type) { if (!init.constExp()) { ThrowSemaError(&init, "标量 const 初始化必须是常量表达式"); } ExprInfo expr = EvalExpr(*init.constExp()); if (!expr.IsScalar() || !expr.has_const_value) { ThrowSemaError(&init, "标量 const 初始化必须是常量表达式"); } return CastConstant(expr.const_value, target_type); } void ValidateConstInitAggregate(SysYParser::ConstInitValContext& init, SemanticType target_type) { if (init.constExp()) { ExprInfo expr = EvalExpr(*init.constExp()); if (!expr.IsScalar() || !expr.has_const_value) { ThrowSemaError(&init, "数组 const 初始化要求常量表达式"); } CastConstant(expr.const_value, target_type); return; } for (auto* nested : init.constInitVal()) { if (nested) { ValidateConstInitAggregate(*nested, target_type); } } } void ValidateVarInitScalar(SysYParser::InitValContext& init, SemanticType target_type, bool require_constant) { if (!init.exp()) { ThrowSemaError(&init, "标量初始化非法"); } ExprInfo expr = EvalExpr(*init.exp()); RequireScalar(&init, expr, "标量初始化要求标量表达式"); if (!CanImplicitlyConvert(expr.type, target_type)) { ThrowSemaError(&init, "初始化表达式类型不兼容"); } if (require_constant && !expr.has_const_value) { ThrowSemaError(&init, "全局变量初始化要求编译期常量"); } } void ValidateVarInitAggregate(SysYParser::InitValContext& init, SemanticType target_type, bool require_constant) { if (init.exp()) { ExprInfo expr = EvalExpr(*init.exp()); RequireScalar(&init, expr, "数组初始化元素必须是标量表达式"); if (!CanImplicitlyConvert(expr.type, target_type)) { ThrowSemaError(&init, "数组初始化元素类型不兼容"); } if (require_constant && !expr.has_const_value) { ThrowSemaError(&init, "全局数组初始化要求编译期常量"); } return; } for (auto* nested : init.initVal()) { if (nested) { ValidateVarInitAggregate(*nested, target_type, require_constant); } } } ExprInfo EvalArithmetic(const antlr4::ParserRuleContext& ctx, const ExprInfo& lhs, const ExprInfo& rhs, char op) { RequireScalar(&ctx, lhs, "算术运算要求标量操作数"); RequireScalar(&ctx, rhs, "算术运算要求标量操作数"); ExprInfo result; result.type = lhs.type == SemanticType::Float || rhs.type == SemanticType::Float ? SemanticType::Float : SemanticType::Int; if (!lhs.has_const_value || !rhs.has_const_value) { return result; } result.has_const_value = true; const ScalarConstant lc = CastConstant(lhs.const_value, result.type); const ScalarConstant rc = CastConstant(rhs.const_value, result.type); if (result.type == SemanticType::Float) { double value = 0.0; if (op == '+') value = lc.number + rc.number; if (op == '-') value = lc.number - rc.number; if (op == '*') value = lc.number * rc.number; if (op == '/') value = lc.number / rc.number; if (op == '%') { ThrowSemaError(&ctx, "浮点数不支持取模运算"); } result.const_value = MakeFloat(value); return result; } const int li = ConvertToInt(lc); const int ri = ConvertToInt(rc); int value = 0; if (op == '+') value = li + ri; if (op == '-') value = li - ri; if (op == '*') value = li * ri; if (op == '/') value = li / ri; if (op == '%') value = li % ri; result.const_value = MakeInt(value); return result; } ExprInfo EvalCompare(antlr4::ParserRuleContext& ctx, const ExprInfo& lhs, const ExprInfo& rhs) { RequireScalar(&ctx, lhs, "比较运算要求标量操作数"); RequireScalar(&ctx, rhs, "比较运算要求标量操作数"); ExprInfo result; result.type = SemanticType::Int; if (!lhs.has_const_value || !rhs.has_const_value) { return result; } const SemanticType promoted = lhs.type == SemanticType::Float || rhs.type == SemanticType::Float ? SemanticType::Float : SemanticType::Int; const ScalarConstant lc = CastConstant(lhs.const_value, promoted); const ScalarConstant rc = CastConstant(rhs.const_value, promoted); bool value = false; if (auto* rel = dynamic_cast(&ctx)) { if (rel->LT()) value = lc.number < rc.number; if (rel->GT()) value = lc.number > rc.number; if (rel->LE()) value = lc.number <= rc.number; if (rel->GE()) value = lc.number >= rc.number; } else if (auto* eq = dynamic_cast(&ctx)) { if (eq->EQ()) value = lc.number == rc.number; if (eq->NE()) value = lc.number != rc.number; } result.has_const_value = true; result.const_value = MakeInt(value ? 1 : 0); return result; } ExprInfo EvalLogical(const antlr4::ParserRuleContext& ctx, const ExprInfo& lhs, const ExprInfo& rhs, bool is_and) { RequireScalar(&ctx, lhs, "逻辑运算要求标量操作数"); RequireScalar(&ctx, rhs, "逻辑运算要求标量操作数"); ExprInfo result; result.type = SemanticType::Int; if (!lhs.has_const_value || !rhs.has_const_value) { return result; } const bool value = is_and ? (IsTrue(lhs.const_value) && IsTrue(rhs.const_value)) : (IsTrue(lhs.const_value) || IsTrue(rhs.const_value)); result.has_const_value = true; result.const_value = MakeInt(value ? 1 : 0); return result; } void RequireScalar(const antlr4::ParserRuleContext* ctx, const ExprInfo& expr, std::string_view message) { if (!expr.IsScalar()) { ThrowSemaError(ctx, message); } } void CollectFunctions(SysYParser::CompUnitContext& ctx) { for (auto* item : ctx.topLevelItem()) { if (!item || !item->funcDef()) { continue; } FunctionBinding fn = BuildFunctionSignature(*item->funcDef()); if (sema_.ResolveFunction(fn.name)) { ThrowSemaError(item->funcDef(), "重复定义函数: " + fn.name); } if (symbols_.ContainsInCurrentScope(fn.name)) { ThrowSemaError(item->funcDef(), "函数与全局对象重名: " + fn.name); } sema_.RegisterFunction(std::move(fn)); } } FunctionBinding BuildFunctionSignature(SysYParser::FuncDefContext& ctx) { FunctionBinding fn; fn.name = ctx.ID()->getText(); fn.return_type = ParseFuncType(*ctx.funcType()); fn.func_def = &ctx; if (ctx.funcFParams()) { for (auto* param : ctx.funcFParams()->funcFParam()) { fn.params.push_back(BuildParamBinding(*param)); } } return fn; } ObjectBinding BuildParamBinding(SysYParser::FuncFParamContext& ctx) { if (!ctx.ID() || !ctx.bType()) { ThrowSemaError(&ctx, "非法函数形参"); } ObjectBinding param; param.name = ctx.ID()->getText(); param.type = ParseBType(*ctx.bType()); param.decl_kind = ObjectBinding::DeclKind::Param; param.func_param = &ctx; if (!ctx.LBRACK().empty()) { param.is_array_param = true; param.dimensions.push_back(kUnknownArrayDim); for (auto* exp : ctx.exp()) { ExprInfo dim = EvalExpr(*exp); if (!dim.IsScalar() || !dim.has_const_value) { ThrowSemaError(&ctx, "数组形参维度必须是整型常量表达式"); } const int value = ConvertToInt(CastConstant(dim.const_value, SemanticType::Int)); if (value <= 0) { ThrowSemaError(&ctx, "数组形参维度必须为正整数"); } param.dimensions.push_back(value); } } return param; } void RegisterBuiltins() { sema_.RegisterFunction(MakeBuiltinFunction("getint", SemanticType::Int, {})); sema_.RegisterFunction(MakeBuiltinFunction("getch", SemanticType::Int, {})); sema_.RegisterFunction( MakeBuiltinFunction("getfloat", SemanticType::Float, {})); sema_.RegisterFunction(MakeBuiltinFunction( "getarray", SemanticType::Int, {MakeParam("a", SemanticType::Int, {kUnknownArrayDim}, true)})); sema_.RegisterFunction(MakeBuiltinFunction( "getfarray", SemanticType::Int, {MakeParam("a", SemanticType::Float, {kUnknownArrayDim}, true)})); sema_.RegisterFunction(MakeBuiltinFunction( "putint", SemanticType::Void, {MakeParam("x", SemanticType::Int)})); sema_.RegisterFunction(MakeBuiltinFunction( "putch", SemanticType::Void, {MakeParam("x", SemanticType::Int)})); sema_.RegisterFunction(MakeBuiltinFunction( "putfloat", SemanticType::Void, {MakeParam("x", SemanticType::Float)})); sema_.RegisterFunction(MakeBuiltinFunction( "putarray", SemanticType::Void, {MakeParam("n", SemanticType::Int), MakeParam("a", SemanticType::Int, {kUnknownArrayDim}, true)})); sema_.RegisterFunction(MakeBuiltinFunction( "putfarray", SemanticType::Void, {MakeParam("n", SemanticType::Int), MakeParam("a", SemanticType::Float, {kUnknownArrayDim}, true)})); sema_.RegisterFunction( MakeBuiltinFunction("starttime", SemanticType::Void, {})); sema_.RegisterFunction( MakeBuiltinFunction("stoptime", SemanticType::Void, {})); } SymbolTable symbols_; SemanticContext sema_; const FunctionBinding* current_function_ = nullptr; int loop_depth_ = 0; }; } // namespace void SemanticContext::BindObjectUse(const SysYParser::LValContext* use, ObjectBinding binding) { object_uses_[use] = std::move(binding); } const ObjectBinding* SemanticContext::ResolveObjectUse( const SysYParser::LValContext* use) const { auto it = object_uses_.find(use); return it == object_uses_.end() ? nullptr : &it->second; } void SemanticContext::BindFunctionCall(const SysYParser::UnaryExpContext* call, FunctionBinding binding) { function_calls_[call] = std::move(binding); } const FunctionBinding* SemanticContext::ResolveFunctionCall( const SysYParser::UnaryExpContext* call) const { auto it = function_calls_.find(call); return it == function_calls_.end() ? nullptr : &it->second; } void SemanticContext::RegisterFunction(FunctionBinding binding) { functions_[binding.name] = std::move(binding); } const FunctionBinding* SemanticContext::ResolveFunction( const std::string& name) const { auto it = functions_.find(name); return it == functions_.end() ? nullptr : &it->second; } SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { SemaVisitor visitor; comp_unit.accept(&visitor); return visitor.TakeSemanticContext(); }