From f3fe34801e7457409bf7fe50217061c75c727125 Mon Sep 17 00:00:00 2001 From: ftt <> Date: Wed, 25 Mar 2026 19:40:06 +0800 Subject: [PATCH] =?UTF-8?q?feat(sem)=E6=8F=90=E4=BA=A4=E8=AF=AD=E4=B9=89?= =?UTF-8?q?=E5=88=86=E6=9E=90B=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/irgen/IRGen.h | 58 -- include/sem/Sema.h | 88 ++- src/CMakeLists.txt | 2 - src/irgen/CMakeLists.txt | 14 +- src/irgen/IRGenDecl.cpp | 107 --- src/irgen/IRGenDriver.cpp | 15 - src/irgen/IRGenExp.cpp | 80 --- src/irgen/IRGenFunc.cpp | 87 --- src/irgen/IRGenStmt.cpp | 39 - src/main.cpp | 6 +- src/sem/Sema.cpp | 1427 ++++++++++++++++++++++++++++++++----- 11 files changed, 1325 insertions(+), 598 deletions(-) delete mode 100644 include/irgen/IRGen.h delete mode 100644 src/irgen/IRGenDecl.cpp delete mode 100644 src/irgen/IRGenDriver.cpp delete mode 100644 src/irgen/IRGenExp.cpp delete mode 100644 src/irgen/IRGenFunc.cpp delete mode 100644 src/irgen/IRGenStmt.cpp diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h deleted file mode 100644 index 231ba90..0000000 --- a/include/irgen/IRGen.h +++ /dev/null @@ -1,58 +0,0 @@ -// 将语法树翻译为 IR。 -// 实现拆分在 IRGenFunc/IRGenStmt/IRGenExp/IRGenDecl。 - -#pragma once - -#include -#include -#include -#include - -#include "SysYBaseVisitor.h" -#include "SysYParser.h" -#include "ir/IR.h" -#include "sem/Sema.h" - -namespace ir { -class Module; -class Function; -class IRBuilder; -class Value; -} - -class IRGenImpl final : public SysYBaseVisitor { - public: - IRGenImpl(ir::Module& module, const SemanticContext& sema); - - std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; - std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; - std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; - std::any visitDecl(SysYParser::DeclContext* ctx) override; - std::any visitStmt(SysYParser::StmtContext* ctx) override; - std::any visitVarDef(SysYParser::VarDefContext* ctx) override; - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override; - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override; - std::any visitVarExp(SysYParser::VarExpContext* ctx) override; - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override; - - private: - enum class BlockFlow { - Continue, - Terminated, - }; - - BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item); - ir::Value* EvalExpr(SysYParser::ExpContext& expr); - - ir::Module& module_; - const SemanticContext& sema_; - ir::Function* func_; - ir::IRBuilder builder_; - // 名称绑定由 Sema 负责;IRGen 只维护“声明 -> 存储槽位”的代码生成状态。 - std::unordered_map storage_map_; -}; - -std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, - const SemanticContext& sema); diff --git a/include/sem/Sema.h b/include/sem/Sema.h index 9ac057b..c79c401 100644 --- a/include/sem/Sema.h +++ b/include/sem/Sema.h @@ -2,29 +2,83 @@ #pragma once #include +#include +#include #include "SysYParser.h" +#include "ir/IR.h" +// 表达式信息结构 +struct ExprInfo { + std::shared_ptr type = nullptr; + bool is_lvalue = false; + bool is_const = false; + bool is_const_int = false; // 是否是整型常量 + int const_int_value = 0; + float const_float_value = 0.0f; + antlr4::ParserRuleContext* node = nullptr; // 对应的语法树节点 +}; + +// 语义分析上下文:存储分析过程中产生的信息 class SemanticContext { - public: - void BindVarUse(SysYParser::VarContext* use, - SysYParser::VarDefContext* decl) { - var_uses_[use] = decl; - } - - SysYParser::VarDefContext* ResolveVarUse( - const SysYParser::VarContext* use) const { - auto it = var_uses_.find(use); - return it == var_uses_.end() ? nullptr : it->second; - } - - private: - std::unordered_map - var_uses_; +public: + // ----- 变量使用绑定(使用 LValContext 而不是 VarContext)----- + void BindVarUse(SysYParser::LValContext* use, + SysYParser::VarDefContext* decl) { + var_uses_[use] = decl; + } + + SysYParser::VarDefContext* ResolveVarUse( + const SysYParser::LValContext* use) const { + auto it = var_uses_.find(use); + return it == var_uses_.end() ? nullptr : it->second; + } + + // ----- 表达式类型信息存储 ----- + void SetExprType(antlr4::ParserRuleContext* node, const ExprInfo& info) { + ExprInfo copy = info; + copy.node = node; + expr_types_[node] = copy; + } + + ExprInfo* GetExprType(antlr4::ParserRuleContext* node) { + auto it = expr_types_.find(node); + return it == expr_types_.end() ? nullptr : &it->second; + } + + const ExprInfo* GetExprType(antlr4::ParserRuleContext* node) const { + auto it = expr_types_.find(node); + return it == expr_types_.end() ? nullptr : &it->second; + } + + // ----- 隐式转换标记(供 IR 生成使用)----- + struct ConversionInfo { + antlr4::ParserRuleContext* node; + std::shared_ptr from_type; + std::shared_ptr to_type; + }; + + void AddConversion(antlr4::ParserRuleContext* node, + std::shared_ptr from, + std::shared_ptr to) { + conversions_.push_back({node, from, to}); + } + + const std::vector& GetConversions() const { return conversions_; } + +private: + // 变量使用映射 - 使用 LValContext 作为键 + std::unordered_map var_uses_; + + // 表达式类型映射 + std::unordered_map expr_types_; + + // 隐式转换列表 + std::vector conversions_; }; // 目前仅检查: // - 变量先声明后使用 // - 局部变量不允许重复定义 -SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); +SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index acb9400..c4ec6d9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -5,7 +5,6 @@ add_subdirectory(ir) add_subdirectory(frontend) if(NOT COMPILER_PARSE_ONLY) add_subdirectory(sem) - add_subdirectory(irgen) add_subdirectory(mir) endif() @@ -20,7 +19,6 @@ target_link_libraries(compiler PRIVATE if(NOT COMPILER_PARSE_ONLY) target_link_libraries(compiler PRIVATE sem - irgen mir ) target_compile_definitions(compiler PRIVATE COMPILER_PARSE_ONLY=0) diff --git a/src/irgen/CMakeLists.txt b/src/irgen/CMakeLists.txt index d440bde..b28b04f 100644 --- a/src/irgen/CMakeLists.txt +++ b/src/irgen/CMakeLists.txt @@ -1,13 +1,3 @@ -add_library(irgen STATIC - IRGenDriver.cpp - IRGenFunc.cpp - IRGenStmt.cpp - IRGenExp.cpp - IRGenDecl.cpp -) -target_link_libraries(irgen PUBLIC - build_options - ${ANTLR4_RUNTIME_TARGET} - ir -) + + diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp deleted file mode 100644 index 0eb62ae..0000000 --- a/src/irgen/IRGenDecl.cpp +++ /dev/null @@ -1,107 +0,0 @@ -#include "irgen/IRGen.h" - -#include - -#include "SysYParser.h" -#include "ir/IR.h" -#include "utils/Log.h" - -namespace { - -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("irgen", "非法左值")); - } - return lvalue.ID()->getText(); -} - -} // namespace - -std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少语句块")); - } - for (auto* item : ctx->blockItem()) { - if (item) { - if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { - // 当前语法要求 return 为块内最后一条语句;命中后可停止生成。 - break; - } - } - } - return {}; -} - -IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult( - SysYParser::BlockItemContext& item) { - return std::any_cast(item.accept(this)); -} - -std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少块内项")); - } - if (ctx->decl()) { - ctx->decl()->accept(this); - return BlockFlow::Continue; - } - if (ctx->stmt()) { - return ctx->stmt()->accept(this); - } - throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明")); -} - -// 变量声明的 IR 生成目前也是最小实现: -// - 先检查声明的基础类型,当前仅支持局部 int; -// - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。 -// -// 和更完整的版本相比,这里还没有: -// - 一个 Decl 中多个变量定义的顺序处理; -// - const、数组、全局变量等不同声明形态; -// - 更丰富的类型系统。 -std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少变量声明")); - } - if (!ctx->btype() || !ctx->btype()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明")); - } - auto* var_def = ctx->varDef(); - if (!var_def) { - throw std::runtime_error(FormatError("irgen", "非法变量声明")); - } - var_def->accept(this); - return {}; -} - - -// 当前仍是教学用的最小版本,因此这里只支持: -// - 局部 int 变量; -// - 标量初始化; -// - 一个 VarDef 对应一个槽位。 -std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少变量定义")); - } - if (!ctx->lValue()) { - throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); - } - GetLValueName(*ctx->lValue()); - if (storage_map_.find(ctx) != storage_map_.end()) { - throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); - } - auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); - storage_map_[ctx] = slot; - - ir::Value* init = nullptr; - if (auto* init_value = ctx->initValue()) { - if (!init_value->exp()) { - throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化")); - } - init = EvalExpr(*init_value->exp()); - } else { - init = builder_.CreateConstInt(0); - } - builder_.CreateStore(init, slot); - return {}; -} diff --git a/src/irgen/IRGenDriver.cpp b/src/irgen/IRGenDriver.cpp deleted file mode 100644 index ff94412..0000000 --- a/src/irgen/IRGenDriver.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include "irgen/IRGen.h" - -#include - -#include "SysYParser.h" -#include "ir/IR.h" -#include "utils/Log.h" - -std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, - const SemanticContext& sema) { - auto module = std::make_unique(); - IRGenImpl gen(*module, sema); - tree.accept(&gen); - return module; -} diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp deleted file mode 100644 index cf4797c..0000000 --- a/src/irgen/IRGenExp.cpp +++ /dev/null @@ -1,80 +0,0 @@ -#include "irgen/IRGen.h" - -#include - -#include "SysYParser.h" -#include "ir/IR.h" -#include "utils/Log.h" - -// 表达式生成当前也只实现了很小的一个子集。 -// 目前支持: -// - 整数字面量 -// - 普通局部变量读取 -// - 括号表达式 -// - 二元加法 -// -// 还未支持: -// - 减乘除与一元运算 -// - 赋值表达式 -// - 函数调用 -// - 数组、指针、下标访问 -// - 条件与比较表达式 -// - ... -ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { - return std::any_cast(expr.accept(this)); -} - - -std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "非法括号表达式")); - } - return EvalExpr(*ctx->exp()); -} - - -std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量")); - } - return static_cast( - builder_.CreateConstInt(std::stoi(ctx->number()->getText()))); -} - -// 变量使用的处理流程: -// 1. 先通过语义分析结果把变量使用绑定回声明; -// 2. 再通过 storage_map_ 找到该声明对应的栈槽位; -// 3. 最后生成 load,把内存中的值读出来。 -// -// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 -std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) { - if (!ctx || !ctx->var() || !ctx->var()->ID()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); - } - auto* decl = sema_.ResolveVarUse(ctx->var()); - if (!decl) { - throw std::runtime_error( - FormatError("irgen", - "变量使用缺少语义绑定: " + ctx->var()->ID()->getText())); - } - auto it = storage_map_.find(decl); - if (it == storage_map_.end()) { - throw std::runtime_error( - FormatError("irgen", - "变量声明缺少存储槽位: " + ctx->var()->ID()->getText())); - } - return static_cast( - builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); -} - - -std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { - throw std::runtime_error(FormatError("irgen", "非法加法表达式")); - } - ir::Value* lhs = EvalExpr(*ctx->exp(0)); - ir::Value* rhs = EvalExpr(*ctx->exp(1)); - return static_cast( - builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, - module_.GetContext().NextTemp())); -} diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp deleted file mode 100644 index 4912d03..0000000 --- a/src/irgen/IRGenFunc.cpp +++ /dev/null @@ -1,87 +0,0 @@ -#include "irgen/IRGen.h" - -#include - -#include "SysYParser.h" -#include "ir/IR.h" -#include "utils/Log.h" - -namespace { - -void VerifyFunctionStructure(const ir::Function& func) { - // 当前 IRGen 仍是单入口、顺序生成;这里在生成结束后补一层块终结校验。 - for (const auto& bb : func.GetBlocks()) { - if (!bb || !bb->HasTerminator()) { - throw std::runtime_error( - FormatError("irgen", "基本块未正确终结: " + - (bb ? bb->GetName() : std::string("")))); - } - } -} - -} // namespace - -IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema) - : module_(module), - sema_(sema), - func_(nullptr), - builder_(module.GetContext(), nullptr) {} - -// 编译单元的 IR 生成当前只实现了最小功能: -// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容; -// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR; -// -// 当前还没有实现: -// - 多个函数定义的遍历与生成; -// - 全局变量、全局常量的 IR 生成。 -std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少编译单元")); - } - auto* func = ctx->funcDef(); - if (!func) { - throw std::runtime_error(FormatError("irgen", "缺少函数定义")); - } - func->accept(this); - return {}; -} - -// 函数 IR 生成当前实现了: -// 1. 获取函数名; -// 2. 检查函数返回类型; -// 3. 在 Module 中创建 Function; -// 4. 将 builder 插入点设置到入口基本块; -// 5. 继续生成函数体。 -// -// 当前还没有实现: -// - 通用函数返回类型处理; -// - 形参列表遍历与参数类型收集; -// - FunctionType 这样的函数类型对象; -// - Argument/形式参数 IR 对象; -// - 入口块中的参数初始化逻辑。 -// ... - -// 因此这里目前只支持最小的“无参 int 函数”生成。 -std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少函数定义")); - } - if (!ctx->blockStmt()) { - throw std::runtime_error(FormatError("irgen", "函数体为空")); - } - if (!ctx->ID()) { - throw std::runtime_error(FormatError("irgen", "缺少函数名")); - } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数")); - } - - func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type()); - builder_.SetInsertPoint(func_->GetEntry()); - storage_map_.clear(); - - ctx->blockStmt()->accept(this); - // 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 - VerifyFunctionStructure(*func_); - return {}; -} diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp deleted file mode 100644 index 751550c..0000000 --- a/src/irgen/IRGenStmt.cpp +++ /dev/null @@ -1,39 +0,0 @@ -#include "irgen/IRGen.h" - -#include - -#include "SysYParser.h" -#include "ir/IR.h" -#include "utils/Log.h" - -// 语句生成当前只实现了最小子集。 -// 目前支持: -// - return ; -// -// 还未支持: -// - 赋值语句 -// - if / while 等控制流 -// - 空语句、块语句嵌套分发之外的更多语句形态 - -std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少语句")); - } - if (ctx->returnStmt()) { - return ctx->returnStmt()->accept(this); - } - throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); -} - - -std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少 return 语句")); - } - if (!ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "return 缺少表达式")); - } - ir::Value* v = EvalExpr(*ctx->exp()); - builder_.CreateRet(v); - return BlockFlow::Terminated; -} diff --git a/src/main.cpp b/src/main.cpp index 88ed747..fc40f04 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -6,7 +6,7 @@ #include "frontend/SyntaxTreePrinter.h" #if !COMPILER_PARSE_ONLY #include "ir/IR.h" -#include "irgen/IRGen.h" +//#include "irgen/IRGen.h" #include "mir/MIR.h" #include "sem/Sema.h" #endif @@ -35,7 +35,7 @@ int main(int argc, char** argv) { } auto sema = RunSema(*comp_unit); - auto module = GenerateIR(*comp_unit, sema); + /*auto module = GenerateIR(*comp_unit, sema); if (opts.emit_ir) { ir::IRPrinter printer; if (need_blank_line) { @@ -53,7 +53,7 @@ int main(int argc, char** argv) { std::cout << "\n"; } mir::PrintAsm(*machine_func, std::cout); - } + }*/ #else if (opts.emit_ir || opts.emit_asm) { throw std::runtime_error( diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index 745374c..68f9b0d 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "SysYBaseVisitor.h" #include "sem/SymbolTable.h" @@ -10,191 +11,1261 @@ namespace { -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("sema", "非法左值")); - } - return lvalue.ID()->getText(); +// 获取左值名称的辅助函数 +std::string GetLValueName(SysYParser::LValContext& lval) { + if (!lval.Ident()) { + throw std::runtime_error(FormatError("sema", "非法左值")); + } + return lval.Ident()->getText(); +} + +// 从 BTypeContext 获取类型 +std::shared_ptr GetTypeFromBType(SysYParser::BTypeContext* ctx) { + if (!ctx) return ir::Type::GetInt32Type(); + if (ctx->Int()) return ir::Type::GetInt32Type(); + if (ctx->Float()) return ir::Type::GetFloatType(); + return ir::Type::GetInt32Type(); } +// 语义分析 Visitor class SemaVisitor final : public SysYBaseVisitor { - public: - std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "缺少编译单元")); - } - auto* func = ctx->funcDef(); - if (!func || !func->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); - } - if (!func->ID() || func->ID()->getText() != "main") { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); - } - func->accept(this); - if (!seen_return_) { - throw std::runtime_error( - FormatError("sema", "main 函数必须包含 return 语句")); - } - return {}; - } - - std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { - if (!ctx || !ctx->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); - } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持 int main")); - } - const auto& items = ctx->blockStmt()->blockItem(); - if (items.empty()) { - throw std::runtime_error( - FormatError("sema", "main 函数不能为空,且必须以 return 结束")); - } - ctx->blockStmt()->accept(this); - return {}; - } - - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "缺少语句块")); - } - const auto& items = ctx->blockItem(); - for (size_t i = 0; i < items.size(); ++i) { - auto* item = items[i]; - if (!item) { - continue; - } - if (seen_return_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); - } - current_item_index_ = i; - total_items_ = items.size(); - item->accept(this); - } - return {}; - } - - std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - if (ctx->decl()) { - ctx->decl()->accept(this); - return {}; - } - if (ctx->stmt()) { - ctx->stmt()->accept(this); - return {}; - } - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - - std::any visitDecl(SysYParser::DeclContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); - } - if (!ctx->btype() || !ctx->btype()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明")); - } - auto* var_def = ctx->varDef(); - if (!var_def || !var_def->lValue()) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); - } - const std::string name = GetLValueName(*var_def->lValue()); - if (table_.Contains(name)) { - throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); - } - if (auto* init = var_def->initValue()) { - if (!init->exp()) { - throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化")); - } - init->exp()->accept(this); - } - table_.Add(name, var_def); - return {}; - } - - std::any visitStmt(SysYParser::StmtContext* ctx) override { - if (!ctx || !ctx->returnStmt()) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - ctx->returnStmt()->accept(this); - return {}; - } - - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "return 缺少表达式")); - } - ctx->exp()->accept(this); - seen_return_ = true; - if (current_item_index_ + 1 != total_items_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); - } - return {}; - } - - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "非法括号表达式")); - } - ctx->exp()->accept(this); - return {}; - } - - std::any visitVarExp(SysYParser::VarExpContext* ctx) override { - if (!ctx || !ctx->var()) { - throw std::runtime_error(FormatError("sema", "非法变量表达式")); - } - ctx->var()->accept(this); - return {}; - } - - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { - throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量")); - } - return {}; - } - - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { - throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式")); - } - ctx->exp(0)->accept(this); - ctx->exp(1)->accept(this); - return {}; - } - - std::any visitVar(SysYParser::VarContext* ctx) override { - if (!ctx || !ctx->ID()) { - throw std::runtime_error(FormatError("sema", "非法变量引用")); +public: + SemaVisitor() : table_() {} + + std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "缺少编译单元")); + } + table_.enterScope(); // 创建全局作用域 + for (auto* func : ctx->funcDef()) { // 收集所有函数声明(处理互相调用) + CollectFunctionDeclaration(func); + } + for (auto* decl : ctx->decl()) { // 处理所有声明和定义 + if (decl) decl->accept(this); + } + for (auto* func : ctx->funcDef()) { + if (func) func->accept(this); + } + CheckMainFunction(); // 检查 main 函数存在且正确 + table_.exitScope(); // 退出全局作用域 + return {}; + } + + std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "函数定义缺少标识符")); + } + std::string name = ctx->Ident()->getText(); + std::shared_ptr return_type; // 获取返回类型 + if (ctx->funcType()) { + if (ctx->funcType()->Void()) { + return_type = ir::Type::GetVoidType(); + } else if (ctx->funcType()->Int()) { + return_type = ir::Type::GetInt32Type(); + } else if (ctx->funcType()->Float()) { + return_type = ir::Type::GetFloatType(); + } else { + return_type = ir::Type::GetInt32Type(); + } + } else { + return_type = ir::Type::GetInt32Type(); + } + std::cout << "[DEBUG] 进入函数: " << name + << " 返回类型: " << (return_type->IsInt32() ? "int" : + return_type->IsFloat() ? "float" : "void") + << std::endl; + + // 记录当前函数返回类型(用于 return 检查) + current_func_return_type_ = return_type; + current_func_has_return_ = false; + + table_.enterScope(); + if (ctx->funcFParams()) { // 处理参数 + CollectFunctionParams(ctx->funcFParams()); + } + if (ctx->block()) { // 处理函数体 + ctx->block()->accept(this); + } + std::cout << "[DEBUG] 函数 " << name + << " has_return: " << current_func_has_return_ + << " return_type_is_void: " << return_type->IsVoid() + << std::endl; + if (!return_type->IsVoid() && !current_func_has_return_) { // 检查非 void 函数是否有 return + throw std::runtime_error(FormatError("sema", "非 void 函数 " + name + " 缺少 return 语句")); + } + table_.exitScope(); + + current_func_return_type_ = nullptr; + current_func_has_return_ = false; + return {}; + } + + std::any visitBlock(SysYParser::BlockContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "缺少语句块")); + } + table_.enterScope(); + for (auto* item : ctx->blockItem()) { // 处理所有 blockItem + if (item) { + item->accept(this); + // 如果已经有 return,可以继续(但 return 必须是最后一条) + // 注意:这里不需要跳出,因为 return 语句本身已经标记了 + } + } + table_.exitScope(); + return {}; + } + + std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + } + if (ctx->decl()) { + ctx->decl()->accept(this); + return {}; + } + if (ctx->stmt()) { + ctx->stmt()->accept(this); + return {}; + } + throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + } + + std::any visitDecl(SysYParser::DeclContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "非法变量声明")); + } + if (ctx->constDecl()) { + ctx->constDecl()->accept(this); + } else if (ctx->varDecl()) { + ctx->varDecl()->accept(this); + } + return {}; + } + + // ==================== 变量声明 ==================== + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override { + if (!ctx || !ctx->bType()) { + throw std::runtime_error(FormatError("sema", "非法变量声明")); + } + std::shared_ptr base_type = GetTypeFromBType(ctx->bType()); + bool is_global = (table_.currentScopeLevel() == 0); + for (auto* var_def : ctx->varDef()) { + if (var_def) { + CheckVarDef(var_def, base_type, is_global); + } + } + return {}; + } + + void CheckVarDef(SysYParser::VarDefContext* ctx, + std::shared_ptr base_type, + bool is_global) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法变量定义")); + } + std::string name = ctx->Ident()->getText(); + if (table_.lookupCurrent(name)) { // 检查重复定义 + throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); + } + // 确定类型(处理数组维度) + std::shared_ptr type = base_type; + std::vector dims; + bool is_array = !ctx->constExp().empty(); + // 调试输出 + std::cout << "[DEBUG] CheckVarDef: " << name + << " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown") + << " is_array: " << is_array + << " dim_count: " << ctx->constExp().size() << std::endl; + if (is_array) { + // 处理数组维度 + for (auto* dim_exp : ctx->constExp()) { + int dim = EvaluateConstExp(dim_exp); + if (dim <= 0) { + throw std::runtime_error(FormatError("sema", "数组维度必须为正整数")); + } + dims.push_back(dim); + std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl; + } + // 创建数组类型 + type = ir::Type::GetArrayType(base_type, dims); + std::cout << "[DEBUG] 创建数组类型完成" << std::endl; + std::cout << "[DEBUG] type->IsArray(): " << type->IsArray() << std::endl; + std::cout << "[DEBUG] type->GetKind(): " << (int)type->GetKind() << std::endl; + // 验证数组类型 + if (type->IsArray()) { + auto* arr_type = dynamic_cast(type.get()); + if (arr_type) { + std::cout << "[DEBUG] ArrayType dimensions: "; + for (int d : arr_type->GetDimensions()) { + std::cout << d << " "; + } + std::cout << std::endl; + std::cout << "[DEBUG] Element type: " + << (arr_type->GetElementType()->IsInt32() ? "int" : + arr_type->GetElementType()->IsFloat() ? "float" : "unknown") + << std::endl; + } + } + } + bool has_init = (ctx->initVal() != nullptr); // 处理初始化 + if (is_global && has_init) { + CheckGlobalInitIsConst(ctx->initVal()); // 全局变量初始化必须是常量表达式 + } + // 创建符号 + Symbol sym; + sym.name = name; + sym.kind = SymbolKind::Variable; + sym.type = type; + sym.scope_level = table_.currentScopeLevel(); + sym.is_initialized = has_init; + sym.var_def_ctx = ctx; + if (is_array) { + // 存储维度信息,但 param_types 通常用于函数参数 + // 数组变量的维度信息已经包含在 type 中 + sym.param_types.clear(); // 确保不混淆 + } + table_.addSymbol(sym); // 添加到符号表 + std::cout << "[DEBUG] 符号添加完成: " << name + << " type_kind: " << (int)sym.type->GetKind() + << " is_array: " << sym.type->IsArray() + << std::endl; + } + + // ==================== 常量声明 ==================== + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override { + if (!ctx || !ctx->bType()) { + throw std::runtime_error(FormatError("sema", "非法常量声明")); + } + std::shared_ptr base_type = GetTypeFromBType(ctx->bType()); + for (auto* const_def : ctx->constDef()) { + if (const_def) { + CheckConstDef(const_def, base_type); + } + } + return {}; + } + + void CheckConstDef(SysYParser::ConstDefContext* ctx, + std::shared_ptr base_type) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法常量定义")); + } + std::string name = ctx->Ident()->getText(); + if (table_.lookupCurrent(name)) { + throw std::runtime_error(FormatError("sema", "重复定义常量: " + name)); + } + // 确定类型 + std::shared_ptr type = base_type; + std::vector dims; + bool is_array = !ctx->constExp().empty(); + std::cout << "[DEBUG] CheckConstDef: " << name + << " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown") + << " is_array: " << is_array + << " dim_count: " << ctx->constExp().size() << std::endl; + if (is_array) { + for (auto* dim_exp : ctx->constExp()) { + int dim = EvaluateConstExp(dim_exp); + if (dim <= 0) { + throw std::runtime_error(FormatError("sema", "数组维度必须为正整数")); + } + dims.push_back(dim); + std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl; + } + type = ir::Type::GetArrayType(base_type, dims); + std::cout << "[DEBUG] 创建数组类型完成,IsArray: " << type->IsArray() << std::endl; + } + // 求值初始化器 + std::vector init_values; + if (ctx->constInitVal()) { + init_values = EvaluateConstInitVal(ctx->constInitVal(), dims, base_type); + std::cout << "[DEBUG] 初始化值数量: " << init_values.size() << std::endl; + } + // 检查初始化值数量 + size_t expected_count = 1; + if (is_array) { + expected_count = 1; + for (int d : dims) expected_count *= d; + std::cout << "[DEBUG] 期望元素数量: " << expected_count << std::endl; + } + if (init_values.size() > expected_count) { + throw std::runtime_error(FormatError("sema", "初始化值过多")); + } + Symbol sym; + sym.name = name; + sym.kind = SymbolKind::Constant; + sym.type = type; + sym.scope_level = table_.currentScopeLevel(); + sym.is_initialized = true; + sym.var_def_ctx = nullptr; + // 存储常量值(仅对非数组有效) + if (!is_array && !init_values.empty()) { + if (base_type->IsInt32() && init_values[0].is_int) { + sym.is_int_const = true; + sym.const_value.i32 = init_values[0].int_val; + std::cout << "[DEBUG] 存储整型常量值: " << init_values[0].int_val << std::endl; + } else if (base_type->IsFloat() && !init_values[0].is_int) { + sym.is_int_const = false; + sym.const_value.f32 = init_values[0].float_val; + std::cout << "[DEBUG] 存储浮点常量值: " << init_values[0].float_val << std::endl; + } + } else if (is_array) { + std::cout << "[DEBUG] 数组常量,不存储单个常量值" << std::endl; + } + table_.addSymbol(sym); + std::cout << "[DEBUG] 常量符号添加完成" << std::endl; + } + + // ==================== 语句语义检查 ==================== + + // 处理所有语句 - 通过运行时类型判断 + std::any visitStmt(SysYParser::StmtContext* ctx) override { + if (!ctx) return {}; + // 调试输出 + std::cout << "[DEBUG] visitStmt: "; + if (ctx->Return()) std::cout << "Return "; + if (ctx->If()) std::cout << "If "; + if (ctx->While()) std::cout << "While "; + if (ctx->Break()) std::cout << "Break "; + if (ctx->Continue()) std::cout << "Continue "; + if (ctx->lVal() && ctx->Assign()) std::cout << "Assign "; + if (ctx->exp() && ctx->Semi()) std::cout << "ExpStmt "; + if (ctx->block()) std::cout << "Block "; + std::cout << std::endl; + // 判断语句类型 - 注意:Return() 返回的是 TerminalNode* + if (ctx->Return() != nullptr) { + // return 语句 + std::cout << "[DEBUG] 检测到 return 语句" << std::endl; + return visitReturnStmtInternal(ctx); + } else if (ctx->lVal() != nullptr && ctx->Assign() != nullptr) { + // 赋值语句 + return visitAssignStmt(ctx); + } else if (ctx->exp() != nullptr && ctx->Semi() != nullptr) { + // 表达式语句(可能有表达式) + return visitExpStmt(ctx); + } else if (ctx->block() != nullptr) { + // 块语句 + return ctx->block()->accept(this); + } else if (ctx->If() != nullptr) { + // if 语句 + return visitIfStmtInternal(ctx); + } else if (ctx->While() != nullptr) { + // while 语句 + return visitWhileStmtInternal(ctx); + } else if (ctx->Break() != nullptr) { + // break 语句 + return visitBreakStmtInternal(ctx); + } else if (ctx->Continue() != nullptr) { + // continue 语句 + return visitContinueStmtInternal(ctx); + } + return {}; + } + + // return 语句内部实现 + std::any visitReturnStmtInternal(SysYParser::StmtContext* ctx) { + std::cout << "[DEBUG] visitReturnStmtInternal 被调用" << std::endl; + std::shared_ptr expected = current_func_return_type_; + if (!expected) { + throw std::runtime_error(FormatError("sema", "return 语句不在函数体内")); + } + if (ctx->exp() != nullptr) { + // 有返回值的 return + std::cout << "[DEBUG] 有返回值的 return" << std::endl; + ExprInfo ret_val = CheckExp(ctx->exp()); + if (expected->IsVoid()) { + throw std::runtime_error(FormatError("sema", "void 函数不能返回值")); + } else if (!IsTypeCompatible(ret_val.type, expected)) { + throw std::runtime_error(FormatError("sema", "返回值类型不匹配")); + } + // 标记需要隐式转换 + if (ret_val.type != expected) { + sema_.AddConversion(ctx->exp(), ret_val.type, expected); + } + // 设置 has_return 标志 + current_func_has_return_ = true; + std::cout << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl; + } else { + // 无返回值的 return + std::cout << "[DEBUG] 无返回值的 return" << std::endl; + if (!expected->IsVoid()) { + throw std::runtime_error(FormatError("sema", "非 void 函数必须返回值")); + } + // 设置 has_return 标志 + current_func_has_return_ = true; + std::cout << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl; + } + return {}; + } + + // 左值表达式(变量引用) + std::any visitLVal(SysYParser::LValContext* ctx) override { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法变量引用")); + } + std::string name = ctx->Ident()->getText(); + auto* sym = table_.lookup(name); + if (!sym) { + throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); + } + // 检查数组访问 + bool is_array_access = !ctx->exp().empty(); + ExprInfo result; + // 判断是否为数组类型或指针类型(数组参数) + bool is_array_or_ptr = false; + if (sym->type) { + is_array_or_ptr = sym->type->IsArray() || sym->type->IsPtrInt32() || sym->type->IsPtrFloat(); + } + // 调试输出 + std::cout << "[DEBUG] visitLVal: " << name + << " kind: " << (int)sym->kind + << " type_kind: " << (sym->type ? (int)sym->type->GetKind() : -1) + << " is_array_or_ptr: " << is_array_or_ptr + << " subscript_count: " << ctx->exp().size() + << std::endl; + if (is_array_or_ptr) { + if (!is_array_access) { + throw std::runtime_error(FormatError("sema", "数组变量必须使用下标访问: " + name)); + } + // 获取维度信息 + size_t dim_count = 0; + std::shared_ptr elem_type = sym->type; + if (sym->type->IsArray()) { + if (auto* arr_type = dynamic_cast(sym->type.get())) { + dim_count = arr_type->GetDimensions().size(); + // 获取元素类型 + elem_type = arr_type->GetElementType(); + } + } else if (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) { + // 指针类型,只接受一个下标 + dim_count = 1; + // 指针解引用后的类型 + if (sym->type->IsPtrInt32()) { + elem_type = ir::Type::GetInt32Type(); + } else if (sym->type->IsPtrFloat()) { + elem_type = ir::Type::GetFloatType(); + } + } + if (ctx->exp().size() != dim_count) { + throw std::runtime_error(FormatError("sema", "数组下标个数不匹配")); + } + for (auto* idx_exp : ctx->exp()) { + ExprInfo idx = CheckExp(idx_exp); + if (!idx.type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "数组下标必须是 int 类型")); + } + } + result.type = elem_type; + result.is_lvalue = true; + result.is_const = false; + } else { + if (is_array_access) { + throw std::runtime_error(FormatError("sema", "非数组变量不能使用下标: " + name)); + } + result.type = sym->type; + result.is_lvalue = true; + result.is_const = (sym->kind == SymbolKind::Constant); + if (result.is_const && sym->type && !sym->type->IsArray()) { + if (sym->is_int_const) { + result.is_const_int = true; + result.const_int_value = sym->const_value.i32; + } else { + result.const_float_value = sym->const_value.f32; + } + } + } + sema_.SetExprType(ctx, result); + return {}; + } + + // if 语句内部实现 + std::any visitIfStmtInternal(SysYParser::StmtContext* ctx) { + // 检查条件表达式 + if (ctx->cond()) { + ExprInfo cond = CheckCond(ctx->cond()); + if (!cond.type || !cond.type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "if 条件必须是 int 类型")); + } + } + // 处理 then 分支 + if (ctx->stmt().size() > 0) { + ctx->stmt()[0]->accept(this); + } + // 处理 else 分支 + if (ctx->stmt().size() > 1) { + ctx->stmt()[1]->accept(this); + } + return {}; + } + + // while 语句内部实现 + std::any visitWhileStmtInternal(SysYParser::StmtContext* ctx) { + if (ctx->cond()) { // 检查条件表达式 + ExprInfo cond = CheckCond(ctx->cond()); + if (!cond.type || !cond.type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "while 条件必须是 int 类型")); + } + } + loop_stack_.push_back({true, ctx}); // 进入循环上下文 + if (ctx->stmt().size() > 0) { // 处理循环体 + ctx->stmt()[0]->accept(this); + } + loop_stack_.pop_back(); // 退出循环上下文 + return {}; + } + + // break 语句内部实现 + std::any visitBreakStmtInternal(SysYParser::StmtContext* ctx) { + if (loop_stack_.empty() || !loop_stack_.back().in_loop) { + throw std::runtime_error(FormatError("sema", "break 语句必须在循环体内使用")); + } + return {}; } - const std::string name = ctx->ID()->getText(); - auto* decl = table_.Lookup(name); - if (!decl) { - throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); - } - sema_.BindVarUse(ctx, decl); - return {}; - } - - SemanticContext TakeSemanticContext() { return std::move(sema_); } + + // continue 语句内部实现 + std::any visitContinueStmtInternal(SysYParser::StmtContext* ctx) { + if (loop_stack_.empty() || !loop_stack_.back().in_loop) { + throw std::runtime_error(FormatError("sema", "continue 语句必须在循环体内使用")); + } + return {}; + } + + // 赋值语句内部实现 + std::any visitAssignStmt(SysYParser::StmtContext* ctx) { + if (!ctx->lVal() || !ctx->exp()) { + throw std::runtime_error(FormatError("sema", "非法赋值语句")); + } + ExprInfo lvalue = CheckLValue(ctx->lVal()); // 检查左值 + if (lvalue.is_const) { + throw std::runtime_error(FormatError("sema", "不能给常量赋值")); + } + if (!lvalue.is_lvalue) { + throw std::runtime_error(FormatError("sema", "赋值左边必须是左值")); + } + ExprInfo rvalue = CheckExp(ctx->exp()); // 检查右值 + if (!IsTypeCompatible(rvalue.type, lvalue.type)) { + throw std::runtime_error(FormatError("sema", "赋值类型不匹配")); + } + if (rvalue.type != lvalue.type) { // 标记需要隐式转换 + sema_.AddConversion(ctx->exp(), rvalue.type, lvalue.type); + } + return {}; + } + + // 表达式语句内部实现 + std::any visitExpStmt(SysYParser::StmtContext* ctx) { + if (ctx->exp()) { + CheckExp(ctx->exp()); + } + return {}; + } + + // ==================== 表达式类型推导 ==================== + + // 主表达式 + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override { + ExprInfo result; + if (ctx->lVal()) { // 左值表达式 + result = CheckLValue(ctx->lVal()); + result.is_lvalue = true; + } else if (ctx->HEX_FLOAT() || ctx->DEC_FLOAT()) { // 浮点字面量 + result.type = ir::Type::GetFloatType(); + result.is_const = true; + result.is_const_int = false; + std::string text; + if (ctx->HEX_FLOAT()) text = ctx->HEX_FLOAT()->getText(); + else text = ctx->DEC_FLOAT()->getText(); + result.const_float_value = std::stof(text); + } else if (ctx->HEX_INT() || ctx->OCTAL_INT() || ctx->DECIMAL_INT() || ctx->ZERO()) { // 整数字面量 + result.type = ir::Type::GetInt32Type(); + result.is_const = true; + result.is_const_int = true; + std::string text; + if (ctx->HEX_INT()) text = ctx->HEX_INT()->getText(); + else if (ctx->OCTAL_INT()) text = ctx->OCTAL_INT()->getText(); + else if (ctx->DECIMAL_INT()) text = ctx->DECIMAL_INT()->getText(); + else text = ctx->ZERO()->getText(); + result.const_int_value = std::stoi(text, nullptr, 0); + } else if (ctx->exp()) { // 括号表达式 + result = CheckExp(ctx->exp()); + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 一元表达式 + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override { + ExprInfo result; + if (ctx->primaryExp()) { + ctx->primaryExp()->accept(this); + auto* info = sema_.GetExprType(ctx->primaryExp()); + if (info) result = *info; + } else if (ctx->Ident() && ctx->L_PAREN()) { // 函数调用 + result = CheckFuncCall(ctx); + } else if (ctx->unaryOp()) { // 一元运算 + ctx->unaryExp()->accept(this); + auto* operand = sema_.GetExprType(ctx->unaryExp()); + if (!operand) { + throw std::runtime_error(FormatError("sema", "一元操作数类型推导失败")); + result.type = ir::Type::GetInt32Type(); + } else { + std::string op = ctx->unaryOp()->getText(); + if (op == "!") { + if (!operand->type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "逻辑非操作数必须是 int 类型")); + } + result.type = ir::Type::GetInt32Type(); + result.is_const = operand->is_const; + if (operand->is_const && operand->is_const_int) { + result.is_const_int = true; + result.const_int_value = (operand->const_int_value == 0) ? 1 : 0; + } + } else { + if (!operand->type->IsInt32() && !operand->type->IsFloat()) { + throw std::runtime_error(FormatError("sema", "正负号操作数必须是算术类型")); + } + result.type = operand->type; + result.is_const = operand->is_const; + if (op == "-" && operand->is_const) { + if (operand->type->IsInt32() && operand->is_const_int) { + result.is_const_int = true; + result.const_int_value = -operand->const_int_value; + } else if (operand->type->IsFloat()) { + result.const_float_value = -operand->const_float_value; + } + } + } + result.is_lvalue = false; + } + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 乘除模表达式 + std::any visitMulExp(SysYParser::MulExpContext* ctx) override { + ExprInfo result; + if (ctx->mulExp()) { + ctx->mulExp()->accept(this); + ctx->unaryExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->mulExp()); + auto* right_info = sema_.GetExprType(ctx->unaryExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "乘除模操作数类型推导失败")); + result.type = ir::Type::GetInt32Type(); + } else { + std::string op; + if (ctx->MulOp()) { + op = "*"; + } else if (ctx->DivOp()) { + op = "/"; + } else if (ctx->QuoOp()) { + op = "%"; + } + result = CheckBinaryOp(left_info, right_info, op, ctx); + } + } else { + ctx->unaryExp()->accept(this); + auto* info = sema_.GetExprType(ctx->unaryExp()); + if (info) { + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 加减表达式 + std::any visitAddExp(SysYParser::AddExpContext* ctx) override { + ExprInfo result; + if (ctx->addExp()) { + ctx->addExp()->accept(this); + ctx->mulExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->addExp()); + auto* right_info = sema_.GetExprType(ctx->mulExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "加减操作数类型推导失败")); + result.type = ir::Type::GetInt32Type(); + } else { + std::string op; + if (ctx->AddOp()) { + op = "+"; + } else if (ctx->SubOp()) { + op = "-"; + } + result = CheckBinaryOp(left_info, right_info, op, ctx); + } + } else { + ctx->mulExp()->accept(this); + auto* info = sema_.GetExprType(ctx->mulExp()); + if (info) { + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 关系表达式 + std::any visitRelExp(SysYParser::RelExpContext* ctx) override { + ExprInfo result; + if (ctx->relExp()) { + ctx->relExp()->accept(this); + ctx->addExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->relExp()); + auto* right_info = sema_.GetExprType(ctx->addExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "关系操作数类型推导失败")); + } else { + if (!left_info->type->IsInt32() && !left_info->type->IsFloat()) { + throw std::runtime_error(FormatError("sema", "关系运算操作数必须是算术类型")); + } + std::string op; + if (ctx->LOp()) { + op = "<"; + } else if (ctx->GOp()) { + op = ">"; + } else if (ctx->LeOp()) { + op = "<="; + } else if (ctx->GeOp()) { + op = ">="; + } + result.type = ir::Type::GetInt32Type(); + result.is_lvalue = false; + if (left_info->is_const && right_info->is_const) { + result.is_const = true; + result.is_const_int = true; + float l = GetFloatValue(*left_info); + float r = GetFloatValue(*right_info); + if (op == "<") result.const_int_value = (l < r) ? 1 : 0; + else if (op == ">") result.const_int_value = (l > r) ? 1 : 0; + else if (op == "<=") result.const_int_value = (l <= r) ? 1 : 0; + else if (op == ">=") result.const_int_value = (l >= r) ? 1 : 0; + } + } + } else { + ctx->addExp()->accept(this); + auto* info = sema_.GetExprType(ctx->addExp()); + if (info) { + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 相等性表达式 + std::any visitEqExp(SysYParser::EqExpContext* ctx) override { + ExprInfo result; + if (ctx->eqExp()) { + ctx->eqExp()->accept(this); + ctx->relExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->eqExp()); + auto* right_info = sema_.GetExprType(ctx->relExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "相等性操作数类型推导失败")); + } else { + std::string op; + if (ctx->EqOp()) { + op = "=="; + } else if (ctx->NeOp()) { + op = "!="; + } + result.type = ir::Type::GetInt32Type(); + result.is_lvalue = false; + if (left_info->is_const && right_info->is_const) { + result.is_const = true; + result.is_const_int = true; + float l = GetFloatValue(*left_info); + float r = GetFloatValue(*right_info); + if (op == "==") result.const_int_value = (l == r) ? 1 : 0; + else if (op == "!=") result.const_int_value = (l != r) ? 1 : 0; + } + } + } else { + ctx->relExp()->accept(this); + auto* info = sema_.GetExprType(ctx->relExp()); + if (info) { + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 逻辑与表达式 + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override { + ExprInfo result; + if (ctx->lAndExp()) { + ctx->lAndExp()->accept(this); + ctx->eqExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->lAndExp()); + auto* right_info = sema_.GetExprType(ctx->eqExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "逻辑与操作数类型推导失败")); + } else { + if (!left_info->type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "逻辑与左操作数必须是 int 类型")); + } + if (!right_info->type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "逻辑与右操作数必须是 int 类型")); + } + result.type = ir::Type::GetInt32Type(); + result.is_lvalue = false; + if (left_info->is_const && right_info->is_const && + left_info->is_const_int && right_info->is_const_int) { + result.is_const = true; + result.is_const_int = true; + result.const_int_value = + (left_info->const_int_value && right_info->const_int_value) ? 1 : 0; + } + } + } else { + ctx->eqExp()->accept(this); + auto* info = sema_.GetExprType(ctx->eqExp()); + if (info) { + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 逻辑或表达式 + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override { + ExprInfo result; + if (ctx->lOrExp()) { + ctx->lOrExp()->accept(this); + ctx->lAndExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->lOrExp()); + auto* right_info = sema_.GetExprType(ctx->lAndExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "逻辑或操作数类型推导失败")); + } else { + if (!left_info->type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "逻辑或左操作数必须是 int 类型")); + } + if (!right_info->type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "逻辑或右操作数必须是 int 类型")); + } + result.type = ir::Type::GetInt32Type(); + result.is_lvalue = false; + if (left_info->is_const && right_info->is_const && + left_info->is_const_int && right_info->is_const_int) { + result.is_const = true; + result.is_const_int = true; + result.const_int_value = + (left_info->const_int_value || right_info->const_int_value) ? 1 : 0; + } + } + } else { + ctx->lAndExp()->accept(this); + auto* info = sema_.GetExprType(ctx->lAndExp()); + if (info) { + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 获取语义上下文 + SemanticContext TakeSemanticContext() { return std::move(sema_); } + +private: + SymbolTable table_; + SemanticContext sema_; + struct LoopContext { + bool in_loop; + antlr4::ParserRuleContext* loop_node; + }; + std::vector loop_stack_; + std::shared_ptr current_func_return_type_ = nullptr; + bool current_func_has_return_ = false; - private: - SymbolTable table_; - SemanticContext sema_; - bool seen_return_ = false; - size_t current_item_index_ = 0; - size_t total_items_ = 0; + // ==================== 辅助函数 ==================== + + ExprInfo CheckExp(SysYParser::ExpContext* ctx) { + if (!ctx || !ctx->addExp()) { + throw std::runtime_error(FormatError("sema", "无效表达式")); + } + ctx->addExp()->accept(this); + auto* info = sema_.GetExprType(ctx->addExp()); + if (!info) { + throw std::runtime_error(FormatError("sema", "表达式类型推导失败")); + } + ExprInfo result = *info; + sema_.SetExprType(ctx, result); + return result; + } + + // 专门用于检查 AddExp 的辅助函数(用于常量表达式) + ExprInfo CheckAddExp(SysYParser::AddExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "无效表达式")); + } + ctx->accept(this); + auto* info = sema_.GetExprType(ctx); + if (!info) { + throw std::runtime_error(FormatError("sema", "表达式类型推导失败")); + } + return *info; + } + + ExprInfo CheckCond(SysYParser::CondContext* ctx) { + if (!ctx || !ctx->lOrExp()) { + throw std::runtime_error(FormatError("sema", "无效条件表达式")); + } + ctx->lOrExp()->accept(this); + auto* info = sema_.GetExprType(ctx->lOrExp()); + if (!info) { + throw std::runtime_error(FormatError("sema", "条件表达式类型推导失败")); + } + return *info; + } + + ExprInfo CheckLValue(SysYParser::LValContext* ctx) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法左值")); + } + std::string name = ctx->Ident()->getText(); + auto* sym = table_.lookup(name); + if (!sym) { + throw std::runtime_error(FormatError("sema", "未定义的变量: " + name)); + } + bool is_const = (sym->kind == SymbolKind::Constant); + bool is_array_or_ptr = false; + if (sym->type) { + is_array_or_ptr = sym->type->IsArray() || + sym->type->IsPtrInt32() || + sym->type->IsPtrFloat(); + } + size_t dim_count = 0; + std::shared_ptr elem_type = sym->type; + if (sym->type && sym->type->IsArray()) { + if (auto* arr_type = dynamic_cast(sym->type.get())) { + dim_count = arr_type->GetDimensions().size(); + elem_type = arr_type->GetElementType(); + } + } else if (sym->type && (sym->type->IsPtrInt32() || sym->type->IsPtrFloat())) { + dim_count = 1; + if (sym->type->IsPtrInt32()) { + elem_type = ir::Type::GetInt32Type(); + } else if (sym->type->IsPtrFloat()) { + elem_type = ir::Type::GetFloatType(); + } + } + size_t subscript_count = ctx->exp().size(); + if (is_array_or_ptr) { + if (subscript_count != dim_count) { + throw std::runtime_error(FormatError("sema", "数组下标个数不匹配")); + } + for (auto* idx_exp : ctx->exp()) { + ExprInfo idx = CheckExp(idx_exp); + if (!idx.type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "数组下标必须是 int 类型")); + } + } + return {elem_type, true, false}; + } else { + if (subscript_count > 0) { + throw std::runtime_error(FormatError("sema", "非数组变量不能使用下标: " + name)); + } + return {sym->type, true, is_const}; + } + } + + ExprInfo CheckFuncCall(SysYParser::UnaryExpContext* ctx) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法函数调用")); + } + std::string func_name = ctx->Ident()->getText(); + auto* func_sym = table_.lookup(func_name); + if (!func_sym || func_sym->kind != SymbolKind::Function) { + throw std::runtime_error(FormatError("sema", "未定义的函数: " + func_name)); + } + std::vector args; + if (ctx->funcRParams()) { + for (auto* exp : ctx->funcRParams()->exp()) { + if (exp) { + args.push_back(CheckExp(exp)); + } + } + } + if (args.size() != func_sym->param_types.size()) { + throw std::runtime_error(FormatError("sema", "参数个数不匹配")); + } + for (size_t i = 0; i < std::min(args.size(), func_sym->param_types.size()); ++i) { + if (!IsTypeCompatible(args[i].type, func_sym->param_types[i])) { + throw std::runtime_error(FormatError("sema", "参数类型不匹配")); + } + if (args[i].type != func_sym->param_types[i] && ctx->funcRParams() && + i < ctx->funcRParams()->exp().size()) { + sema_.AddConversion(ctx->funcRParams()->exp()[i], + args[i].type, func_sym->param_types[i]); + } + } + std::shared_ptr return_type; + if (func_sym->type && func_sym->type->IsFunction()) { + auto* func_type = dynamic_cast(func_sym->type.get()); + if (func_type) { + return_type = func_type->GetReturnType(); + } + } + if (!return_type) { + return_type = ir::Type::GetInt32Type(); + } + ExprInfo result; + result.type = return_type; + result.is_lvalue = false; + result.is_const = false; + return result; + } + + ExprInfo CheckBinaryOp(const ExprInfo* left, const ExprInfo* right, + const std::string& op, antlr4::ParserRuleContext* ctx) { + ExprInfo result; + if (!left->type->IsInt32() && !left->type->IsFloat()) { + throw std::runtime_error(FormatError("sema", "左操作数必须是算术类型")); + } + if (!right->type->IsInt32() && !right->type->IsFloat()) { + throw std::runtime_error(FormatError("sema", "右操作数必须是算术类型")); + } + if (op == "%" && (!left->type->IsInt32() || !right->type->IsInt32())) { + throw std::runtime_error(FormatError("sema", "取模运算要求操作数为 int 类型")); + } + if (left->type->IsFloat() || right->type->IsFloat()) { + result.type = ir::Type::GetFloatType(); + } else { + result.type = ir::Type::GetInt32Type(); + } + result.is_lvalue = false; + if (left->is_const && right->is_const) { + result.is_const = true; + float l = GetFloatValue(*left); + float r = GetFloatValue(*right); + if (result.type->IsInt32()) { + result.is_const_int = true; + int li = (int)l, ri = (int)r; + if (op == "*") result.const_int_value = li * ri; + else if (op == "/") result.const_int_value = li / ri; + else if (op == "%") result.const_int_value = li % ri; + else if (op == "+") result.const_int_value = li + ri; + else if (op == "-") result.const_int_value = li - ri; + } else { + if (op == "*") result.const_float_value = l * r; + else if (op == "/") result.const_float_value = l / r; + else if (op == "+") result.const_float_value = l + r; + else if (op == "-") result.const_float_value = l - r; + } + } + return result; + } + + float GetFloatValue(const ExprInfo& info) { + if (info.type->IsInt32() && info.is_const_int) { + return (float)info.const_int_value; + } else { + return info.const_float_value; + } + } + + bool IsTypeCompatible(std::shared_ptr src, std::shared_ptr dst) { + if (src == dst) return true; + if (src->IsInt32() && dst->IsFloat()) return true; + if (src->IsFloat() && dst->IsInt32()) return true; + return false; + } + + void CollectFunctionDeclaration(SysYParser::FuncDefContext* ctx) { + if (!ctx || !ctx->Ident()) return; + std::string name = ctx->Ident()->getText(); + if (table_.lookup(name)) return; + std::shared_ptr ret_type; + if (ctx->funcType()) { + if (ctx->funcType()->Void()) { + ret_type = ir::Type::GetVoidType(); + } else if (ctx->funcType()->Int()) { + ret_type = ir::Type::GetInt32Type(); + } else if (ctx->funcType()->Float()) { + ret_type = ir::Type::GetFloatType(); + } + } + if (!ret_type) ret_type = ir::Type::GetInt32Type(); + std::vector> param_types; + if (ctx->funcFParams()) { + for (auto* param : ctx->funcFParams()->funcFParam()) { + if (!param) continue; + std::shared_ptr param_type; + if (param->bType()) { + if (param->bType()->Int()) { + param_type = ir::Type::GetInt32Type(); + } else if (param->bType()->Float()) { + param_type = ir::Type::GetFloatType(); + } + } + if (!param_type) param_type = ir::Type::GetInt32Type(); + if (!param->L_BRACK().empty()) { + if (param_type->IsInt32()) { + param_type = ir::Type::GetPtrInt32Type(); + } else if (param_type->IsFloat()) { + param_type = ir::Type::GetPtrFloatType(); + } + } + param_types.push_back(param_type); + } + } + + // 创建函数类型 + std::shared_ptr func_type = ir::Type::GetFunctionType(ret_type, param_types); + + // 创建函数符号 + Symbol sym; + sym.name = name; + sym.kind = SymbolKind::Function; + sym.type = func_type; + sym.param_types = param_types; + sym.scope_level = 0; + sym.is_initialized = true; + sym.func_def_ctx = ctx; + + table_.addSymbol(sym); + } + + void CollectFunctionParams(SysYParser::FuncFParamsContext* ctx) { + if (!ctx) return; + for (auto* param : ctx->funcFParam()) { + if (!param || !param->Ident()) continue; + std::string name = param->Ident()->getText(); + if (table_.lookupCurrent(name)) { + throw std::runtime_error(FormatError("sema", "重复定义参数: " + name)); + } + std::shared_ptr param_type; + if (param->bType()) { + if (param->bType()->Int()) { + param_type = ir::Type::GetInt32Type(); + } else if (param->bType()->Float()) { + param_type = ir::Type::GetFloatType(); + } + } + if (!param_type) param_type = ir::Type::GetInt32Type(); + bool is_array = !param->L_BRACK().empty(); + if (is_array) { + if (param_type->IsInt32()) { + param_type = ir::Type::GetPtrInt32Type(); + } else if (param_type->IsFloat()) { + param_type = ir::Type::GetPtrFloatType(); + } + } + Symbol sym; + sym.name = name; + sym.kind = SymbolKind::Parameter; + sym.type = param_type; + sym.scope_level = table_.currentScopeLevel(); + sym.is_initialized = true; + sym.var_def_ctx = nullptr; + table_.addSymbol(sym); + } + } + + void CheckGlobalInitIsConst(SysYParser::InitValContext* ctx) { + if (!ctx) return; + if (ctx->exp()) { + ExprInfo info = CheckExp(ctx->exp()); + if (!info.is_const) { + throw std::runtime_error(FormatError("sema", "全局变量初始化必须是常量表达式")); + } + } else { + for (auto* init : ctx->initVal()) { + CheckGlobalInitIsConst(init); + } + } + } + + int EvaluateConstExp(SysYParser::ConstExpContext* ctx) { + if (!ctx || !ctx->addExp()) return 0; + ExprInfo info = CheckAddExp(ctx->addExp()); + if (info.is_const && info.is_const_int) { + return info.const_int_value; + } + throw std::runtime_error(FormatError("sema", "常量表达式求值失败")); + return 0; + } + + struct ConstValue { + bool is_int; + int int_val; + float float_val; + }; + + std::vector EvaluateConstInitVal(SysYParser::ConstInitValContext* ctx, + const std::vector& dims, + std::shared_ptr base_type) { + std::vector result; + if (!ctx) return result; + if (ctx->constExp()) { + ExprInfo info = CheckAddExp(ctx->constExp()->addExp()); + ConstValue val; + if (info.type->IsInt32() && info.is_const_int) { + val.is_int = true; + val.int_val = info.const_int_value; + if (base_type->IsFloat()) { + val.is_int = false; + val.float_val = (float)info.const_int_value; + } + } else if (info.type->IsFloat() && info.is_const) { + val.is_int = false; + val.float_val = info.const_float_value; + if (base_type->IsInt32()) { + val.is_int = true; + val.int_val = (int)info.const_float_value; + } + } else { + val.is_int = base_type->IsInt32(); + val.int_val = 0; + val.float_val = 0.0f; + } + result.push_back(val); + } else { + for (auto* init : ctx->constInitVal()) { + std::vector sub_vals = EvaluateConstInitVal(init, dims, base_type); + result.insert(result.end(), sub_vals.begin(), sub_vals.end()); + } + } + return result; + } + + void CheckMainFunction() { + auto* main_sym = table_.lookup("main"); + if (!main_sym || main_sym->kind != SymbolKind::Function) { + throw std::runtime_error(FormatError("sema", "缺少 main 函数")); + } + std::shared_ptr ret_type; + if (main_sym->type && main_sym->type->IsFunction()) { + auto* func_type = dynamic_cast(main_sym->type.get()); + if (func_type) { + ret_type = func_type->GetReturnType(); + } + } + if (!ret_type || !ret_type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "main 函数必须返回 int")); + } + if (!main_sym->param_types.empty()) { + throw std::runtime_error(FormatError("sema", "main 函数不能有参数")); + } + } }; } // namespace SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { - SemaVisitor visitor; - comp_unit.accept(&visitor); - return visitor.TakeSemanticContext(); -} + SemaVisitor visitor; + comp_unit.accept(&visitor); + return visitor.TakeSemanticContext(); +} \ No newline at end of file