#include "sem/Sema.h" #include #include #include #include #include "SysYBaseVisitor.h" #include "sem/SymbolTable.h" #include "utils/Log.h" namespace { // 获取左值名称的辅助函数 std::string GetLValueName(SysYParser::LValContext& lval) { if (!lval.Ident()) { throw std::runtime_error(FormatError("sema", "非法左值")); } return lval.Ident()->getText(); } // 从 BTypeContext 获取类型 std::shared_ptr GetTypeFromBType(SysYParser::BTypeContext* ctx) { if (!ctx) return ir::Type::GetInt32Type(); if (ctx->Int()) return ir::Type::GetInt32Type(); if (ctx->Float()) return ir::Type::GetFloatType(); return ir::Type::GetInt32Type(); } // 语义分析 Visitor class SemaVisitor final : public SysYBaseVisitor { public: SemaVisitor() : table_() {} std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { if (!ctx) { throw std::runtime_error(FormatError("sema", "缺少编译单元")); } table_.enterScope(); // 创建全局作用域 for (auto* func : ctx->funcDef()) { // 收集所有函数声明(处理互相调用) CollectFunctionDeclaration(func); } for (auto* decl : ctx->decl()) { // 处理所有声明和定义 if (decl) decl->accept(this); } for (auto* func : ctx->funcDef()) { if (func) func->accept(this); } CheckMainFunction(); // 检查 main 函数存在且正确 table_.exitScope(); // 退出全局作用域 return {}; } std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("sema", "函数定义缺少标识符")); } std::string name = ctx->Ident()->getText(); std::shared_ptr return_type; // 获取返回类型 if (ctx->funcType()) { if (ctx->funcType()->Void()) { return_type = ir::Type::GetVoidType(); } else if (ctx->funcType()->Int()) { return_type = ir::Type::GetInt32Type(); } else if (ctx->funcType()->Float()) { return_type = ir::Type::GetFloatType(); } else { return_type = ir::Type::GetInt32Type(); } } else { return_type = ir::Type::GetInt32Type(); } std::cout << "[DEBUG] 进入函数: " << name << " 返回类型: " << (return_type->IsInt32() ? "int" : return_type->IsFloat() ? "float" : "void") << std::endl; // 记录当前函数返回类型(用于 return 检查) current_func_return_type_ = return_type; current_func_has_return_ = false; table_.enterScope(); if (ctx->funcFParams()) { // 处理参数 CollectFunctionParams(ctx->funcFParams()); } if (ctx->block()) { // 处理函数体 ctx->block()->accept(this); } std::cout << "[DEBUG] 函数 " << name << " has_return: " << current_func_has_return_ << " return_type_is_void: " << return_type->IsVoid() << std::endl; if (!return_type->IsVoid() && !current_func_has_return_) { // 检查非 void 函数是否有 return throw std::runtime_error(FormatError("sema", "非 void 函数 " + name + " 缺少 return 语句")); } table_.exitScope(); current_func_return_type_ = nullptr; current_func_has_return_ = false; return {}; } std::any visitBlock(SysYParser::BlockContext* ctx) override { if (!ctx) { throw std::runtime_error(FormatError("sema", "缺少语句块")); } table_.enterScope(); for (auto* item : ctx->blockItem()) { // 处理所有 blockItem if (item) { item->accept(this); // 如果已经有 return,可以继续(但 return 必须是最后一条) // 注意:这里不需要跳出,因为 return 语句本身已经标记了 } } table_.exitScope(); return {}; } std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { if (!ctx) { throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); } if (ctx->decl()) { ctx->decl()->accept(this); return {}; } if (ctx->stmt()) { ctx->stmt()->accept(this); return {}; } throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); } std::any visitDecl(SysYParser::DeclContext* ctx) override { if (!ctx) { throw std::runtime_error(FormatError("sema", "非法变量声明")); } if (ctx->constDecl()) { ctx->constDecl()->accept(this); } else if (ctx->varDecl()) { ctx->varDecl()->accept(this); } return {}; } // ==================== 变量声明 ==================== std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override { if (!ctx || !ctx->bType()) { throw std::runtime_error(FormatError("sema", "非法变量声明")); } std::shared_ptr base_type = GetTypeFromBType(ctx->bType()); bool is_global = (table_.currentScopeLevel() == 0); for (auto* var_def : ctx->varDef()) { if (var_def) { CheckVarDef(var_def, base_type, is_global); } } return {}; } void CheckVarDef(SysYParser::VarDefContext* ctx, std::shared_ptr base_type, bool is_global) { if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("sema", "非法变量定义")); } std::string name = ctx->Ident()->getText(); if (table_.lookupCurrent(name)) { // 检查重复定义 throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); } // 确定类型(处理数组维度) std::shared_ptr type = base_type; std::vector dims; bool is_array = !ctx->constExp().empty(); // 调试输出 std::cout << "[DEBUG] CheckVarDef: " << name << " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown") << " is_array: " << is_array << " dim_count: " << ctx->constExp().size() << std::endl; if (is_array) { // 处理数组维度 for (auto* dim_exp : ctx->constExp()) { int dim = EvaluateConstExp(dim_exp); if (dim <= 0) { throw std::runtime_error(FormatError("sema", "数组维度必须为正整数")); } dims.push_back(dim); std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl; } // 创建数组类型 type = ir::Type::GetArrayType(base_type, dims); std::cout << "[DEBUG] 创建数组类型完成" << std::endl; std::cout << "[DEBUG] type->IsArray(): " << type->IsArray() << std::endl; std::cout << "[DEBUG] type->GetKind(): " << (int)type->GetKind() << std::endl; // 验证数组类型 if (type->IsArray()) { auto* arr_type = dynamic_cast(type.get()); if (arr_type) { std::cout << "[DEBUG] ArrayType dimensions: "; for (int d : arr_type->GetDimensions()) { std::cout << d << " "; } std::cout << std::endl; std::cout << "[DEBUG] Element type: " << (arr_type->GetElementType()->IsInt32() ? "int" : arr_type->GetElementType()->IsFloat() ? "float" : "unknown") << std::endl; } } } bool has_init = (ctx->initVal() != nullptr); // 处理初始化 if (is_global && has_init) { CheckGlobalInitIsConst(ctx->initVal()); // 全局变量初始化必须是常量表达式 } // 创建符号 Symbol sym; sym.name = name; sym.kind = SymbolKind::Variable; sym.type = type; sym.scope_level = table_.currentScopeLevel(); sym.is_initialized = has_init; sym.var_def_ctx = ctx; if (is_array) { // 存储维度信息,但 param_types 通常用于函数参数 // 数组变量的维度信息已经包含在 type 中 sym.param_types.clear(); // 确保不混淆 } table_.addSymbol(sym); // 添加到符号表 std::cout << "[DEBUG] 符号添加完成: " << name << " type_kind: " << (int)sym.type->GetKind() << " is_array: " << sym.type->IsArray() << std::endl; } // ==================== 常量声明 ==================== std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override { if (!ctx || !ctx->bType()) { throw std::runtime_error(FormatError("sema", "非法常量声明")); } std::shared_ptr base_type = GetTypeFromBType(ctx->bType()); for (auto* const_def : ctx->constDef()) { if (const_def) { CheckConstDef(const_def, base_type); } } return {}; } void CheckConstDef(SysYParser::ConstDefContext* ctx, std::shared_ptr base_type) { if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("sema", "非法常量定义")); } std::string name = ctx->Ident()->getText(); if (table_.lookupCurrent(name)) { throw std::runtime_error(FormatError("sema", "重复定义常量: " + name)); } // 确定类型 std::shared_ptr type = base_type; std::vector dims; bool is_array = !ctx->constExp().empty(); std::cout << "[DEBUG] CheckConstDef: " << name << " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown") << " is_array: " << is_array << " dim_count: " << ctx->constExp().size() << std::endl; if (is_array) { for (auto* dim_exp : ctx->constExp()) { int dim = EvaluateConstExp(dim_exp); if (dim <= 0) { throw std::runtime_error(FormatError("sema", "数组维度必须为正整数")); } dims.push_back(dim); std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl; } type = ir::Type::GetArrayType(base_type, dims); std::cout << "[DEBUG] 创建数组类型完成,IsArray: " << type->IsArray() << std::endl; } // 求值初始化器 std::vector init_values; if (ctx->constInitVal()) { init_values = EvaluateConstInitVal(ctx->constInitVal(), dims, base_type); std::cout << "[DEBUG] 初始化值数量: " << init_values.size() << std::endl; } // 检查初始化值数量 size_t expected_count = 1; if (is_array) { expected_count = 1; for (int d : dims) expected_count *= d; std::cout << "[DEBUG] 期望元素数量: " << expected_count << std::endl; } if (init_values.size() > expected_count) { throw std::runtime_error(FormatError("sema", "初始化值过多")); } Symbol sym; sym.name = name; sym.kind = SymbolKind::Constant; sym.type = type; sym.scope_level = table_.currentScopeLevel(); sym.is_initialized = true; sym.var_def_ctx = nullptr; // 存储常量值(仅对非数组有效) if (!is_array && !init_values.empty()) { if (base_type->IsInt32() && init_values[0].is_int) { sym.is_int_const = true; sym.const_value.i32 = init_values[0].int_val; std::cout << "[DEBUG] 存储整型常量值: " << init_values[0].int_val << std::endl; } else if (base_type->IsFloat() && !init_values[0].is_int) { sym.is_int_const = false; sym.const_value.f32 = init_values[0].float_val; std::cout << "[DEBUG] 存储浮点常量值: " << init_values[0].float_val << std::endl; } } else if (is_array) { std::cout << "[DEBUG] 数组常量,不存储单个常量值" << std::endl; } table_.addSymbol(sym); std::cout << "[DEBUG] 常量符号添加完成" << std::endl; } // ==================== 语句语义检查 ==================== // 处理所有语句 - 通过运行时类型判断 std::any visitStmt(SysYParser::StmtContext* ctx) override { if (!ctx) return {}; // 调试输出 std::cout << "[DEBUG] visitStmt: "; if (ctx->Return()) std::cout << "Return "; if (ctx->If()) std::cout << "If "; if (ctx->While()) std::cout << "While "; if (ctx->Break()) std::cout << "Break "; if (ctx->Continue()) std::cout << "Continue "; if (ctx->lVal() && ctx->Assign()) std::cout << "Assign "; if (ctx->exp() && ctx->Semi()) std::cout << "ExpStmt "; if (ctx->block()) std::cout << "Block "; std::cout << std::endl; // 判断语句类型 - 注意:Return() 返回的是 TerminalNode* if (ctx->Return() != nullptr) { // return 语句 std::cout << "[DEBUG] 检测到 return 语句" << std::endl; return visitReturnStmtInternal(ctx); } else if (ctx->lVal() != nullptr && ctx->Assign() != nullptr) { // 赋值语句 return visitAssignStmt(ctx); } else if (ctx->exp() != nullptr && ctx->Semi() != nullptr) { // 表达式语句(可能有表达式) return visitExpStmt(ctx); } else if (ctx->block() != nullptr) { // 块语句 return ctx->block()->accept(this); } else if (ctx->If() != nullptr) { // if 语句 return visitIfStmtInternal(ctx); } else if (ctx->While() != nullptr) { // while 语句 return visitWhileStmtInternal(ctx); } else if (ctx->Break() != nullptr) { // break 语句 return visitBreakStmtInternal(ctx); } else if (ctx->Continue() != nullptr) { // continue 语句 return visitContinueStmtInternal(ctx); } return {}; } // return 语句内部实现 std::any visitReturnStmtInternal(SysYParser::StmtContext* ctx) { std::cout << "[DEBUG] visitReturnStmtInternal 被调用" << std::endl; std::shared_ptr expected = current_func_return_type_; if (!expected) { throw std::runtime_error(FormatError("sema", "return 语句不在函数体内")); } if (ctx->exp() != nullptr) { // 有返回值的 return std::cout << "[DEBUG] 有返回值的 return" << std::endl; ExprInfo ret_val = CheckExp(ctx->exp()); if (expected->IsVoid()) { throw std::runtime_error(FormatError("sema", "void 函数不能返回值")); } else if (!IsTypeCompatible(ret_val.type, expected)) { throw std::runtime_error(FormatError("sema", "返回值类型不匹配")); } // 标记需要隐式转换 if (ret_val.type != expected) { sema_.AddConversion(ctx->exp(), ret_val.type, expected); } // 设置 has_return 标志 current_func_has_return_ = true; std::cout << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl; } else { // 无返回值的 return std::cout << "[DEBUG] 无返回值的 return" << std::endl; if (!expected->IsVoid()) { throw std::runtime_error(FormatError("sema", "非 void 函数必须返回值")); } // 设置 has_return 标志 current_func_has_return_ = true; std::cout << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl; } return {}; } // 左值表达式(变量引用) std::any visitLVal(SysYParser::LValContext* ctx) override { std::cout << "[DEBUG] visitLVal: " << ctx->getText() << std::endl; if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("sema", "非法变量引用")); } std::string name = ctx->Ident()->getText(); auto* sym = table_.lookup(name); if (!sym) { throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); } // 检查数组访问 bool is_array_access = !ctx->exp().empty(); std::cout << "[DEBUG] name: " << name << ", is_array_access: " << is_array_access << ", subscript_count: " << ctx->exp().size() << std::endl; ExprInfo result; // 判断是否为数组类型或指针类型(数组参数) bool is_array_or_ptr = false; if (sym->type) { is_array_or_ptr = sym->type->IsArray() || sym->type->IsPtrInt32() || sym->type->IsPtrFloat(); std::cout << "[DEBUG] type_kind: " << (int)sym->type->GetKind() << ", is_array: " << sym->type->IsArray() << ", is_ptr: " << (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) << std::endl; } if (is_array_or_ptr) { // 获取维度信息 size_t dim_count = 0; std::shared_ptr elem_type = sym->type; if (sym->type->IsArray()) { if (auto* arr_type = dynamic_cast(sym->type.get())) { dim_count = arr_type->GetDimensions().size(); elem_type = arr_type->GetElementType(); std::cout << "[DEBUG] 数组维度: " << dim_count << std::endl; } } else if (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) { dim_count = 1; if (sym->type->IsPtrInt32()) { elem_type = ir::Type::GetInt32Type(); } else if (sym->type->IsPtrFloat()) { elem_type = ir::Type::GetFloatType(); } std::cout << "[DEBUG] 指针类型, dim_count: 1" << std::endl; } if (is_array_access) { std::cout << "[DEBUG] 有下标访问,期望维度: " << dim_count << ", 实际下标数: " << ctx->exp().size() << std::endl; if (ctx->exp().size() != dim_count) { throw std::runtime_error(FormatError("sema", "数组下标个数不匹配")); } for (auto* idx_exp : ctx->exp()) { ExprInfo idx = CheckExp(idx_exp); if (!idx.type->IsInt32()) { throw std::runtime_error(FormatError("sema", "数组下标必须是 int 类型")); } } result.type = elem_type; result.is_lvalue = true; result.is_const = false; } else { std::cout << "[DEBUG] 无下标访问" << std::endl; if (sym->type->IsArray()) { std::cout << "[DEBUG] 数组名作为地址,转换为指针" << std::endl; if (auto* arr_type = dynamic_cast(sym->type.get())) { if (arr_type->GetElementType()->IsInt32()) { result.type = ir::Type::GetPtrInt32Type(); } else if (arr_type->GetElementType()->IsFloat()) { result.type = ir::Type::GetPtrFloatType(); } else { result.type = ir::Type::GetPtrInt32Type(); } } else { result.type = ir::Type::GetPtrInt32Type(); } result.is_lvalue = false; result.is_const = true; } else { result.type = sym->type; result.is_lvalue = true; result.is_const = (sym->kind == SymbolKind::Constant); } } } else { if (is_array_access) { throw std::runtime_error(FormatError("sema", "非数组变量不能使用下标: " + name)); } result.type = sym->type; result.is_lvalue = true; result.is_const = (sym->kind == SymbolKind::Constant); if (result.is_const && sym->type && !sym->type->IsArray()) { if (sym->is_int_const) { result.is_const_int = true; result.const_int_value = sym->const_value.i32; } else { result.const_float_value = sym->const_value.f32; } } } sema_.SetExprType(ctx, result); return {}; } // if 语句内部实现 std::any visitIfStmtInternal(SysYParser::StmtContext* ctx) { // 检查条件表达式 if (ctx->cond()) { ExprInfo cond = CheckCond(ctx->cond()); // CheckCond 已经处理了类型转换 // 不需要额外检查,因为 CheckCond 已经确保类型正确 } // 处理 then 分支 if (ctx->stmt().size() > 0) { ctx->stmt()[0]->accept(this); } // 处理 else 分支 if (ctx->stmt().size() > 1) { ctx->stmt()[1]->accept(this); } return {}; } // while 语句内部实现 std::any visitWhileStmtInternal(SysYParser::StmtContext* ctx) { if (ctx->cond()) { ExprInfo cond = CheckCond(ctx->cond()); // CheckCond 已经处理了类型转换 // 不需要额外检查 } loop_stack_.push_back({true, ctx}); if (ctx->stmt().size() > 0) { ctx->stmt()[0]->accept(this); } loop_stack_.pop_back(); return {}; } // break 语句内部实现 std::any visitBreakStmtInternal(SysYParser::StmtContext* ctx) { if (loop_stack_.empty() || !loop_stack_.back().in_loop) { throw std::runtime_error(FormatError("sema", "break 语句必须在循环体内使用")); } return {}; } // continue 语句内部实现 std::any visitContinueStmtInternal(SysYParser::StmtContext* ctx) { if (loop_stack_.empty() || !loop_stack_.back().in_loop) { throw std::runtime_error(FormatError("sema", "continue 语句必须在循环体内使用")); } return {}; } // 赋值语句内部实现 std::any visitAssignStmt(SysYParser::StmtContext* ctx) { if (!ctx->lVal() || !ctx->exp()) { throw std::runtime_error(FormatError("sema", "非法赋值语句")); } ExprInfo lvalue = CheckLValue(ctx->lVal()); // 检查左值 if (lvalue.is_const) { throw std::runtime_error(FormatError("sema", "不能给常量赋值")); } if (!lvalue.is_lvalue) { throw std::runtime_error(FormatError("sema", "赋值左边必须是左值")); } ExprInfo rvalue = CheckExp(ctx->exp()); // 检查右值 if (!IsTypeCompatible(rvalue.type, lvalue.type)) { throw std::runtime_error(FormatError("sema", "赋值类型不匹配")); } if (rvalue.type != lvalue.type) { // 标记需要隐式转换 sema_.AddConversion(ctx->exp(), rvalue.type, lvalue.type); } return {}; } // 表达式语句内部实现 std::any visitExpStmt(SysYParser::StmtContext* ctx) { if (ctx->exp()) { CheckExp(ctx->exp()); } return {}; } // ==================== 表达式类型推导 ==================== // 主表达式 std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override { std::cout << "[DEBUG] visitPrimaryExp: " << ctx->getText() << std::endl; ExprInfo result; if (ctx->lVal()) { // 左值表达式 result = CheckLValue(ctx->lVal()); result.is_lvalue = true; } else if (ctx->HEX_FLOAT() || ctx->DEC_FLOAT()) { // 浮点字面量 result.type = ir::Type::GetFloatType(); result.is_const = true; result.is_const_int = false; std::string text; if (ctx->HEX_FLOAT()) text = ctx->HEX_FLOAT()->getText(); else text = ctx->DEC_FLOAT()->getText(); result.const_float_value = std::stof(text); } else if (ctx->HEX_INT() || ctx->OCTAL_INT() || ctx->DECIMAL_INT() || ctx->ZERO()) { // 整数字面量 result.type = ir::Type::GetInt32Type(); result.is_const = true; result.is_const_int = true; std::string text; if (ctx->HEX_INT()) text = ctx->HEX_INT()->getText(); else if (ctx->OCTAL_INT()) text = ctx->OCTAL_INT()->getText(); else if (ctx->DECIMAL_INT()) text = ctx->DECIMAL_INT()->getText(); else text = ctx->ZERO()->getText(); result.const_int_value = std::stoi(text, nullptr, 0); } else if (ctx->exp()) { // 括号表达式 result = CheckExp(ctx->exp()); } sema_.SetExprType(ctx, result); return {}; } // 一元表达式 std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override { std::cout << "[DEBUG] visitUnaryExp: " << ctx->getText() << std::endl; ExprInfo result; if (ctx->primaryExp()) { ctx->primaryExp()->accept(this); auto* info = sema_.GetExprType(ctx->primaryExp()); if (info) result = *info; } else if (ctx->Ident() && ctx->L_PAREN()) { // 函数调用 std::cout << "[DEBUG] 函数调用: " << ctx->Ident()->getText() << std::endl; result = CheckFuncCall(ctx); } else if (ctx->unaryOp()) { // 一元运算 ctx->unaryExp()->accept(this); auto* operand = sema_.GetExprType(ctx->unaryExp()); if (!operand) { throw std::runtime_error(FormatError("sema", "一元操作数类型推导失败")); result.type = ir::Type::GetInt32Type(); } else { std::string op = ctx->unaryOp()->getText(); if (op == "!") { // 逻辑非:要求操作数是 int 类型,或者可以转换为 int 的 float if (operand->type->IsInt32()) { // 已经是 int,没问题 } else if (operand->type->IsFloat()) { // float 可以隐式转换为 int sema_.AddConversion(ctx->unaryExp(), operand->type, ir::Type::GetInt32Type()); // 更新操作数类型为 int operand->type = ir::Type::GetInt32Type(); operand->is_const_int = true; if (operand->is_const && !operand->is_const_int) { // 如果原来是 float 常量,转换为 int 常量 operand->const_int_value = (int)operand->const_float_value; operand->is_const_int = true; } } else { throw std::runtime_error(FormatError("sema", "逻辑非操作数必须是 int 类型或可以转换为 int 的 float 类型")); } result.type = ir::Type::GetInt32Type(); result.is_lvalue = false; result.is_const = operand->is_const; if (operand->is_const && operand->is_const_int) { result.is_const_int = true; result.const_int_value = (operand->const_int_value == 0) ? 1 : 0; } } else { // 正负号 if (!operand->type->IsInt32() && !operand->type->IsFloat()) { throw std::runtime_error(FormatError("sema", "正负号操作数必须是算术类型")); } result.type = operand->type; result.is_lvalue = false; result.is_const = operand->is_const; if (op == "-" && operand->is_const) { if (operand->type->IsInt32() && operand->is_const_int) { result.is_const_int = true; result.const_int_value = -operand->const_int_value; } else if (operand->type->IsFloat()) { result.const_float_value = -operand->const_float_value; } } } } } sema_.SetExprType(ctx, result); return {}; } // 乘除模表达式 std::any visitMulExp(SysYParser::MulExpContext* ctx) override { ExprInfo result; if (ctx->mulExp()) { ctx->mulExp()->accept(this); ctx->unaryExp()->accept(this); auto* left_info = sema_.GetExprType(ctx->mulExp()); auto* right_info = sema_.GetExprType(ctx->unaryExp()); if (!left_info || !right_info) { throw std::runtime_error(FormatError("sema", "乘除模操作数类型推导失败")); result.type = ir::Type::GetInt32Type(); } else { std::string op; if (ctx->MulOp()) { op = "*"; } else if (ctx->DivOp()) { op = "/"; } else if (ctx->QuoOp()) { op = "%"; } result = CheckBinaryOp(left_info, right_info, op, ctx); } } else { ctx->unaryExp()->accept(this); auto* info = sema_.GetExprType(ctx->unaryExp()); if (info) { sema_.SetExprType(ctx, *info); } return {}; } sema_.SetExprType(ctx, result); return {}; } // 加减表达式 std::any visitAddExp(SysYParser::AddExpContext* ctx) override { ExprInfo result; if (ctx->addExp()) { ctx->addExp()->accept(this); ctx->mulExp()->accept(this); auto* left_info = sema_.GetExprType(ctx->addExp()); auto* right_info = sema_.GetExprType(ctx->mulExp()); if (!left_info || !right_info) { throw std::runtime_error(FormatError("sema", "加减操作数类型推导失败")); result.type = ir::Type::GetInt32Type(); } else { std::string op; if (ctx->AddOp()) { op = "+"; } else if (ctx->SubOp()) { op = "-"; } result = CheckBinaryOp(left_info, right_info, op, ctx); } } else { ctx->mulExp()->accept(this); auto* info = sema_.GetExprType(ctx->mulExp()); if (info) { sema_.SetExprType(ctx, *info); } return {}; } sema_.SetExprType(ctx, result); return {}; } // 关系表达式 std::any visitRelExp(SysYParser::RelExpContext* ctx) override { ExprInfo result; if (ctx->relExp()) { ctx->relExp()->accept(this); ctx->addExp()->accept(this); auto* left_info = sema_.GetExprType(ctx->relExp()); auto* right_info = sema_.GetExprType(ctx->addExp()); if (!left_info || !right_info) { throw std::runtime_error(FormatError("sema", "关系操作数类型推导失败")); } else { if (!left_info->type->IsInt32() && !left_info->type->IsFloat()) { throw std::runtime_error(FormatError("sema", "关系运算操作数必须是算术类型")); } std::string op; if (ctx->LOp()) { op = "<"; } else if (ctx->GOp()) { op = ">"; } else if (ctx->LeOp()) { op = "<="; } else if (ctx->GeOp()) { op = ">="; } result.type = ir::Type::GetInt32Type(); result.is_lvalue = false; if (left_info->is_const && right_info->is_const) { result.is_const = true; result.is_const_int = true; float l = GetFloatValue(*left_info); float r = GetFloatValue(*right_info); if (op == "<") result.const_int_value = (l < r) ? 1 : 0; else if (op == ">") result.const_int_value = (l > r) ? 1 : 0; else if (op == "<=") result.const_int_value = (l <= r) ? 1 : 0; else if (op == ">=") result.const_int_value = (l >= r) ? 1 : 0; } } } else { ctx->addExp()->accept(this); auto* info = sema_.GetExprType(ctx->addExp()); if (info) { sema_.SetExprType(ctx, *info); } return {}; } sema_.SetExprType(ctx, result); return {}; } // 相等性表达式 std::any visitEqExp(SysYParser::EqExpContext* ctx) override { ExprInfo result; if (ctx->eqExp()) { ctx->eqExp()->accept(this); ctx->relExp()->accept(this); auto* left_info = sema_.GetExprType(ctx->eqExp()); auto* right_info = sema_.GetExprType(ctx->relExp()); if (!left_info || !right_info) { throw std::runtime_error(FormatError("sema", "相等性操作数类型推导失败")); } else { std::string op; if (ctx->EqOp()) { op = "=="; } else if (ctx->NeOp()) { op = "!="; } result.type = ir::Type::GetInt32Type(); result.is_lvalue = false; if (left_info->is_const && right_info->is_const) { result.is_const = true; result.is_const_int = true; float l = GetFloatValue(*left_info); float r = GetFloatValue(*right_info); if (op == "==") result.const_int_value = (l == r) ? 1 : 0; else if (op == "!=") result.const_int_value = (l != r) ? 1 : 0; } } } else { ctx->relExp()->accept(this); auto* info = sema_.GetExprType(ctx->relExp()); if (info) { sema_.SetExprType(ctx, *info); } return {}; } sema_.SetExprType(ctx, result); return {}; } // 逻辑与表达式 std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override { ExprInfo result; if (ctx->lAndExp()) { ctx->lAndExp()->accept(this); ctx->eqExp()->accept(this); auto* left_info = sema_.GetExprType(ctx->lAndExp()); auto* right_info = sema_.GetExprType(ctx->eqExp()); if (!left_info || !right_info) { throw std::runtime_error(FormatError("sema", "逻辑与操作数类型推导失败")); } else { // 处理左操作数 if (left_info->type->IsInt32()) { // 已经是 int,没问题 } else if (left_info->type->IsFloat()) { // float 可以隐式转换为 int sema_.AddConversion(ctx->lAndExp(), left_info->type, ir::Type::GetInt32Type()); left_info->type = ir::Type::GetInt32Type(); if (left_info->is_const && !left_info->is_const_int) { left_info->const_int_value = (int)left_info->const_float_value; left_info->is_const_int = true; } } else { throw std::runtime_error(FormatError("sema", "逻辑与左操作数必须是 int 类型或可以转换为 int 的 float 类型")); } // 处理右操作数 if (right_info->type->IsInt32()) { // 已经是 int,没问题 } else if (right_info->type->IsFloat()) { // float 可以隐式转换为 int sema_.AddConversion(ctx->eqExp(), right_info->type, ir::Type::GetInt32Type()); right_info->type = ir::Type::GetInt32Type(); if (right_info->is_const && !right_info->is_const_int) { right_info->const_int_value = (int)right_info->const_float_value; right_info->is_const_int = true; } } else { throw std::runtime_error(FormatError("sema", "逻辑与右操作数必须是 int 类型或可以转换为 int 的 float 类型")); } result.type = ir::Type::GetInt32Type(); result.is_lvalue = false; if (left_info->is_const && right_info->is_const && left_info->is_const_int && right_info->is_const_int) { result.is_const = true; result.is_const_int = true; result.const_int_value = (left_info->const_int_value && right_info->const_int_value) ? 1 : 0; } } } else { ctx->eqExp()->accept(this); auto* info = sema_.GetExprType(ctx->eqExp()); if (info) { // 对于单个操作数,也需要确保类型是 int(用于条件表达式) if (info->type->IsFloat()) { sema_.AddConversion(ctx->eqExp(), info->type, ir::Type::GetInt32Type()); info->type = ir::Type::GetInt32Type(); if (info->is_const && !info->is_const_int) { info->const_int_value = (int)info->const_float_value; info->is_const_int = true; } } else if (!info->type->IsInt32()) { throw std::runtime_error(FormatError("sema", "逻辑与操作数必须是 int 类型或可以转换为 int 的 float 类型")); } sema_.SetExprType(ctx, *info); } return {}; } sema_.SetExprType(ctx, result); return {}; } // 逻辑或表达式 std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override { ExprInfo result; if (ctx->lOrExp()) { ctx->lOrExp()->accept(this); ctx->lAndExp()->accept(this); auto* left_info = sema_.GetExprType(ctx->lOrExp()); auto* right_info = sema_.GetExprType(ctx->lAndExp()); if (!left_info || !right_info) { throw std::runtime_error(FormatError("sema", "逻辑或操作数类型推导失败")); } else { // 处理左操作数 if (left_info->type->IsInt32()) { // 已经是 int,没问题 } else if (left_info->type->IsFloat()) { // float 可以隐式转换为 int sema_.AddConversion(ctx->lOrExp(), left_info->type, ir::Type::GetInt32Type()); left_info->type = ir::Type::GetInt32Type(); if (left_info->is_const && !left_info->is_const_int) { left_info->const_int_value = (int)left_info->const_float_value; left_info->is_const_int = true; } } else { throw std::runtime_error(FormatError("sema", "逻辑或左操作数必须是 int 类型或可以转换为 int 的 float 类型")); } // 处理右操作数 if (right_info->type->IsInt32()) { // 已经是 int,没问题 } else if (right_info->type->IsFloat()) { // float 可以隐式转换为 int sema_.AddConversion(ctx->lAndExp(), right_info->type, ir::Type::GetInt32Type()); right_info->type = ir::Type::GetInt32Type(); if (right_info->is_const && !right_info->is_const_int) { right_info->const_int_value = (int)right_info->const_float_value; right_info->is_const_int = true; } } else { throw std::runtime_error(FormatError("sema", "逻辑或右操作数必须是 int 类型或可以转换为 int 的 float 类型")); } result.type = ir::Type::GetInt32Type(); result.is_lvalue = false; if (left_info->is_const && right_info->is_const && left_info->is_const_int && right_info->is_const_int) { result.is_const = true; result.is_const_int = true; result.const_int_value = (left_info->const_int_value || right_info->const_int_value) ? 1 : 0; } } } else { ctx->lAndExp()->accept(this); auto* info = sema_.GetExprType(ctx->lAndExp()); if (info) { // 对于单个操作数,也需要确保类型是 int(用于条件表达式) if (info->type->IsFloat()) { sema_.AddConversion(ctx->lAndExp(), info->type, ir::Type::GetInt32Type()); info->type = ir::Type::GetInt32Type(); if (info->is_const && !info->is_const_int) { info->const_int_value = (int)info->const_float_value; info->is_const_int = true; } } else if (!info->type->IsInt32()) { throw std::runtime_error(FormatError("sema", "逻辑或操作数必须是 int 类型或可以转换为 int 的 float 类型")); } sema_.SetExprType(ctx, *info); } return {}; } sema_.SetExprType(ctx, result); return {}; } // 获取语义上下文 SemanticContext TakeSemanticContext() { return std::move(sema_); } private: SymbolTable table_; SemanticContext sema_; struct LoopContext { bool in_loop; antlr4::ParserRuleContext* loop_node; }; std::vector loop_stack_; std::shared_ptr current_func_return_type_ = nullptr; bool current_func_has_return_ = false; // ==================== 辅助函数 ==================== ExprInfo CheckExp(SysYParser::ExpContext* ctx) { if (!ctx || !ctx->addExp()) { throw std::runtime_error(FormatError("sema", "无效表达式")); } std::cout << "[DEBUG] CheckExp: " << ctx->getText() << std::endl; ctx->addExp()->accept(this); auto* info = sema_.GetExprType(ctx->addExp()); if (!info) { throw std::runtime_error(FormatError("sema", "表达式类型推导失败")); } ExprInfo result = *info; sema_.SetExprType(ctx, result); return result; } // 专门用于检查 AddExp 的辅助函数(用于常量表达式) ExprInfo CheckAddExp(SysYParser::AddExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("sema", "无效表达式")); } ctx->accept(this); auto* info = sema_.GetExprType(ctx); if (!info) { throw std::runtime_error(FormatError("sema", "表达式类型推导失败")); } return *info; } ExprInfo CheckCond(SysYParser::CondContext* ctx) { if (!ctx || !ctx->lOrExp()) { throw std::runtime_error(FormatError("sema", "无效条件表达式")); } ctx->lOrExp()->accept(this); auto* info = sema_.GetExprType(ctx->lOrExp()); if (!info) { throw std::runtime_error(FormatError("sema", "条件表达式类型推导失败")); } ExprInfo result = *info; // 条件表达式的结果必须是 int,如果是 float 则需要转换 // 注意:lOrExp 已经处理了类型转换,这里只是再检查一次 if (!result.type->IsInt32()) { throw std::runtime_error(FormatError("sema", "条件表达式必须是 int 类型")); } return result; } ExprInfo CheckLValue(SysYParser::LValContext* ctx) { if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("sema", "非法左值")); } std::string name = ctx->Ident()->getText(); auto* sym = table_.lookup(name); if (!sym) { throw std::runtime_error(FormatError("sema", "未定义的变量: " + name)); } bool is_array_access = !ctx->exp().empty(); bool is_const = (sym->kind == SymbolKind::Constant); bool is_array_or_ptr = false; if (sym->type) { is_array_or_ptr = sym->type->IsArray() || sym->type->IsPtrInt32() || sym->type->IsPtrFloat(); } size_t dim_count = 0; std::shared_ptr elem_type = sym->type; if (sym->type && sym->type->IsArray()) { if (auto* arr_type = dynamic_cast(sym->type.get())) { dim_count = arr_type->GetDimensions().size(); elem_type = arr_type->GetElementType(); } } else if (sym->type && (sym->type->IsPtrInt32() || sym->type->IsPtrFloat())) { dim_count = 1; if (sym->type->IsPtrInt32()) { elem_type = ir::Type::GetInt32Type(); } else if (sym->type->IsPtrFloat()) { elem_type = ir::Type::GetFloatType(); } } size_t subscript_count = ctx->exp().size(); if (is_array_or_ptr) { if (subscript_count > 0) { // 有下标访问 if (subscript_count != dim_count) { throw std::runtime_error(FormatError("sema", "数组下标个数不匹配")); } for (auto* idx_exp : ctx->exp()) { ExprInfo idx = CheckExp(idx_exp); if (!idx.type->IsInt32()) { throw std::runtime_error(FormatError("sema", "数组下标必须是 int 类型")); } } return {elem_type, true, false}; } else { // 没有下标访问 if (sym->type->IsArray()) { // 数组名作为地址(右值) if (auto* arr_type = dynamic_cast(sym->type.get())) { if (arr_type->GetElementType()->IsInt32()) { return {ir::Type::GetPtrInt32Type(), false, true}; } else if (arr_type->GetElementType()->IsFloat()) { return {ir::Type::GetPtrFloatType(), false, true}; } } return {ir::Type::GetPtrInt32Type(), false, true}; } else { // 指针类型(如函数参数)可以不带下标使用 return {sym->type, true, is_const}; } } } else { if (subscript_count > 0) { throw std::runtime_error(FormatError("sema", "非数组变量不能使用下标: " + name)); } return {sym->type, true, is_const}; } } ExprInfo CheckFuncCall(SysYParser::UnaryExpContext* ctx) { if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("sema", "非法函数调用")); } std::string func_name = ctx->Ident()->getText(); std::cout << "[DEBUG] CheckFuncCall: " << func_name << std::endl; auto* func_sym = table_.lookup(func_name); if (!func_sym || func_sym->kind != SymbolKind::Function) { throw std::runtime_error(FormatError("sema", "未定义的函数: " + func_name)); } std::vector args; if (ctx->funcRParams()) { std::cout << "[DEBUG] 处理函数调用参数:" << std::endl; for (auto* exp : ctx->funcRParams()->exp()) { if (exp) { args.push_back(CheckExp(exp)); } } } if (args.size() != func_sym->param_types.size()) { throw std::runtime_error(FormatError("sema", "参数个数不匹配")); } for (size_t i = 0; i < std::min(args.size(), func_sym->param_types.size()); ++i) { std::cout << "[DEBUG] 检查参数 " << i << ": 实参类型 " << (int)args[i].type->GetKind() << " 形参类型 " << (int)func_sym->param_types[i]->GetKind() << std::endl; if (!IsTypeCompatible(args[i].type, func_sym->param_types[i])) { throw std::runtime_error(FormatError("sema", "参数类型不匹配")); } if (args[i].type != func_sym->param_types[i] && ctx->funcRParams() && i < ctx->funcRParams()->exp().size()) { sema_.AddConversion(ctx->funcRParams()->exp()[i], args[i].type, func_sym->param_types[i]); } } std::shared_ptr return_type; if (func_sym->type && func_sym->type->IsFunction()) { auto* func_type = dynamic_cast(func_sym->type.get()); if (func_type) { return_type = func_type->GetReturnType(); } } if (!return_type) { return_type = ir::Type::GetInt32Type(); } ExprInfo result; result.type = return_type; result.is_lvalue = false; result.is_const = false; return result; } ExprInfo CheckBinaryOp(const ExprInfo* left, const ExprInfo* right, const std::string& op, antlr4::ParserRuleContext* ctx) { ExprInfo result; if (!left->type->IsInt32() && !left->type->IsFloat()) { throw std::runtime_error(FormatError("sema", "左操作数必须是算术类型")); } if (!right->type->IsInt32() && !right->type->IsFloat()) { throw std::runtime_error(FormatError("sema", "右操作数必须是算术类型")); } if (op == "%" && (!left->type->IsInt32() || !right->type->IsInt32())) { throw std::runtime_error(FormatError("sema", "取模运算要求操作数为 int 类型")); } if (left->type->IsFloat() || right->type->IsFloat()) { result.type = ir::Type::GetFloatType(); } else { result.type = ir::Type::GetInt32Type(); } result.is_lvalue = false; if (left->is_const && right->is_const) { result.is_const = true; float l = GetFloatValue(*left); float r = GetFloatValue(*right); if (result.type->IsInt32()) { result.is_const_int = true; int li = (int)l, ri = (int)r; if (op == "*") result.const_int_value = li * ri; else if (op == "/") result.const_int_value = li / ri; else if (op == "%") result.const_int_value = li % ri; else if (op == "+") result.const_int_value = li + ri; else if (op == "-") result.const_int_value = li - ri; } else { if (op == "*") result.const_float_value = l * r; else if (op == "/") result.const_float_value = l / r; else if (op == "+") result.const_float_value = l + r; else if (op == "-") result.const_float_value = l - r; } } return result; } float GetFloatValue(const ExprInfo& info) { if (info.type->IsInt32() && info.is_const_int) { return (float)info.const_int_value; } else { return info.const_float_value; } } bool IsTypeCompatible(std::shared_ptr src, std::shared_ptr dst) { if (src == dst) return true; if (src->IsInt32() && dst->IsFloat()) return true; if (src->IsFloat() && dst->IsInt32()) return true; return false; } void CollectFunctionDeclaration(SysYParser::FuncDefContext* ctx) { if (!ctx || !ctx->Ident()) return; std::string name = ctx->Ident()->getText(); if (table_.lookup(name)) return; std::shared_ptr ret_type; if (ctx->funcType()) { if (ctx->funcType()->Void()) { ret_type = ir::Type::GetVoidType(); } else if (ctx->funcType()->Int()) { ret_type = ir::Type::GetInt32Type(); } else if (ctx->funcType()->Float()) { ret_type = ir::Type::GetFloatType(); } } if (!ret_type) ret_type = ir::Type::GetInt32Type(); std::vector> param_types; if (ctx->funcFParams()) { for (auto* param : ctx->funcFParams()->funcFParam()) { if (!param) continue; std::shared_ptr param_type; if (param->bType()) { if (param->bType()->Int()) { param_type = ir::Type::GetInt32Type(); } else if (param->bType()->Float()) { param_type = ir::Type::GetFloatType(); } } if (!param_type) param_type = ir::Type::GetInt32Type(); if (!param->L_BRACK().empty()) { if (param_type->IsInt32()) { param_type = ir::Type::GetPtrInt32Type(); } else if (param_type->IsFloat()) { param_type = ir::Type::GetPtrFloatType(); } } param_types.push_back(param_type); } } // 创建函数类型 std::shared_ptr func_type = ir::Type::GetFunctionType(ret_type, param_types); // 创建函数符号 Symbol sym; sym.name = name; sym.kind = SymbolKind::Function; sym.type = func_type; sym.param_types = param_types; sym.scope_level = 0; sym.is_initialized = true; sym.func_def_ctx = ctx; table_.addSymbol(sym); } void CollectFunctionParams(SysYParser::FuncFParamsContext* ctx) { if (!ctx) return; for (auto* param : ctx->funcFParam()) { if (!param || !param->Ident()) continue; std::string name = param->Ident()->getText(); if (table_.lookupCurrent(name)) { throw std::runtime_error(FormatError("sema", "重复定义参数: " + name)); } std::shared_ptr param_type; if (param->bType()) { if (param->bType()->Int()) { param_type = ir::Type::GetInt32Type(); } else if (param->bType()->Float()) { param_type = ir::Type::GetFloatType(); } } if (!param_type) param_type = ir::Type::GetInt32Type(); bool is_array = !param->L_BRACK().empty(); if (is_array) { if (param_type->IsInt32()) { param_type = ir::Type::GetPtrInt32Type(); } else if (param_type->IsFloat()) { param_type = ir::Type::GetPtrFloatType(); } std::cout << "[DEBUG] 数组参数: " << name << " 类型转换为指针" << std::endl; } Symbol sym; sym.name = name; sym.kind = SymbolKind::Parameter; sym.type = param_type; sym.scope_level = table_.currentScopeLevel(); sym.is_initialized = true; sym.var_def_ctx = nullptr; table_.addSymbol(sym); std::cout << "[DEBUG] 添加参数: " << name << " type_kind: " << (int)param_type->GetKind() << std::endl; } } void CheckGlobalInitIsConst(SysYParser::InitValContext* ctx) { if (!ctx) return; if (ctx->exp()) { ExprInfo info = CheckExp(ctx->exp()); if (!info.is_const) { throw std::runtime_error(FormatError("sema", "全局变量初始化必须是常量表达式")); } } else { for (auto* init : ctx->initVal()) { CheckGlobalInitIsConst(init); } } } int EvaluateConstExp(SysYParser::ConstExpContext* ctx) { if (!ctx || !ctx->addExp()) return 0; ExprInfo info = CheckAddExp(ctx->addExp()); if (info.is_const && info.is_const_int) { return info.const_int_value; } throw std::runtime_error(FormatError("sema", "常量表达式求值失败")); return 0; } struct ConstValue { bool is_int; int int_val; float float_val; }; std::vector EvaluateConstInitVal(SysYParser::ConstInitValContext* ctx, const std::vector& dims, std::shared_ptr base_type) { std::vector result; if (!ctx) return result; if (ctx->constExp()) { ExprInfo info = CheckAddExp(ctx->constExp()->addExp()); ConstValue val; if (info.type->IsInt32() && info.is_const_int) { val.is_int = true; val.int_val = info.const_int_value; if (base_type->IsFloat()) { val.is_int = false; val.float_val = (float)info.const_int_value; } } else if (info.type->IsFloat() && info.is_const) { val.is_int = false; val.float_val = info.const_float_value; if (base_type->IsInt32()) { val.is_int = true; val.int_val = (int)info.const_float_value; } } else { val.is_int = base_type->IsInt32(); val.int_val = 0; val.float_val = 0.0f; } result.push_back(val); } else { for (auto* init : ctx->constInitVal()) { std::vector sub_vals = EvaluateConstInitVal(init, dims, base_type); result.insert(result.end(), sub_vals.begin(), sub_vals.end()); } } return result; } void CheckMainFunction() { auto* main_sym = table_.lookup("main"); if (!main_sym || main_sym->kind != SymbolKind::Function) { throw std::runtime_error(FormatError("sema", "缺少 main 函数")); } std::shared_ptr ret_type; if (main_sym->type && main_sym->type->IsFunction()) { auto* func_type = dynamic_cast(main_sym->type.get()); if (func_type) { ret_type = func_type->GetReturnType(); } } if (!ret_type || !ret_type->IsInt32()) { throw std::runtime_error(FormatError("sema", "main 函数必须返回 int")); } if (!main_sym->param_types.empty()) { throw std::runtime_error(FormatError("sema", "main 函数不能有参数")); } } }; } // namespace SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { SemaVisitor visitor; comp_unit.accept(&visitor); return visitor.TakeSemanticContext(); }