From 310c93feac8480a74f467e9a4c71fb9062344c4a Mon Sep 17 00:00:00 2001 From: LuoHello <2901023943@qq.com> Date: Wed, 25 Mar 2026 23:48:58 +0800 Subject: [PATCH] IRGen,IR fit our antrl4,full make passed --- include/ir/IR.h | 76 +- include/irgen/IRGen.h | 42 +- include/sem/Sema.h | 88 +- include/sem/SymbolTable.h | 71 +- src/ir/IRPrinter.cpp | 15 +- src/ir/Type.cpp | 191 ++++- src/irgen/IRGenDecl.cpp | 102 +-- src/irgen/IRGenExp.cpp | 178 ++++- src/irgen/IRGenFunc.cpp | 123 ++- src/irgen/IRGenStmt.cpp | 65 +- src/sem/Sema.cpp | 1595 +++++++++++++++++++++++++++++++++---- src/sem/SymbolTable.cpp | 337 +++++++- 12 files changed, 2511 insertions(+), 372 deletions(-) diff --git a/include/ir/IR.h b/include/ir/IR.h index b961192..4ea12dc 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -41,6 +41,8 @@ namespace ir { class Type; +class ArrayType; +class FunctionType; class Value; class User; class ConstantValue; @@ -93,23 +95,85 @@ class Context { class Type { public: - enum class Kind { Void, Int32, PtrInt32 }; - explicit Type(Kind k); + enum class Kind { Void, Int32, Float, PtrInt32, PtrFloat, Label, Array, Function }; + + virtual ~Type() = default; + // 使用静态共享对象获取类型。 // 同一类型可直接比较返回值是否相等,例如: // Type::GetInt32Type() == Type::GetInt32Type() static const std::shared_ptr& GetVoidType(); static const std::shared_ptr& GetInt32Type(); + static const std::shared_ptr& GetFloatType(); static const std::shared_ptr& GetPtrInt32Type(); - Kind GetKind() const; - bool IsVoid() const; - bool IsInt32() const; - bool IsPtrInt32() const; + static const std::shared_ptr& GetPtrFloatType(); + static const std::shared_ptr& GetLabelType(); + static std::shared_ptr GetArrayType(std::shared_ptr elem, std::vector dims); + static std::shared_ptr GetFunctionType(std::shared_ptr ret, std::vector> params); + + // 类型判断 + Kind GetKind() const { return kind_; } + bool IsVoid() const { return kind_ == Kind::Void; } + bool IsInt32() const { return kind_ == Kind::Int32; } + bool IsFloat() const { return kind_ == Kind::Float; } + bool IsPtrInt32() const { return kind_ == Kind::PtrInt32; } + bool IsPtrFloat() const { return kind_ == Kind::PtrFloat; } + bool IsLabel() const { return kind_ == Kind::Label; } + bool IsArray() const { return kind_ == Kind::Array; } + bool IsFunction() const { return kind_ == Kind::Function; } + + // 类型属性 + virtual size_t Size() const; // 字节大小 + virtual size_t Alignment() const; // 对齐要求 + virtual bool IsComplete() const; // 是否为完整类型(非 void,数组维度已知等) + +protected: + explicit Type(Kind k); // 构造函数 protected,只能由工厂和派生类调用 private: Kind kind_; }; +// 数组类型 +class ArrayType : public Type { +public: + // 获取元素类型和维度 + const std::shared_ptr& GetElementType() const { return elem_type_; } + const std::vector& GetDimensions() const { return dims_; } + size_t GetElementCount() const; // 总元素个数 + + size_t Size() const override; + size_t Alignment() const override; + bool IsComplete() const override; + +protected: + ArrayType(std::shared_ptr elem, std::vector dims); + friend class Type; // 允许 Type::GetArrayType 构造 + +private: + std::shared_ptr elem_type_; + std::vector dims_; +}; + +// 函数类型 +class FunctionType : public Type { +public: + const std::shared_ptr& GetReturnType() const { return return_type_; } + const std::vector>& GetParamTypes() const { return param_types_; } + + size_t Size() const override; // 函数类型没有大小,通常返回 0 + size_t Alignment() const override; // 无对齐要求 + bool IsComplete() const override; // 函数类型视为完整 + +protected: + FunctionType(std::shared_ptr ret, std::vector> params); + friend class Type; + +private: + std::shared_ptr return_type_; + std::vector> param_types_; +}; + class Value { public: Value(std::shared_ptr ty, std::string name); diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index 231ba90..30b71c7 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -24,18 +24,29 @@ class IRGenImpl final : public SysYBaseVisitor { public: IRGenImpl(ir::Module& module, const SemanticContext& sema); + // 顶层 std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; + + // 块 + std::any visitBlock(SysYParser::BlockContext* ctx) override; std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; + + // 声明 std::any visitDecl(SysYParser::DeclContext* ctx) override; - std::any visitStmt(SysYParser::StmtContext* ctx) override; std::any visitVarDef(SysYParser::VarDefContext* ctx) override; - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override; - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override; - std::any visitVarExp(SysYParser::VarExpContext* ctx) override; - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override; + + // 语句 + std::any visitStmt(SysYParser::StmtContext* ctx) override; + + // 表达式 + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override; + std::any visitLVal(SysYParser::LValContext* ctx) override; + std::any visitAddExp(SysYParser::AddExpContext* ctx) override; + std::any visitMulExp(SysYParser::MulExpContext* ctx) override; + std::any visitRelExp(SysYParser::RelExpContext* ctx) override; + std::any visitEqExp(SysYParser::EqExpContext* ctx) override; + std::any visitCond(SysYParser::CondContext* ctx) override; private: enum class BlockFlow { @@ -45,6 +56,15 @@ class IRGenImpl final : public SysYBaseVisitor { BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item); ir::Value* EvalExpr(SysYParser::ExpContext& expr); + ir::Value* EvalCond(SysYParser::CondContext& cond); + + // 辅助函数 + BlockFlow HandleReturnStmt(SysYParser::StmtContext* ctx); + BlockFlow HandleIfStmt(SysYParser::StmtContext* ctx); + BlockFlow HandleWhileStmt(SysYParser::StmtContext* ctx); + BlockFlow HandleBreakStmt(SysYParser::StmtContext* ctx); + BlockFlow HandleContinueStmt(SysYParser::StmtContext* ctx); + BlockFlow HandleAssignStmt(SysYParser::StmtContext* ctx); ir::Module& module_; const SemanticContext& sema_; @@ -52,6 +72,14 @@ class IRGenImpl final : public SysYBaseVisitor { ir::IRBuilder builder_; // 名称绑定由 Sema 负责;IRGen 只维护“声明 -> 存储槽位”的代码生成状态。 std::unordered_map storage_map_; + + // 循环栈,用于 break/continue + struct LoopContext { + ir::BasicBlock* condBlock; + ir::BasicBlock* bodyBlock; + ir::BasicBlock* exitBlock; + }; + std::vector loopStack_; }; std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, diff --git a/include/sem/Sema.h b/include/sem/Sema.h index 9ac057b..c79c401 100644 --- a/include/sem/Sema.h +++ b/include/sem/Sema.h @@ -2,29 +2,83 @@ #pragma once #include +#include +#include #include "SysYParser.h" +#include "ir/IR.h" +// 表达式信息结构 +struct ExprInfo { + std::shared_ptr type = nullptr; + bool is_lvalue = false; + bool is_const = false; + bool is_const_int = false; // 是否是整型常量 + int const_int_value = 0; + float const_float_value = 0.0f; + antlr4::ParserRuleContext* node = nullptr; // 对应的语法树节点 +}; + +// 语义分析上下文:存储分析过程中产生的信息 class SemanticContext { - public: - void BindVarUse(SysYParser::VarContext* use, - SysYParser::VarDefContext* decl) { - var_uses_[use] = decl; - } - - SysYParser::VarDefContext* ResolveVarUse( - const SysYParser::VarContext* use) const { - auto it = var_uses_.find(use); - return it == var_uses_.end() ? nullptr : it->second; - } - - private: - std::unordered_map - var_uses_; +public: + // ----- 变量使用绑定(使用 LValContext 而不是 VarContext)----- + void BindVarUse(SysYParser::LValContext* use, + SysYParser::VarDefContext* decl) { + var_uses_[use] = decl; + } + + SysYParser::VarDefContext* ResolveVarUse( + const SysYParser::LValContext* use) const { + auto it = var_uses_.find(use); + return it == var_uses_.end() ? nullptr : it->second; + } + + // ----- 表达式类型信息存储 ----- + void SetExprType(antlr4::ParserRuleContext* node, const ExprInfo& info) { + ExprInfo copy = info; + copy.node = node; + expr_types_[node] = copy; + } + + ExprInfo* GetExprType(antlr4::ParserRuleContext* node) { + auto it = expr_types_.find(node); + return it == expr_types_.end() ? nullptr : &it->second; + } + + const ExprInfo* GetExprType(antlr4::ParserRuleContext* node) const { + auto it = expr_types_.find(node); + return it == expr_types_.end() ? nullptr : &it->second; + } + + // ----- 隐式转换标记(供 IR 生成使用)----- + struct ConversionInfo { + antlr4::ParserRuleContext* node; + std::shared_ptr from_type; + std::shared_ptr to_type; + }; + + void AddConversion(antlr4::ParserRuleContext* node, + std::shared_ptr from, + std::shared_ptr to) { + conversions_.push_back({node, from, to}); + } + + const std::vector& GetConversions() const { return conversions_; } + +private: + // 变量使用映射 - 使用 LValContext 作为键 + std::unordered_map var_uses_; + + // 表达式类型映射 + std::unordered_map expr_types_; + + // 隐式转换列表 + std::vector conversions_; }; // 目前仅检查: // - 变量先声明后使用 // - 局部变量不允许重复定义 -SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); +SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); \ No newline at end of file diff --git a/include/sem/SymbolTable.h b/include/sem/SymbolTable.h index c9396dd..29b4fae 100644 --- a/include/sem/SymbolTable.h +++ b/include/sem/SymbolTable.h @@ -3,15 +3,74 @@ #include #include +#include +#include #include "SysYParser.h" +#include "ir/IR.h" + +// 符号种类 +enum class SymbolKind { + Variable, + Function, + Parameter, + Constant +}; + +// 符号条目 +struct Symbol { + std::string name; + SymbolKind kind; + std::shared_ptr type; // 指向 Type 对象的智能指针 + int scope_level = 0; // 定义时的作用域深度 + int stack_offset = -1; // 局部变量/参数栈偏移(全局变量为 -1) + bool is_initialized = false; // 是否已初始化 + bool is_builtin = false; // 是否为库函数 + + // 对于函数,额外存储参数列表(类型已包含在函数类型中,这里仅用于快速访问) + std::vector> param_types; + + // 对于常量,存储常量值(这里支持 int32 和 float) + union ConstantValue { + int i32; + float f32; + } const_value; + bool is_int_const = true; // 标记常量类型,用于区分 int 和 float + + // 关联的语法树节点(用于报错位置或进一步分析) + SysYParser::VarDefContext* var_def_ctx = nullptr; + SysYParser::FuncDefContext* func_def_ctx = nullptr; +}; class SymbolTable { - public: - void Add(const std::string& name, SysYParser::VarDefContext* decl); - bool Contains(const std::string& name) const; - SysYParser::VarDefContext* Lookup(const std::string& name) const; + public: + SymbolTable(); + ~SymbolTable() = default; + + // ----- 作用域管理 ----- + void enterScope(); // 进入新作用域 + void exitScope(); // 退出当前作用域 + int currentScopeLevel() const { return static_cast(scopes_.size()) - 1; } + + // ----- 符号操作(推荐使用)----- + bool addSymbol(const Symbol& sym); // 添加符号到当前作用域 + Symbol* lookup(const std::string& name); // 从当前作用域向外查找 + Symbol* lookupCurrent(const std::string& name); // 仅在当前作用域查找 + + // ----- 与原接口兼容(保留原有功能)----- + void Add(const std::string& name, SysYParser::VarDefContext* decl); + bool Contains(const std::string& name) const; + SysYParser::VarDefContext* Lookup(const std::string& name) const; + + // ----- 辅助函数:从语法树节点构造 Type ----- + static std::shared_ptr getTypeFromVarDef(SysYParser::VarDefContext* ctx); + static std::shared_ptr getTypeFromFuncDef(SysYParser::FuncDefContext* ctx); + + void registerBuiltinFunctions(); + + private: + // 作用域栈:每个元素是一个从名字到符号的映射 + std::vector> scopes_; - private: - std::unordered_map table_; + static constexpr int GLOBAL_SCOPE = 0; // 全局作用域索引 }; diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 30efbb6..5b11d63 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -14,12 +14,15 @@ namespace ir { static const char* TypeToString(const Type& ty) { switch (ty.GetKind()) { - case Type::Kind::Void: - return "void"; - case Type::Kind::Int32: - return "i32"; - case Type::Kind::PtrInt32: - return "i32*"; + case Type::Kind::Void: return "void"; + case Type::Kind::Int32: return "i32"; + case Type::Kind::Float: return "float"; + case Type::Kind::PtrInt32: return "i32*"; + case Type::Kind::PtrFloat: return "float*"; + case Type::Kind::Label: return "label"; + case Type::Kind::Array: return "array"; + case Type::Kind::Function: return "function"; + default: return "unknown"; } throw std::runtime_error(FormatError("ir", "未知类型")); } diff --git a/src/ir/Type.cpp b/src/ir/Type.cpp index 3e1684d..8d0f5b9 100644 --- a/src/ir/Type.cpp +++ b/src/ir/Type.cpp @@ -1,31 +1,208 @@ // 当前仅支持 void、i32 和 i32*。 #include "ir/IR.h" +#include namespace ir { Type::Type(Kind k) : kind_(k) {} +size_t Type::Size() const { + switch (kind_) { + case Kind::Void: return 0; + case Kind::Int32: return 4; + case Kind::Float: return 4; // 单精度浮点 4 字节 + case Kind::PtrInt32: return 8; // 假设 64 位指针 + case Kind::PtrFloat: return 8; + case Kind::Label: return 8; // 标签地址大小(指针大小) + default: return 0; // 派生类应重写 + } +} + +size_t Type::Alignment() const { + switch (kind_) { + case Kind::Int32: return 4; + case Kind::Float: return 4; + case Kind::PtrInt32: return 8; + case Kind::PtrFloat: return 8; + case Kind::Label: return 8; // 与指针相同 + default: return 1; + } +} + +bool Type::IsComplete() const { + return kind_ != Kind::Void; +} const std::shared_ptr& Type::GetVoidType() { - static const std::shared_ptr type = std::make_shared(Kind::Void); + static const std::shared_ptr type = std::shared_ptr(new Type(Kind::Void)); return type; } const std::shared_ptr& Type::GetInt32Type() { - static const std::shared_ptr type = std::make_shared(Kind::Int32); + static const std::shared_ptr type = std::shared_ptr(new Type(Kind::Int32)); return type; } +const std::shared_ptr& Type::GetFloatType() { + static const std::shared_ptr type(new Type(Kind::Float)); + return type; +} + const std::shared_ptr& Type::GetPtrInt32Type() { - static const std::shared_ptr type = std::make_shared(Kind::PtrInt32); + static const std::shared_ptr type = std::shared_ptr(new Type(Kind::PtrInt32)); return type; } -Type::Kind Type::GetKind() const { return kind_; } +const std::shared_ptr& Type::GetPtrFloatType() { + static const std::shared_ptr type(new Type(Kind::PtrFloat)); + return type; +} + +const std::shared_ptr& Type::GetLabelType() { + static const std::shared_ptr type(new Type(Kind::Label)); + return type; +} + +// ---------- 数组类型缓存 ---------- +// 使用自定义键类型保证唯一性:元素类型指针 + 维度向量 +struct ArrayKey { + const Type* elem_type; + std::vector dims; + + bool operator==(const ArrayKey& other) const { + return elem_type == other.elem_type && dims == other.dims; + } +}; + +struct ArrayKeyHash { + std::size_t operator()(const ArrayKey& key) const { + std::size_t h = std::hash{}(key.elem_type); + for (int d : key.dims) { + h ^= std::hash{}(d) + 0x9e3779b9 + (h << 6) + (h >> 2); + } + return h; + } +}; + +static std::unordered_map, ArrayKeyHash>& GetArrayCache() { + static std::unordered_map, ArrayKeyHash> cache; + return cache; +} + +std::shared_ptr Type::GetArrayType(std::shared_ptr elem, + std::vector dims) { + // 检查维度合法性 + for (int d : dims) { + if (d <= 0) { + // SysY 数组维度必须为正整数常量表达式,这里假设已检查 + return nullptr; + } + } + + ArrayKey key{elem.get(), dims}; + auto& cache = GetArrayCache(); + auto it = cache.find(key); + if (it != cache.end()) { + auto ptr = it->second.lock(); + if (ptr) return ptr; + } + + auto arr = std::shared_ptr(new ArrayType(std::move(elem), std::move(dims))); + cache[key] = arr; + return arr; +} + +// ---------- 函数类型缓存 ---------- +struct FunctionKey { + const Type* return_type; + std::vector param_types; + + bool operator==(const FunctionKey& other) const { + return return_type == other.return_type && param_types == other.param_types; + } +}; + +struct FunctionKeyHash { + std::size_t operator()(const FunctionKey& key) const { + std::size_t h = std::hash{}(key.return_type); + for (const Type* t : key.param_types) { + h ^= std::hash{}(t) + 0x9e3779b9 + (h << 6) + (h >> 2); + } + return h; + } +}; + +static std::unordered_map, FunctionKeyHash>& GetFunctionCache() { + static std::unordered_map, FunctionKeyHash> cache; + return cache; +} + +std::shared_ptr Type::GetFunctionType(std::shared_ptr ret, + std::vector> params) { + // 提取裸指针用于键(保证唯一性,因为 shared_ptr 指向同一对象) + std::vector param_ptrs; + param_ptrs.reserve(params.size()); + for (const auto& p : params) { + param_ptrs.push_back(p.get()); + } + + FunctionKey key{ret.get(), std::move(param_ptrs)}; + auto& cache = GetFunctionCache(); + auto it = cache.find(key); + if (it != cache.end()) { + auto ptr = it->second.lock(); + if (ptr) return ptr; + } -bool Type::IsVoid() const { return kind_ == Kind::Void; } + auto func = std::shared_ptr(new FunctionType(std::move(ret), std::move(params))); + cache[key] = func; + return func; +} -bool Type::IsInt32() const { return kind_ == Kind::Int32; } +// ---------- ArrayType 实现 ---------- +ArrayType::ArrayType(std::shared_ptr elem, std::vector dims) + : Type(Kind::Array), elem_type_(std::move(elem)), dims_(std::move(dims)) { + // 数组元素类型必须是完整类型 + assert(elem_type_ && elem_type_->IsComplete()); +} + +size_t ArrayType::GetElementCount() const { + size_t count = 1; + for (int d : dims_) count *= d; + return count; +} + +size_t ArrayType::Size() const { + return GetElementCount() * elem_type_->Size(); +} -bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; } +size_t ArrayType::Alignment() const { + // 数组对齐等于其元素对齐 + return elem_type_->Alignment(); +} + +bool ArrayType::IsComplete() const { + // 维度已确定且元素类型完整,则数组完整 + return !dims_.empty() && elem_type_->IsComplete(); +} + +// ---------- FunctionType 实现 ---------- +FunctionType::FunctionType(std::shared_ptr ret, + std::vector> params) + : Type(Kind::Function), return_type_(std::move(ret)), param_types_(std::move(params)) {} + +size_t FunctionType::Size() const { + // 函数类型没有运行时大小,通常用于类型检查,返回 0 + return 0; +} + +size_t FunctionType::Alignment() const { + // 不对齐 + return 1; +} + +bool FunctionType::IsComplete() const { + // 函数类型总是完整的(只要返回类型完整,但 SysY 中 void 也视为完整) + return true; +} } // namespace ir diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index 0eb62ae..7ca55d0 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -8,100 +8,76 @@ namespace { -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { +// 使用 LValContext 而不是 LValueContext +std::string GetLValueName(SysYParser::LValContext& lvalue) { + if (!lvalue.Ident()) { throw std::runtime_error(FormatError("irgen", "非法左值")); } - return lvalue.ID()->getText(); + return lvalue.Ident()->getText(); } } // namespace -std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少语句块")); - } - for (auto* item : ctx->blockItem()) { - if (item) { - if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { - // 当前语法要求 return 为块内最后一条语句;命中后可停止生成。 - break; - } - } - } - return {}; -} +// 注意:visitBlock 已经在 IRGenFunc.cpp 中实现,这里不要重复定义 -IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult( - SysYParser::BlockItemContext& item) { - return std::any_cast(item.accept(this)); -} - -std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少块内项")); - } - if (ctx->decl()) { - ctx->decl()->accept(this); - return BlockFlow::Continue; - } - if (ctx->stmt()) { - return ctx->stmt()->accept(this); - } - throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明")); -} - -// 变量声明的 IR 生成目前也是最小实现: -// - 先检查声明的基础类型,当前仅支持局部 int; -// - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。 -// -// 和更完整的版本相比,这里还没有: -// - 一个 Decl 中多个变量定义的顺序处理; -// - const、数组、全局变量等不同声明形态; -// - 更丰富的类型系统。 std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少变量声明")); } - if (!ctx->btype() || !ctx->btype()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明")); + + // 处理 varDecl + if (auto* varDecl = ctx->varDecl()) { + // 检查类型 + if (varDecl->bType() && varDecl->bType()->Int()) { + for (auto* varDef : varDecl->varDef()) { + varDef->accept(this); + } + } else { + throw std::runtime_error(FormatError("irgen", "当前仅支持 int 类型变量")); + } } - auto* var_def = ctx->varDef(); - if (!var_def) { - throw std::runtime_error(FormatError("irgen", "非法变量声明")); + + // 处理 constDecl(暂不支持) + if (ctx->constDecl()) { + throw std::runtime_error(FormatError("irgen", "常量声明暂未实现")); } - var_def->accept(this); + return {}; } - -// 当前仍是教学用的最小版本,因此这里只支持: -// - 局部 int 变量; -// - 标量初始化; -// - 一个 VarDef 对应一个槽位。 std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少变量定义")); } - if (!ctx->lValue()) { + + // 使用 Ident() 而不是 lValue() + if (!ctx->Ident()) { throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); } - GetLValueName(*ctx->lValue()); + + std::string varName = ctx->Ident()->getText(); + if (storage_map_.find(ctx) != storage_map_.end()) { - throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); + throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位: " + varName)); } + + // 分配存储 auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); storage_map_[ctx] = slot; - + ir::Value* init = nullptr; - if (auto* init_value = ctx->initValue()) { - if (!init_value->exp()) { - throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化")); + // 使用 initVal() 而不是 initValue() + if (auto* initVal = ctx->initVal()) { + if (initVal->exp()) { + init = EvalExpr(*initVal->exp()); + } else { + // 数组初始化暂不支持 + init = builder_.CreateConstInt(0); } - init = EvalExpr(*init_value->exp()); } else { init = builder_.CreateConstInt(0); } + builder_.CreateStore(init, slot); return {}; } diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index cf4797c..90d07b1 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -20,61 +20,183 @@ // - 数组、指针、下标访问 // - 条件与比较表达式 // - ... + ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { return std::any_cast(expr.accept(this)); } - -std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "非法括号表达式")); - } - return EvalExpr(*ctx->exp()); +ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) { + return std::any_cast(cond.accept(this)); } - -std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量")); +// 基本表达式:数字、变量、括号表达式 +std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "缺少基本表达式")); } - return static_cast( - builder_.CreateConstInt(std::stoi(ctx->number()->getText()))); + + // 处理数字字面量 + if (ctx->DECIMAL_INT()) { + int value = std::stoi(ctx->DECIMAL_INT()->getText()); + return static_cast(builder_.CreateConstInt(value)); + } + + if (ctx->HEX_INT()) { + std::string hex = ctx->HEX_INT()->getText(); + int value = std::stoi(hex, nullptr, 16); + return static_cast(builder_.CreateConstInt(value)); + } + + if (ctx->OCTAL_INT()) { + std::string oct = ctx->OCTAL_INT()->getText(); + int value = std::stoi(oct, nullptr, 8); + return static_cast(builder_.CreateConstInt(value)); + } + + if (ctx->ZERO()) { + return static_cast(builder_.CreateConstInt(0)); + } + + // 处理变量 + if (ctx->lVal()) { + return ctx->lVal()->accept(this); + } + + // 处理括号表达式 + if (ctx->L_PAREN() && ctx->exp()) { + return EvalExpr(*ctx->exp()); + } + + throw std::runtime_error(FormatError("irgen", "不支持的基本表达式类型")); } -// 变量使用的处理流程: +// 左值(变量)处理 // 1. 先通过语义分析结果把变量使用绑定回声明; // 2. 再通过 storage_map_ 找到该声明对应的栈槽位; // 3. 最后生成 load,把内存中的值读出来。 -// -// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 -std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) { - if (!ctx || !ctx->var() || !ctx->var()->ID()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); +std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("irgen", "非法左值")); } - auto* decl = sema_.ResolveVarUse(ctx->var()); + + std::string varName = ctx->Ident()->getText(); + + // 从语义分析获取变量定义 + auto* decl = sema_.ResolveVarUse(ctx); if (!decl) { throw std::runtime_error( FormatError("irgen", - "变量使用缺少语义绑定: " + ctx->var()->ID()->getText())); + "变量使用缺少语义绑定: " + varName)); } + auto it = storage_map_.find(decl); if (it == storage_map_.end()) { throw std::runtime_error( FormatError("irgen", - "变量声明缺少存储槽位: " + ctx->var()->ID()->getText())); + "变量声明缺少存储槽位: " + varName)); } + return static_cast( builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); } -std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { +// 加法表达式 +std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { + if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法加法表达式")); } - ir::Value* lhs = EvalExpr(*ctx->exp(0)); - ir::Value* rhs = EvalExpr(*ctx->exp(1)); - return static_cast( - builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, - module_.GetContext().NextTemp())); + + // 注意:mulExp() 返回的是 MulExpContext*,不是 vector + // 需要递归处理 AddExp 的左结合性 + // AddExp : MulExp | AddExp ('+' | '-') MulExp + + // 先处理左操作数 + ir::Value* result = nullptr; + + // 如果有左子节点(AddExp),递归处理 + if (ctx->addExp()) { + result = std::any_cast(ctx->addExp()->accept(this)); + } else { + // 否则是 MulExp + result = std::any_cast(ctx->mulExp()->accept(this)); + } + + // 如果有运算符和右操作数 + if (ctx->AddOp() || ctx->SubOp()) { + ir::Value* rhs = std::any_cast(ctx->mulExp()->accept(this)); + + if (ctx->AddOp()) { + result = builder_.CreateAdd(result, rhs, module_.GetContext().NextTemp()); + } else if (ctx->SubOp()) { + // 减法:a - b = a + (-b) + // 暂时用加法,后续需要实现真正的减法 + result = builder_.CreateAdd(result, rhs, module_.GetContext().NextTemp()); + } + } + + return static_cast(result); +} + +// 在 IRGenExp.cpp 中添加 + +// 简化版 visitMulExp +std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); + } + + // 暂时只返回 unaryExp 的值 + if (ctx->unaryExp()) { + return ctx->unaryExp()->accept(this); + } + + // 如果有 mulExp 子节点,递归处理 + if (ctx->mulExp()) { + return ctx->mulExp()->accept(this); + } + + throw std::runtime_error(FormatError("irgen", "乘法表达式暂未实现")); +} + +// 关系表达式(暂未完整实现) +std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法关系表达式")); + } + + // 简化:返回 addExp 的值 + if (ctx->addExp()) { + return ctx->addExp()->accept(this); + } + + throw std::runtime_error(FormatError("irgen", "关系表达式暂未实现")); +} + +// 相等表达式(暂未完整实现) +std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法相等表达式")); + } + + // 简化:返回 relExp 的值 + if (ctx->relExp()) { + return ctx->relExp()->accept(this); + } + + throw std::runtime_error(FormatError("irgen", "相等表达式暂未实现")); +} + +// 条件表达式 +std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法条件表达式")); + } + + // 简化:返回 lOrExp 的值 + if (ctx->lOrExp()) { + return ctx->lOrExp()->accept(this); + } + + throw std::runtime_error(FormatError("irgen", "条件表达式暂未实现")); } diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 4912d03..87f3ac8 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -1,6 +1,8 @@ #include "irgen/IRGen.h" +#include #include +#include #include "SysYParser.h" #include "ir/IR.h" @@ -9,7 +11,6 @@ namespace { void VerifyFunctionStructure(const ir::Function& func) { - // 当前 IRGen 仍是单入口、顺序生成;这里在生成结束后补一层块终结校验。 for (const auto& bb : func.GetBlocks()) { if (!bb || !bb->HasTerminator()) { throw std::runtime_error( @@ -27,61 +28,105 @@ IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema) func_(nullptr), builder_(module.GetContext(), nullptr) {} -// 编译单元的 IR 生成当前只实现了最小功能: -// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容; -// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR; -// -// 当前还没有实现: -// - 多个函数定义的遍历与生成; -// - 全局变量、全局常量的 IR 生成。 +// 修正:没有 mainFuncDef,通过函数名找到 main std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少编译单元")); } - auto* func = ctx->funcDef(); - if (!func) { - throw std::runtime_error(FormatError("irgen", "缺少函数定义")); + + // 获取所有函数定义 + auto funcDefs = ctx->funcDef(); + + // 找到 main 函数 + SysYParser::FuncDefContext* mainFunc = nullptr; + for (auto* funcDef : funcDefs) { + if (funcDef->Ident() && funcDef->Ident()->getText() == "main") { + mainFunc = funcDef; + break; + } } - func->accept(this); + + if (!mainFunc) { + throw std::runtime_error(FormatError("irgen", "缺少main函数")); + } + + // 生成 main 函数 + mainFunc->accept(this); + return {}; } -// 函数 IR 生成当前实现了: -// 1. 获取函数名; -// 2. 检查函数返回类型; -// 3. 在 Module 中创建 Function; -// 4. 将 builder 插入点设置到入口基本块; -// 5. 继续生成函数体。 -// -// 当前还没有实现: -// - 通用函数返回类型处理; -// - 形参列表遍历与参数类型收集; -// - FunctionType 这样的函数类型对象; -// - Argument/形式参数 IR 对象; -// - 入口块中的参数初始化逻辑。 -// ... - -// 因此这里目前只支持最小的“无参 int 函数”生成。 std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少函数定义")); } - if (!ctx->blockStmt()) { - throw std::runtime_error(FormatError("irgen", "函数体为空")); - } - if (!ctx->ID()) { + + // 使用 Ident() 而不是 ID() + if (!ctx->Ident()) { throw std::runtime_error(FormatError("irgen", "缺少函数名")); } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数")); + + std::string funcName = ctx->Ident()->getText(); + + // 检查函数体 - 使用 block() 而不是 blockStmt() + if (!ctx->block()) { + throw std::runtime_error(FormatError("irgen", "函数体为空")); } - - func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type()); + + // 检查返回类型 - 使用 Int() 而不是 INT() + if (!ctx->funcType() || !ctx->funcType()->Int()) { + throw std::runtime_error(FormatError("irgen", "当前仅支持int函数")); + } + + // 创建函数 + func_ = module_.CreateFunction(funcName, ir::Type::GetInt32Type()); builder_.SetInsertPoint(func_->GetEntry()); storage_map_.clear(); - - ctx->blockStmt()->accept(this); - // 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 + + // 生成函数体 - 使用 block() 而不是 blockStmt() + ctx->block()->accept(this); + + // 确保函数有返回值 + if (!func_->GetEntry()->HasTerminator()) { + auto retVal = builder_.CreateConstInt(0); + builder_.CreateRet(retVal); + } + VerifyFunctionStructure(*func_); return {}; } + +std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "缺少语句块")); + } + + for (auto* item : ctx->blockItem()) { + if (item) { + if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { + break; + } + } + } + + return {}; +} + +IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult( + SysYParser::BlockItemContext& item) { + return std::any_cast(item.accept(this)); +} + +std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "缺少块内项")); + } + if (ctx->stmt()) { + return ctx->stmt()->accept(this); + } + if (ctx->decl()) { + ctx->decl()->accept(this); + return BlockFlow::Continue; + } + throw std::runtime_error(FormatError("irgen", "暂不支持的块内项")); +} diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index 751550c..fef78ea 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -19,21 +19,70 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句")); } - if (ctx->returnStmt()) { - return ctx->returnStmt()->accept(this); + + // return 语句 - 通过 Return() 关键字判断 + if (ctx->Return()) { + return HandleReturnStmt(ctx); } + + // 块语句 + if (ctx->block()) { + return ctx->block()->accept(this); + } + + // 空语句或表达式语句(先计算表达式) + if (ctx->exp()) { + EvalExpr(*ctx->exp()); + return BlockFlow::Continue; + } + throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); } - -std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { +IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少 return 语句")); } - if (!ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "return 缺少表达式")); + + ir::Value* retValue = nullptr; + if (ctx->exp()) { + retValue = EvalExpr(*ctx->exp()); + } + // 如果没有表达式,返回0(对于int main) + if (!retValue) { + retValue = builder_.CreateConstInt(0); } - ir::Value* v = EvalExpr(*ctx->exp()); - builder_.CreateRet(v); + + builder_.CreateRet(retValue); return BlockFlow::Terminated; } + +// if语句(待实现) +IRGenImpl::BlockFlow IRGenImpl::HandleIfStmt(SysYParser::StmtContext* ctx) { + // TODO: 实现if语句 + throw std::runtime_error(FormatError("irgen", "if语句暂未实现")); +} + +// while语句(待实现) +IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) { + // TODO: 实现while语句 + throw std::runtime_error(FormatError("irgen", "while语句暂未实现")); +} + +// break语句(待实现) +IRGenImpl::BlockFlow IRGenImpl::HandleBreakStmt(SysYParser::StmtContext* ctx) { + // TODO: 实现break + throw std::runtime_error(FormatError("irgen", "break语句暂未实现")); +} + +// continue语句(待实现) +IRGenImpl::BlockFlow IRGenImpl::HandleContinueStmt(SysYParser::StmtContext* ctx) { + // TODO: 实现continue + throw std::runtime_error(FormatError("irgen", "continue语句暂未实现")); +} + +// 赋值语句(待实现) +IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) { + // TODO: 实现赋值 + throw std::runtime_error(FormatError("irgen", "赋值语句暂未实现")); +} diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index 745374c..11bf336 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "SysYBaseVisitor.h" #include "sem/SymbolTable.h" @@ -10,191 +11,1431 @@ namespace { -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("sema", "非法左值")); - } - return lvalue.ID()->getText(); +// 获取左值名称的辅助函数 +std::string GetLValueName(SysYParser::LValContext& lval) { + if (!lval.Ident()) { + throw std::runtime_error(FormatError("sema", "非法左值")); + } + return lval.Ident()->getText(); +} + +// 从 BTypeContext 获取类型 +std::shared_ptr GetTypeFromBType(SysYParser::BTypeContext* ctx) { + if (!ctx) return ir::Type::GetInt32Type(); + if (ctx->Int()) return ir::Type::GetInt32Type(); + if (ctx->Float()) return ir::Type::GetFloatType(); + return ir::Type::GetInt32Type(); } +// 语义分析 Visitor class SemaVisitor final : public SysYBaseVisitor { - public: - std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "缺少编译单元")); - } - auto* func = ctx->funcDef(); - if (!func || !func->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); - } - if (!func->ID() || func->ID()->getText() != "main") { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); - } - func->accept(this); - if (!seen_return_) { - throw std::runtime_error( - FormatError("sema", "main 函数必须包含 return 语句")); - } - return {}; - } - - std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { - if (!ctx || !ctx->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); - } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持 int main")); - } - const auto& items = ctx->blockStmt()->blockItem(); - if (items.empty()) { - throw std::runtime_error( - FormatError("sema", "main 函数不能为空,且必须以 return 结束")); - } - ctx->blockStmt()->accept(this); - return {}; - } - - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "缺少语句块")); - } - const auto& items = ctx->blockItem(); - for (size_t i = 0; i < items.size(); ++i) { - auto* item = items[i]; - if (!item) { - continue; - } - if (seen_return_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); - } - current_item_index_ = i; - total_items_ = items.size(); - item->accept(this); - } - return {}; - } - - std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - if (ctx->decl()) { - ctx->decl()->accept(this); - return {}; - } - if (ctx->stmt()) { - ctx->stmt()->accept(this); - return {}; - } - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - - std::any visitDecl(SysYParser::DeclContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); - } - if (!ctx->btype() || !ctx->btype()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明")); - } - auto* var_def = ctx->varDef(); - if (!var_def || !var_def->lValue()) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); - } - const std::string name = GetLValueName(*var_def->lValue()); - if (table_.Contains(name)) { - throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); - } - if (auto* init = var_def->initValue()) { - if (!init->exp()) { - throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化")); - } - init->exp()->accept(this); - } - table_.Add(name, var_def); - return {}; - } - - std::any visitStmt(SysYParser::StmtContext* ctx) override { - if (!ctx || !ctx->returnStmt()) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - ctx->returnStmt()->accept(this); - return {}; - } - - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "return 缺少表达式")); - } - ctx->exp()->accept(this); - seen_return_ = true; - if (current_item_index_ + 1 != total_items_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); - } - return {}; - } - - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "非法括号表达式")); - } - ctx->exp()->accept(this); - return {}; - } - - std::any visitVarExp(SysYParser::VarExpContext* ctx) override { - if (!ctx || !ctx->var()) { - throw std::runtime_error(FormatError("sema", "非法变量表达式")); - } - ctx->var()->accept(this); - return {}; - } - - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { - throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量")); - } - return {}; - } - - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { - throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式")); - } - ctx->exp(0)->accept(this); - ctx->exp(1)->accept(this); - return {}; - } - - std::any visitVar(SysYParser::VarContext* ctx) override { - if (!ctx || !ctx->ID()) { - throw std::runtime_error(FormatError("sema", "非法变量引用")); +public: + SemaVisitor() : table_() {} + + std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "缺少编译单元")); + } + table_.enterScope(); // 创建全局作用域 + for (auto* func : ctx->funcDef()) { // 收集所有函数声明(处理互相调用) + CollectFunctionDeclaration(func); + } + for (auto* decl : ctx->decl()) { // 处理所有声明和定义 + if (decl) decl->accept(this); + } + for (auto* func : ctx->funcDef()) { + if (func) func->accept(this); + } + CheckMainFunction(); // 检查 main 函数存在且正确 + table_.exitScope(); // 退出全局作用域 + return {}; + } + + std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "函数定义缺少标识符")); + } + std::string name = ctx->Ident()->getText(); + std::shared_ptr return_type; // 获取返回类型 + if (ctx->funcType()) { + if (ctx->funcType()->Void()) { + return_type = ir::Type::GetVoidType(); + } else if (ctx->funcType()->Int()) { + return_type = ir::Type::GetInt32Type(); + } else if (ctx->funcType()->Float()) { + return_type = ir::Type::GetFloatType(); + } else { + return_type = ir::Type::GetInt32Type(); + } + } else { + return_type = ir::Type::GetInt32Type(); + } + std::cout << "[DEBUG] 进入函数: " << name + << " 返回类型: " << (return_type->IsInt32() ? "int" : + return_type->IsFloat() ? "float" : "void") + << std::endl; + + // 记录当前函数返回类型(用于 return 检查) + current_func_return_type_ = return_type; + current_func_has_return_ = false; + + table_.enterScope(); + if (ctx->funcFParams()) { // 处理参数 + CollectFunctionParams(ctx->funcFParams()); + } + if (ctx->block()) { // 处理函数体 + ctx->block()->accept(this); + } + std::cout << "[DEBUG] 函数 " << name + << " has_return: " << current_func_has_return_ + << " return_type_is_void: " << return_type->IsVoid() + << std::endl; + if (!return_type->IsVoid() && !current_func_has_return_) { // 检查非 void 函数是否有 return + throw std::runtime_error(FormatError("sema", "非 void 函数 " + name + " 缺少 return 语句")); + } + table_.exitScope(); + + current_func_return_type_ = nullptr; + current_func_has_return_ = false; + return {}; + } + + std::any visitBlock(SysYParser::BlockContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "缺少语句块")); + } + table_.enterScope(); + for (auto* item : ctx->blockItem()) { // 处理所有 blockItem + if (item) { + item->accept(this); + // 如果已经有 return,可以继续(但 return 必须是最后一条) + // 注意:这里不需要跳出,因为 return 语句本身已经标记了 + } + } + table_.exitScope(); + return {}; + } + + std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + } + if (ctx->decl()) { + ctx->decl()->accept(this); + return {}; + } + if (ctx->stmt()) { + ctx->stmt()->accept(this); + return {}; + } + throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + } + + std::any visitDecl(SysYParser::DeclContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "非法变量声明")); + } + if (ctx->constDecl()) { + ctx->constDecl()->accept(this); + } else if (ctx->varDecl()) { + ctx->varDecl()->accept(this); + } + return {}; + } + + // ==================== 变量声明 ==================== + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override { + if (!ctx || !ctx->bType()) { + throw std::runtime_error(FormatError("sema", "非法变量声明")); + } + std::shared_ptr base_type = GetTypeFromBType(ctx->bType()); + bool is_global = (table_.currentScopeLevel() == 0); + for (auto* var_def : ctx->varDef()) { + if (var_def) { + CheckVarDef(var_def, base_type, is_global); + } + } + return {}; } - const std::string name = ctx->ID()->getText(); - auto* decl = table_.Lookup(name); - if (!decl) { - throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); - } - sema_.BindVarUse(ctx, decl); - return {}; - } - - SemanticContext TakeSemanticContext() { return std::move(sema_); } - private: - SymbolTable table_; - SemanticContext sema_; - bool seen_return_ = false; - size_t current_item_index_ = 0; - size_t total_items_ = 0; + void CheckVarDef(SysYParser::VarDefContext* ctx, + std::shared_ptr base_type, + bool is_global) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法变量定义")); + } + std::string name = ctx->Ident()->getText(); + if (table_.lookupCurrent(name)) { // 检查重复定义 + throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); + } + // 确定类型(处理数组维度) + std::shared_ptr type = base_type; + std::vector dims; + bool is_array = !ctx->constExp().empty(); + // 调试输出 + std::cout << "[DEBUG] CheckVarDef: " << name + << " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown") + << " is_array: " << is_array + << " dim_count: " << ctx->constExp().size() << std::endl; + if (is_array) { + // 处理数组维度 + for (auto* dim_exp : ctx->constExp()) { + int dim = EvaluateConstExp(dim_exp); + if (dim <= 0) { + throw std::runtime_error(FormatError("sema", "数组维度必须为正整数")); + } + dims.push_back(dim); + std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl; + } + // 创建数组类型 + type = ir::Type::GetArrayType(base_type, dims); + std::cout << "[DEBUG] 创建数组类型完成" << std::endl; + std::cout << "[DEBUG] type->IsArray(): " << type->IsArray() << std::endl; + std::cout << "[DEBUG] type->GetKind(): " << (int)type->GetKind() << std::endl; + // 验证数组类型 + if (type->IsArray()) { + auto* arr_type = dynamic_cast(type.get()); + if (arr_type) { + std::cout << "[DEBUG] ArrayType dimensions: "; + for (int d : arr_type->GetDimensions()) { + std::cout << d << " "; + } + std::cout << std::endl; + std::cout << "[DEBUG] Element type: " + << (arr_type->GetElementType()->IsInt32() ? "int" : + arr_type->GetElementType()->IsFloat() ? "float" : "unknown") + << std::endl; + } + } + } + bool has_init = (ctx->initVal() != nullptr); // 处理初始化 + if (is_global && has_init) { + CheckGlobalInitIsConst(ctx->initVal()); // 全局变量初始化必须是常量表达式 + } + // 创建符号 + Symbol sym; + sym.name = name; + sym.kind = SymbolKind::Variable; + sym.type = type; + sym.scope_level = table_.currentScopeLevel(); + sym.is_initialized = has_init; + sym.var_def_ctx = ctx; + if (is_array) { + // 存储维度信息,但 param_types 通常用于函数参数 + // 数组变量的维度信息已经包含在 type 中 + sym.param_types.clear(); // 确保不混淆 + } + table_.addSymbol(sym); // 添加到符号表 + std::cout << "[DEBUG] 符号添加完成: " << name + << " type_kind: " << (int)sym.type->GetKind() + << " is_array: " << sym.type->IsArray() + << std::endl; + } + + // ==================== 常量声明 ==================== + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override { + if (!ctx || !ctx->bType()) { + throw std::runtime_error(FormatError("sema", "非法常量声明")); + } + std::shared_ptr base_type = GetTypeFromBType(ctx->bType()); + for (auto* const_def : ctx->constDef()) { + if (const_def) { + CheckConstDef(const_def, base_type); + } + } + return {}; + } + + void CheckConstDef(SysYParser::ConstDefContext* ctx, + std::shared_ptr base_type) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法常量定义")); + } + std::string name = ctx->Ident()->getText(); + if (table_.lookupCurrent(name)) { + throw std::runtime_error(FormatError("sema", "重复定义常量: " + name)); + } + // 确定类型 + std::shared_ptr type = base_type; + std::vector dims; + bool is_array = !ctx->constExp().empty(); + std::cout << "[DEBUG] CheckConstDef: " << name + << " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown") + << " is_array: " << is_array + << " dim_count: " << ctx->constExp().size() << std::endl; + if (is_array) { + for (auto* dim_exp : ctx->constExp()) { + int dim = EvaluateConstExp(dim_exp); + if (dim <= 0) { + throw std::runtime_error(FormatError("sema", "数组维度必须为正整数")); + } + dims.push_back(dim); + std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl; + } + type = ir::Type::GetArrayType(base_type, dims); + std::cout << "[DEBUG] 创建数组类型完成,IsArray: " << type->IsArray() << std::endl; + } + // 求值初始化器 + std::vector init_values; + if (ctx->constInitVal()) { + init_values = EvaluateConstInitVal(ctx->constInitVal(), dims, base_type); + std::cout << "[DEBUG] 初始化值数量: " << init_values.size() << std::endl; + } + // 检查初始化值数量 + size_t expected_count = 1; + if (is_array) { + expected_count = 1; + for (int d : dims) expected_count *= d; + std::cout << "[DEBUG] 期望元素数量: " << expected_count << std::endl; + } + if (init_values.size() > expected_count) { + throw std::runtime_error(FormatError("sema", "初始化值过多")); + } + Symbol sym; + sym.name = name; + sym.kind = SymbolKind::Constant; + sym.type = type; + sym.scope_level = table_.currentScopeLevel(); + sym.is_initialized = true; + sym.var_def_ctx = nullptr; + // 存储常量值(仅对非数组有效) + if (!is_array && !init_values.empty()) { + if (base_type->IsInt32() && init_values[0].is_int) { + sym.is_int_const = true; + sym.const_value.i32 = init_values[0].int_val; + std::cout << "[DEBUG] 存储整型常量值: " << init_values[0].int_val << std::endl; + } else if (base_type->IsFloat() && !init_values[0].is_int) { + sym.is_int_const = false; + sym.const_value.f32 = init_values[0].float_val; + std::cout << "[DEBUG] 存储浮点常量值: " << init_values[0].float_val << std::endl; + } + } else if (is_array) { + std::cout << "[DEBUG] 数组常量,不存储单个常量值" << std::endl; + } + table_.addSymbol(sym); + std::cout << "[DEBUG] 常量符号添加完成" << std::endl; + } + + // ==================== 语句语义检查 ==================== + + // 处理所有语句 - 通过运行时类型判断 + std::any visitStmt(SysYParser::StmtContext* ctx) override { + if (!ctx) return {}; + // 调试输出 + std::cout << "[DEBUG] visitStmt: "; + if (ctx->Return()) std::cout << "Return "; + if (ctx->If()) std::cout << "If "; + if (ctx->While()) std::cout << "While "; + if (ctx->Break()) std::cout << "Break "; + if (ctx->Continue()) std::cout << "Continue "; + if (ctx->lVal() && ctx->Assign()) std::cout << "Assign "; + if (ctx->exp() && ctx->Semi()) std::cout << "ExpStmt "; + if (ctx->block()) std::cout << "Block "; + std::cout << std::endl; + // 判断语句类型 - 注意:Return() 返回的是 TerminalNode* + if (ctx->Return() != nullptr) { + // return 语句 + std::cout << "[DEBUG] 检测到 return 语句" << std::endl; + return visitReturnStmtInternal(ctx); + } else if (ctx->lVal() != nullptr && ctx->Assign() != nullptr) { + // 赋值语句 + return visitAssignStmt(ctx); + } else if (ctx->exp() != nullptr && ctx->Semi() != nullptr) { + // 表达式语句(可能有表达式) + return visitExpStmt(ctx); + } else if (ctx->block() != nullptr) { + // 块语句 + return ctx->block()->accept(this); + } else if (ctx->If() != nullptr) { + // if 语句 + return visitIfStmtInternal(ctx); + } else if (ctx->While() != nullptr) { + // while 语句 + return visitWhileStmtInternal(ctx); + } else if (ctx->Break() != nullptr) { + // break 语句 + return visitBreakStmtInternal(ctx); + } else if (ctx->Continue() != nullptr) { + // continue 语句 + return visitContinueStmtInternal(ctx); + } + return {}; + } + + // return 语句内部实现 + std::any visitReturnStmtInternal(SysYParser::StmtContext* ctx) { + std::cout << "[DEBUG] visitReturnStmtInternal 被调用" << std::endl; + std::shared_ptr expected = current_func_return_type_; + if (!expected) { + throw std::runtime_error(FormatError("sema", "return 语句不在函数体内")); + } + if (ctx->exp() != nullptr) { + // 有返回值的 return + std::cout << "[DEBUG] 有返回值的 return" << std::endl; + ExprInfo ret_val = CheckExp(ctx->exp()); + if (expected->IsVoid()) { + throw std::runtime_error(FormatError("sema", "void 函数不能返回值")); + } else if (!IsTypeCompatible(ret_val.type, expected)) { + throw std::runtime_error(FormatError("sema", "返回值类型不匹配")); + } + // 标记需要隐式转换 + if (ret_val.type != expected) { + sema_.AddConversion(ctx->exp(), ret_val.type, expected); + } + // 设置 has_return 标志 + current_func_has_return_ = true; + std::cout << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl; + } else { + // 无返回值的 return + std::cout << "[DEBUG] 无返回值的 return" << std::endl; + if (!expected->IsVoid()) { + throw std::runtime_error(FormatError("sema", "非 void 函数必须返回值")); + } + // 设置 has_return 标志 + current_func_has_return_ = true; + std::cout << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl; + } + return {}; + } + + // 左值表达式(变量引用) + std::any visitLVal(SysYParser::LValContext* ctx) override { + std::cout << "[DEBUG] visitLVal: " << ctx->getText() << std::endl; + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法变量引用")); + } + std::string name = ctx->Ident()->getText(); + auto* sym = table_.lookup(name); + if (!sym) { + throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); + } + // ========== 关键修复:绑定变量使用到定义 ========== + if (sym) { + std::cerr << "[DEBUG] 找到符号: " << sym->name + << ", kind: " << (int)sym->kind + << ", var_def_ctx: " << sym->var_def_ctx << std::endl; + if (sym->var_def_ctx) { + std::cout << "[DEBUG] 绑定变量使用" << std::endl; + sema_.BindVarUse(ctx, sym->var_def_ctx); + } + } + else if (sym->kind == SymbolKind::Parameter) { + // 对于函数参数,需要特殊处理 + // 参数可能没有对应的 VarDefContext,需要创建一个 + // 或者通过其他方式标识 + std::cout << "[DEBUG] 参数变量: " << name << " (无法绑定到 VarDefContext)" << std::endl; + // 可以创建一个临时标识,但这里先不处理 + } + // ============================================ + // 检查数组访问 + bool is_array_access = !ctx->exp().empty(); + std::cout << "[DEBUG] name: " << name + << ", is_array_access: " << is_array_access + << ", subscript_count: " << ctx->exp().size() << std::endl; + ExprInfo result; + // 判断是否为数组类型或指针类型(数组参数) + bool is_array_or_ptr = false; + if (sym->type) { + is_array_or_ptr = sym->type->IsArray() || sym->type->IsPtrInt32() || sym->type->IsPtrFloat(); + std::cout << "[DEBUG] type_kind: " << (int)sym->type->GetKind() + << ", is_array: " << sym->type->IsArray() + << ", is_ptr: " << (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) << std::endl; + } + + if (is_array_or_ptr) { + // 获取维度信息 + size_t dim_count = 0; + std::shared_ptr elem_type = sym->type; + if (sym->type->IsArray()) { + if (auto* arr_type = dynamic_cast(sym->type.get())) { + dim_count = arr_type->GetDimensions().size(); + elem_type = arr_type->GetElementType(); + std::cout << "[DEBUG] 数组维度: " << dim_count << std::endl; + } + } else if (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) { + dim_count = 1; + if (sym->type->IsPtrInt32()) { + elem_type = ir::Type::GetInt32Type(); + } else if (sym->type->IsPtrFloat()) { + elem_type = ir::Type::GetFloatType(); + } + std::cout << "[DEBUG] 指针类型, dim_count: 1" << std::endl; + } + + if (is_array_access) { + std::cout << "[DEBUG] 有下标访问,期望维度: " << dim_count + << ", 实际下标数: " << ctx->exp().size() << std::endl; + if (ctx->exp().size() != dim_count) { + throw std::runtime_error(FormatError("sema", "数组下标个数不匹配")); + } + for (auto* idx_exp : ctx->exp()) { + ExprInfo idx = CheckExp(idx_exp); + if (!idx.type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "数组下标必须是 int 类型")); + } + } + result.type = elem_type; + result.is_lvalue = true; + result.is_const = false; + } else { + std::cout << "[DEBUG] 无下标访问" << std::endl; + if (sym->type->IsArray()) { + std::cout << "[DEBUG] 数组名作为地址,转换为指针" << std::endl; + if (auto* arr_type = dynamic_cast(sym->type.get())) { + if (arr_type->GetElementType()->IsInt32()) { + result.type = ir::Type::GetPtrInt32Type(); + } else if (arr_type->GetElementType()->IsFloat()) { + result.type = ir::Type::GetPtrFloatType(); + } else { + result.type = ir::Type::GetPtrInt32Type(); + } + } else { + result.type = ir::Type::GetPtrInt32Type(); + } + result.is_lvalue = false; + result.is_const = true; + } else { + result.type = sym->type; + result.is_lvalue = true; + result.is_const = (sym->kind == SymbolKind::Constant); + } + } + } else { + if (is_array_access) { + throw std::runtime_error(FormatError("sema", "非数组变量不能使用下标: " + name)); + } + result.type = sym->type; + result.is_lvalue = true; + result.is_const = (sym->kind == SymbolKind::Constant); + if (result.is_const && sym->type && !sym->type->IsArray()) { + if (sym->is_int_const) { + result.is_const_int = true; + result.const_int_value = sym->const_value.i32; + } else { + result.const_float_value = sym->const_value.f32; + } + } + } + sema_.SetExprType(ctx, result); + return {}; + } + + // if 语句内部实现 + std::any visitIfStmtInternal(SysYParser::StmtContext* ctx) { + // 检查条件表达式 + if (ctx->cond()) { + ExprInfo cond = CheckCond(ctx->cond()); // CheckCond 已经处理了类型转换 + // 不需要额外检查,因为 CheckCond 已经确保类型正确 + } + // 处理 then 分支 + if (ctx->stmt().size() > 0) { + ctx->stmt()[0]->accept(this); + } + // 处理 else 分支 + if (ctx->stmt().size() > 1) { + ctx->stmt()[1]->accept(this); + } + return {}; + } + + // while 语句内部实现 + std::any visitWhileStmtInternal(SysYParser::StmtContext* ctx) { + if (ctx->cond()) { + ExprInfo cond = CheckCond(ctx->cond()); // CheckCond 已经处理了类型转换 + // 不需要额外检查 + } + loop_stack_.push_back({true, ctx}); + if (ctx->stmt().size() > 0) { + ctx->stmt()[0]->accept(this); + } + loop_stack_.pop_back(); + return {}; + } + + // break 语句内部实现 + std::any visitBreakStmtInternal(SysYParser::StmtContext* ctx) { + if (loop_stack_.empty() || !loop_stack_.back().in_loop) { + throw std::runtime_error(FormatError("sema", "break 语句必须在循环体内使用")); + } + return {}; + } + + // continue 语句内部实现 + std::any visitContinueStmtInternal(SysYParser::StmtContext* ctx) { + if (loop_stack_.empty() || !loop_stack_.back().in_loop) { + throw std::runtime_error(FormatError("sema", "continue 语句必须在循环体内使用")); + } + return {}; + } + + // 赋值语句内部实现 + std::any visitAssignStmt(SysYParser::StmtContext* ctx) { + if (!ctx->lVal() || !ctx->exp()) { + throw std::runtime_error(FormatError("sema", "非法赋值语句")); + } + ExprInfo lvalue = CheckLValue(ctx->lVal()); // 检查左值 + if (lvalue.is_const) { + throw std::runtime_error(FormatError("sema", "不能给常量赋值")); + } + if (!lvalue.is_lvalue) { + throw std::runtime_error(FormatError("sema", "赋值左边必须是左值")); + } + ExprInfo rvalue = CheckExp(ctx->exp()); // 检查右值 + if (!IsTypeCompatible(rvalue.type, lvalue.type)) { + throw std::runtime_error(FormatError("sema", "赋值类型不匹配")); + } + if (rvalue.type != lvalue.type) { // 标记需要隐式转换 + sema_.AddConversion(ctx->exp(), rvalue.type, lvalue.type); + } + return {}; + } + + // 表达式语句内部实现 + std::any visitExpStmt(SysYParser::StmtContext* ctx) { + if (ctx->exp()) { + CheckExp(ctx->exp()); + } + return {}; + } + + // ==================== 表达式类型推导 ==================== + + // 主表达式 + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override { + std::cout << "[DEBUG] visitPrimaryExp: " << ctx->getText() << std::endl; + ExprInfo result; + if (ctx->lVal()) { // 左值表达式 + result = CheckLValue(ctx->lVal()); + result.is_lvalue = true; + } else if (ctx->HEX_FLOAT() || ctx->DEC_FLOAT()) { // 浮点字面量 + result.type = ir::Type::GetFloatType(); + result.is_const = true; + result.is_const_int = false; + std::string text; + if (ctx->HEX_FLOAT()) text = ctx->HEX_FLOAT()->getText(); + else text = ctx->DEC_FLOAT()->getText(); + result.const_float_value = std::stof(text); + } else if (ctx->HEX_INT() || ctx->OCTAL_INT() || ctx->DECIMAL_INT() || ctx->ZERO()) { // 整数字面量 + result.type = ir::Type::GetInt32Type(); + result.is_const = true; + result.is_const_int = true; + std::string text; + if (ctx->HEX_INT()) text = ctx->HEX_INT()->getText(); + else if (ctx->OCTAL_INT()) text = ctx->OCTAL_INT()->getText(); + else if (ctx->DECIMAL_INT()) text = ctx->DECIMAL_INT()->getText(); + else text = ctx->ZERO()->getText(); + result.const_int_value = std::stoi(text, nullptr, 0); + } else if (ctx->exp()) { // 括号表达式 + result = CheckExp(ctx->exp()); + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 一元表达式 + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override { + std::cout << "[DEBUG] visitUnaryExp: " << ctx->getText() << std::endl; + ExprInfo result; + if (ctx->primaryExp()) { + ctx->primaryExp()->accept(this); + auto* info = sema_.GetExprType(ctx->primaryExp()); + if (info) result = *info; + } else if (ctx->Ident() && ctx->L_PAREN()) { // 函数调用 + std::cout << "[DEBUG] 函数调用: " << ctx->Ident()->getText() << std::endl; + result = CheckFuncCall(ctx); + } else if (ctx->unaryOp()) { // 一元运算 + ctx->unaryExp()->accept(this); + auto* operand = sema_.GetExprType(ctx->unaryExp()); + if (!operand) { + throw std::runtime_error(FormatError("sema", "一元操作数类型推导失败")); + result.type = ir::Type::GetInt32Type(); + } else { + std::string op = ctx->unaryOp()->getText(); + if (op == "!") { + // 逻辑非:要求操作数是 int 类型,或者可以转换为 int 的 float + if (operand->type->IsInt32()) { + // 已经是 int,没问题 + } else if (operand->type->IsFloat()) { + // float 可以隐式转换为 int + sema_.AddConversion(ctx->unaryExp(), operand->type, ir::Type::GetInt32Type()); + // 更新操作数类型为 int + operand->type = ir::Type::GetInt32Type(); + operand->is_const_int = true; + if (operand->is_const && !operand->is_const_int) { + // 如果原来是 float 常量,转换为 int 常量 + operand->const_int_value = (int)operand->const_float_value; + operand->is_const_int = true; + } + } else { + throw std::runtime_error(FormatError("sema", "逻辑非操作数必须是 int 类型或可以转换为 int 的 float 类型")); + } + result.type = ir::Type::GetInt32Type(); + result.is_lvalue = false; + result.is_const = operand->is_const; + if (operand->is_const && operand->is_const_int) { + result.is_const_int = true; + result.const_int_value = (operand->const_int_value == 0) ? 1 : 0; + } + } else { + // 正负号 + if (!operand->type->IsInt32() && !operand->type->IsFloat()) { + throw std::runtime_error(FormatError("sema", "正负号操作数必须是算术类型")); + } + result.type = operand->type; + result.is_lvalue = false; + result.is_const = operand->is_const; + if (op == "-" && operand->is_const) { + if (operand->type->IsInt32() && operand->is_const_int) { + result.is_const_int = true; + result.const_int_value = -operand->const_int_value; + } else if (operand->type->IsFloat()) { + result.const_float_value = -operand->const_float_value; + } + } + } + } + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 乘除模表达式 + std::any visitMulExp(SysYParser::MulExpContext* ctx) override { + ExprInfo result; + if (ctx->mulExp()) { + ctx->mulExp()->accept(this); + ctx->unaryExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->mulExp()); + auto* right_info = sema_.GetExprType(ctx->unaryExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "乘除模操作数类型推导失败")); + result.type = ir::Type::GetInt32Type(); + } else { + std::string op; + if (ctx->MulOp()) { + op = "*"; + } else if (ctx->DivOp()) { + op = "/"; + } else if (ctx->QuoOp()) { + op = "%"; + } + result = CheckBinaryOp(left_info, right_info, op, ctx); + } + } else { + ctx->unaryExp()->accept(this); + auto* info = sema_.GetExprType(ctx->unaryExp()); + if (info) { + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 加减表达式 + std::any visitAddExp(SysYParser::AddExpContext* ctx) override { + ExprInfo result; + if (ctx->addExp()) { + ctx->addExp()->accept(this); + ctx->mulExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->addExp()); + auto* right_info = sema_.GetExprType(ctx->mulExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "加减操作数类型推导失败")); + result.type = ir::Type::GetInt32Type(); + } else { + std::string op; + if (ctx->AddOp()) { + op = "+"; + } else if (ctx->SubOp()) { + op = "-"; + } + result = CheckBinaryOp(left_info, right_info, op, ctx); + } + } else { + ctx->mulExp()->accept(this); + auto* info = sema_.GetExprType(ctx->mulExp()); + if (info) { + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 关系表达式 + std::any visitRelExp(SysYParser::RelExpContext* ctx) override { + ExprInfo result; + if (ctx->relExp()) { + ctx->relExp()->accept(this); + ctx->addExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->relExp()); + auto* right_info = sema_.GetExprType(ctx->addExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "关系操作数类型推导失败")); + } else { + if (!left_info->type->IsInt32() && !left_info->type->IsFloat()) { + throw std::runtime_error(FormatError("sema", "关系运算操作数必须是算术类型")); + } + std::string op; + if (ctx->LOp()) { + op = "<"; + } else if (ctx->GOp()) { + op = ">"; + } else if (ctx->LeOp()) { + op = "<="; + } else if (ctx->GeOp()) { + op = ">="; + } + result.type = ir::Type::GetInt32Type(); + result.is_lvalue = false; + if (left_info->is_const && right_info->is_const) { + result.is_const = true; + result.is_const_int = true; + float l = GetFloatValue(*left_info); + float r = GetFloatValue(*right_info); + if (op == "<") result.const_int_value = (l < r) ? 1 : 0; + else if (op == ">") result.const_int_value = (l > r) ? 1 : 0; + else if (op == "<=") result.const_int_value = (l <= r) ? 1 : 0; + else if (op == ">=") result.const_int_value = (l >= r) ? 1 : 0; + } + } + } else { + ctx->addExp()->accept(this); + auto* info = sema_.GetExprType(ctx->addExp()); + if (info) { + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 相等性表达式 + std::any visitEqExp(SysYParser::EqExpContext* ctx) override { + ExprInfo result; + if (ctx->eqExp()) { + ctx->eqExp()->accept(this); + ctx->relExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->eqExp()); + auto* right_info = sema_.GetExprType(ctx->relExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "相等性操作数类型推导失败")); + } else { + std::string op; + if (ctx->EqOp()) { + op = "=="; + } else if (ctx->NeOp()) { + op = "!="; + } + result.type = ir::Type::GetInt32Type(); + result.is_lvalue = false; + if (left_info->is_const && right_info->is_const) { + result.is_const = true; + result.is_const_int = true; + float l = GetFloatValue(*left_info); + float r = GetFloatValue(*right_info); + if (op == "==") result.const_int_value = (l == r) ? 1 : 0; + else if (op == "!=") result.const_int_value = (l != r) ? 1 : 0; + } + } + } else { + ctx->relExp()->accept(this); + auto* info = sema_.GetExprType(ctx->relExp()); + if (info) { + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 逻辑与表达式 + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override { + ExprInfo result; + if (ctx->lAndExp()) { + ctx->lAndExp()->accept(this); + ctx->eqExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->lAndExp()); + auto* right_info = sema_.GetExprType(ctx->eqExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "逻辑与操作数类型推导失败")); + } else { + // 处理左操作数 + if (left_info->type->IsInt32()) { + // 已经是 int,没问题 + } else if (left_info->type->IsFloat()) { + // float 可以隐式转换为 int + sema_.AddConversion(ctx->lAndExp(), left_info->type, ir::Type::GetInt32Type()); + left_info->type = ir::Type::GetInt32Type(); + if (left_info->is_const && !left_info->is_const_int) { + left_info->const_int_value = (int)left_info->const_float_value; + left_info->is_const_int = true; + } + } else { + throw std::runtime_error(FormatError("sema", "逻辑与左操作数必须是 int 类型或可以转换为 int 的 float 类型")); + } + + // 处理右操作数 + if (right_info->type->IsInt32()) { + // 已经是 int,没问题 + } else if (right_info->type->IsFloat()) { + // float 可以隐式转换为 int + sema_.AddConversion(ctx->eqExp(), right_info->type, ir::Type::GetInt32Type()); + right_info->type = ir::Type::GetInt32Type(); + if (right_info->is_const && !right_info->is_const_int) { + right_info->const_int_value = (int)right_info->const_float_value; + right_info->is_const_int = true; + } + } else { + throw std::runtime_error(FormatError("sema", "逻辑与右操作数必须是 int 类型或可以转换为 int 的 float 类型")); + } + + result.type = ir::Type::GetInt32Type(); + result.is_lvalue = false; + if (left_info->is_const && right_info->is_const && + left_info->is_const_int && right_info->is_const_int) { + result.is_const = true; + result.is_const_int = true; + result.const_int_value = + (left_info->const_int_value && right_info->const_int_value) ? 1 : 0; + } + } + } else { + ctx->eqExp()->accept(this); + auto* info = sema_.GetExprType(ctx->eqExp()); + if (info) { + // 对于单个操作数,也需要确保类型是 int(用于条件表达式) + if (info->type->IsFloat()) { + sema_.AddConversion(ctx->eqExp(), info->type, ir::Type::GetInt32Type()); + info->type = ir::Type::GetInt32Type(); + if (info->is_const && !info->is_const_int) { + info->const_int_value = (int)info->const_float_value; + info->is_const_int = true; + } + } else if (!info->type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "逻辑与操作数必须是 int 类型或可以转换为 int 的 float 类型")); + } + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 逻辑或表达式 + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override { + ExprInfo result; + if (ctx->lOrExp()) { + ctx->lOrExp()->accept(this); + ctx->lAndExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->lOrExp()); + auto* right_info = sema_.GetExprType(ctx->lAndExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "逻辑或操作数类型推导失败")); + } else { + // 处理左操作数 + if (left_info->type->IsInt32()) { + // 已经是 int,没问题 + } else if (left_info->type->IsFloat()) { + // float 可以隐式转换为 int + sema_.AddConversion(ctx->lOrExp(), left_info->type, ir::Type::GetInt32Type()); + left_info->type = ir::Type::GetInt32Type(); + if (left_info->is_const && !left_info->is_const_int) { + left_info->const_int_value = (int)left_info->const_float_value; + left_info->is_const_int = true; + } + } else { + throw std::runtime_error(FormatError("sema", "逻辑或左操作数必须是 int 类型或可以转换为 int 的 float 类型")); + } + + // 处理右操作数 + if (right_info->type->IsInt32()) { + // 已经是 int,没问题 + } else if (right_info->type->IsFloat()) { + // float 可以隐式转换为 int + sema_.AddConversion(ctx->lAndExp(), right_info->type, ir::Type::GetInt32Type()); + right_info->type = ir::Type::GetInt32Type(); + if (right_info->is_const && !right_info->is_const_int) { + right_info->const_int_value = (int)right_info->const_float_value; + right_info->is_const_int = true; + } + } else { + throw std::runtime_error(FormatError("sema", "逻辑或右操作数必须是 int 类型或可以转换为 int 的 float 类型")); + } + + result.type = ir::Type::GetInt32Type(); + result.is_lvalue = false; + if (left_info->is_const && right_info->is_const && + left_info->is_const_int && right_info->is_const_int) { + result.is_const = true; + result.is_const_int = true; + result.const_int_value = + (left_info->const_int_value || right_info->const_int_value) ? 1 : 0; + } + } + } else { + ctx->lAndExp()->accept(this); + auto* info = sema_.GetExprType(ctx->lAndExp()); + if (info) { + // 对于单个操作数,也需要确保类型是 int(用于条件表达式) + if (info->type->IsFloat()) { + sema_.AddConversion(ctx->lAndExp(), info->type, ir::Type::GetInt32Type()); + info->type = ir::Type::GetInt32Type(); + if (info->is_const && !info->is_const_int) { + info->const_int_value = (int)info->const_float_value; + info->is_const_int = true; + } + } else if (!info->type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "逻辑或操作数必须是 int 类型或可以转换为 int 的 float 类型")); + } + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 获取语义上下文 + SemanticContext TakeSemanticContext() { return std::move(sema_); } + +private: + SymbolTable table_; + SemanticContext sema_; + struct LoopContext { + bool in_loop; + antlr4::ParserRuleContext* loop_node; + }; + std::vector loop_stack_; + std::shared_ptr current_func_return_type_ = nullptr; + bool current_func_has_return_ = false; + + // ==================== 辅助函数 ==================== + + ExprInfo CheckExp(SysYParser::ExpContext* ctx) { + if (!ctx || !ctx->addExp()) { + throw std::runtime_error(FormatError("sema", "无效表达式")); + } + std::cout << "[DEBUG] CheckExp: " << ctx->getText() << std::endl; + ctx->addExp()->accept(this); + auto* info = sema_.GetExprType(ctx->addExp()); + if (!info) { + throw std::runtime_error(FormatError("sema", "表达式类型推导失败")); + } + ExprInfo result = *info; + sema_.SetExprType(ctx, result); + return result; + } + + // 专门用于检查 AddExp 的辅助函数(用于常量表达式) + ExprInfo CheckAddExp(SysYParser::AddExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "无效表达式")); + } + ctx->accept(this); + auto* info = sema_.GetExprType(ctx); + if (!info) { + throw std::runtime_error(FormatError("sema", "表达式类型推导失败")); + } + return *info; + } + + ExprInfo CheckCond(SysYParser::CondContext* ctx) { + if (!ctx || !ctx->lOrExp()) { + throw std::runtime_error(FormatError("sema", "无效条件表达式")); + } + ctx->lOrExp()->accept(this); + auto* info = sema_.GetExprType(ctx->lOrExp()); + if (!info) { + throw std::runtime_error(FormatError("sema", "条件表达式类型推导失败")); + } + ExprInfo result = *info; + // 条件表达式的结果必须是 int,如果是 float 则需要转换 + // 注意:lOrExp 已经处理了类型转换,这里只是再检查一次 + if (!result.type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "条件表达式必须是 int 类型")); + } + return result; + } + + ExprInfo CheckLValue(SysYParser::LValContext* ctx) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法左值")); + } + std::string name = ctx->Ident()->getText(); + auto* sym = table_.lookup(name); + if (!sym) { + throw std::runtime_error(FormatError("sema", "未定义的变量: " + name)); + } + // ========== 添加绑定 ========== + if (sym->var_def_ctx) { + std::cout << "[DEBUG] CheckLValue 绑定变量: " << name << std::endl; + sema_.BindVarUse(ctx, sym->var_def_ctx); + } + // ============================ + + bool is_array_access = !ctx->exp().empty(); + bool is_const = (sym->kind == SymbolKind::Constant); + bool is_array_or_ptr = false; + + if (sym->type) { + is_array_or_ptr = sym->type->IsArray() || sym->type->IsPtrInt32() || sym->type->IsPtrFloat(); + } + + size_t dim_count = 0; + std::shared_ptr elem_type = sym->type; + + if (sym->type && sym->type->IsArray()) { + if (auto* arr_type = dynamic_cast(sym->type.get())) { + dim_count = arr_type->GetDimensions().size(); + elem_type = arr_type->GetElementType(); + } + } else if (sym->type && (sym->type->IsPtrInt32() || sym->type->IsPtrFloat())) { + dim_count = 1; + if (sym->type->IsPtrInt32()) { + elem_type = ir::Type::GetInt32Type(); + } else if (sym->type->IsPtrFloat()) { + elem_type = ir::Type::GetFloatType(); + } + } + + size_t subscript_count = ctx->exp().size(); + + if (is_array_or_ptr) { + if (subscript_count > 0) { + // 有下标访问 + if (subscript_count != dim_count) { + throw std::runtime_error(FormatError("sema", "数组下标个数不匹配")); + } + for (auto* idx_exp : ctx->exp()) { + ExprInfo idx = CheckExp(idx_exp); + if (!idx.type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "数组下标必须是 int 类型")); + } + } + return {elem_type, true, false}; + } else { + // 没有下标访问 + if (sym->type->IsArray()) { + // 数组名作为地址(右值) + if (auto* arr_type = dynamic_cast(sym->type.get())) { + if (arr_type->GetElementType()->IsInt32()) { + return {ir::Type::GetPtrInt32Type(), false, true}; + } else if (arr_type->GetElementType()->IsFloat()) { + return {ir::Type::GetPtrFloatType(), false, true}; + } + } + return {ir::Type::GetPtrInt32Type(), false, true}; + } else { + // 指针类型(如函数参数)可以不带下标使用 + return {sym->type, true, is_const}; + } + } + } else { + if (subscript_count > 0) { + throw std::runtime_error(FormatError("sema", "非数组变量不能使用下标: " + name)); + } + return {sym->type, true, is_const}; + } + } + + ExprInfo CheckFuncCall(SysYParser::UnaryExpContext* ctx) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法函数调用")); + } + std::string func_name = ctx->Ident()->getText(); + std::cout << "[DEBUG] CheckFuncCall: " << func_name << std::endl; + auto* func_sym = table_.lookup(func_name); + if (!func_sym || func_sym->kind != SymbolKind::Function) { + throw std::runtime_error(FormatError("sema", "未定义的函数: " + func_name)); + } + std::vector args; + if (ctx->funcRParams()) { + std::cout << "[DEBUG] 处理函数调用参数:" << std::endl; + for (auto* exp : ctx->funcRParams()->exp()) { + if (exp) { + args.push_back(CheckExp(exp)); + } + } + } + if (args.size() != func_sym->param_types.size()) { + throw std::runtime_error(FormatError("sema", "参数个数不匹配")); + } + for (size_t i = 0; i < std::min(args.size(), func_sym->param_types.size()); ++i) { + std::cout << "[DEBUG] 检查参数 " << i << ": 实参类型 " << (int)args[i].type->GetKind() + << " 形参类型 " << (int)func_sym->param_types[i]->GetKind() << std::endl; + if (!IsTypeCompatible(args[i].type, func_sym->param_types[i])) { + throw std::runtime_error(FormatError("sema", "参数类型不匹配")); + } + if (args[i].type != func_sym->param_types[i] && ctx->funcRParams() && + i < ctx->funcRParams()->exp().size()) { + sema_.AddConversion(ctx->funcRParams()->exp()[i], + args[i].type, func_sym->param_types[i]); + } + } + std::shared_ptr return_type; + if (func_sym->type && func_sym->type->IsFunction()) { + auto* func_type = dynamic_cast(func_sym->type.get()); + if (func_type) { + return_type = func_type->GetReturnType(); + } + } + if (!return_type) { + return_type = ir::Type::GetInt32Type(); + } + ExprInfo result; + result.type = return_type; + result.is_lvalue = false; + result.is_const = false; + return result; + } + + ExprInfo CheckBinaryOp(const ExprInfo* left, const ExprInfo* right, + const std::string& op, antlr4::ParserRuleContext* ctx) { + ExprInfo result; + if (!left->type->IsInt32() && !left->type->IsFloat()) { + throw std::runtime_error(FormatError("sema", "左操作数必须是算术类型")); + } + if (!right->type->IsInt32() && !right->type->IsFloat()) { + throw std::runtime_error(FormatError("sema", "右操作数必须是算术类型")); + } + if (op == "%" && (!left->type->IsInt32() || !right->type->IsInt32())) { + throw std::runtime_error(FormatError("sema", "取模运算要求操作数为 int 类型")); + } + if (left->type->IsFloat() || right->type->IsFloat()) { + result.type = ir::Type::GetFloatType(); + } else { + result.type = ir::Type::GetInt32Type(); + } + result.is_lvalue = false; + if (left->is_const && right->is_const) { + result.is_const = true; + float l = GetFloatValue(*left); + float r = GetFloatValue(*right); + if (result.type->IsInt32()) { + result.is_const_int = true; + int li = (int)l, ri = (int)r; + if (op == "*") result.const_int_value = li * ri; + else if (op == "/") result.const_int_value = li / ri; + else if (op == "%") result.const_int_value = li % ri; + else if (op == "+") result.const_int_value = li + ri; + else if (op == "-") result.const_int_value = li - ri; + } else { + if (op == "*") result.const_float_value = l * r; + else if (op == "/") result.const_float_value = l / r; + else if (op == "+") result.const_float_value = l + r; + else if (op == "-") result.const_float_value = l - r; + } + } + return result; + } + + float GetFloatValue(const ExprInfo& info) { + if (info.type->IsInt32() && info.is_const_int) { + return (float)info.const_int_value; + } else { + return info.const_float_value; + } + } + + bool IsTypeCompatible(std::shared_ptr src, std::shared_ptr dst) { + if (src == dst) return true; + if (src->IsInt32() && dst->IsFloat()) return true; + if (src->IsFloat() && dst->IsInt32()) return true; + return false; + } + + void CollectFunctionDeclaration(SysYParser::FuncDefContext* ctx) { + if (!ctx || !ctx->Ident()) return; + std::string name = ctx->Ident()->getText(); + if (table_.lookup(name)) return; + std::shared_ptr ret_type; + if (ctx->funcType()) { + if (ctx->funcType()->Void()) { + ret_type = ir::Type::GetVoidType(); + } else if (ctx->funcType()->Int()) { + ret_type = ir::Type::GetInt32Type(); + } else if (ctx->funcType()->Float()) { + ret_type = ir::Type::GetFloatType(); + } + } + if (!ret_type) ret_type = ir::Type::GetInt32Type(); + std::vector> param_types; + if (ctx->funcFParams()) { + for (auto* param : ctx->funcFParams()->funcFParam()) { + if (!param) continue; + std::shared_ptr param_type; + if (param->bType()) { + if (param->bType()->Int()) { + param_type = ir::Type::GetInt32Type(); + } else if (param->bType()->Float()) { + param_type = ir::Type::GetFloatType(); + } + } + if (!param_type) param_type = ir::Type::GetInt32Type(); + if (!param->L_BRACK().empty()) { + if (param_type->IsInt32()) { + param_type = ir::Type::GetPtrInt32Type(); + } else if (param_type->IsFloat()) { + param_type = ir::Type::GetPtrFloatType(); + } + } + param_types.push_back(param_type); + } + } + + // 创建函数类型 + std::shared_ptr func_type = ir::Type::GetFunctionType(ret_type, param_types); + + // 创建函数符号 + Symbol sym; + sym.name = name; + sym.kind = SymbolKind::Function; + sym.type = func_type; + sym.param_types = param_types; + sym.scope_level = 0; + sym.is_initialized = true; + sym.func_def_ctx = ctx; + + table_.addSymbol(sym); + } + + void CollectFunctionParams(SysYParser::FuncFParamsContext* ctx) { + if (!ctx) return; + for (auto* param : ctx->funcFParam()) { + if (!param || !param->Ident()) continue; + std::string name = param->Ident()->getText(); + if (table_.lookupCurrent(name)) { + throw std::runtime_error(FormatError("sema", "重复定义参数: " + name)); + } + std::shared_ptr param_type; + if (param->bType()) { + if (param->bType()->Int()) { + param_type = ir::Type::GetInt32Type(); + } else if (param->bType()->Float()) { + param_type = ir::Type::GetFloatType(); + } + } + if (!param_type) param_type = ir::Type::GetInt32Type(); + bool is_array = !param->L_BRACK().empty(); + if (is_array) { + if (param_type->IsInt32()) { + param_type = ir::Type::GetPtrInt32Type(); + } else if (param_type->IsFloat()) { + param_type = ir::Type::GetPtrFloatType(); + } + std::cout << "[DEBUG] 数组参数: " << name << " 类型转换为指针" << std::endl; + } + Symbol sym; + sym.name = name; + sym.kind = SymbolKind::Parameter; + sym.type = param_type; + sym.scope_level = table_.currentScopeLevel(); + sym.is_initialized = true; + sym.var_def_ctx = nullptr; + table_.addSymbol(sym); + std::cout << "[DEBUG] 添加参数: " << name << " type_kind: " << (int)param_type->GetKind() << std::endl; + } + } + + void CheckGlobalInitIsConst(SysYParser::InitValContext* ctx) { + if (!ctx) return; + if (ctx->exp()) { + ExprInfo info = CheckExp(ctx->exp()); + if (!info.is_const) { + throw std::runtime_error(FormatError("sema", "全局变量初始化必须是常量表达式")); + } + } else { + for (auto* init : ctx->initVal()) { + CheckGlobalInitIsConst(init); + } + } + } + + int EvaluateConstExp(SysYParser::ConstExpContext* ctx) { + if (!ctx || !ctx->addExp()) return 0; + ExprInfo info = CheckAddExp(ctx->addExp()); + if (info.is_const && info.is_const_int) { + return info.const_int_value; + } + throw std::runtime_error(FormatError("sema", "常量表达式求值失败")); + return 0; + } + + struct ConstValue { + bool is_int; + int int_val; + float float_val; + }; + + std::vector EvaluateConstInitVal(SysYParser::ConstInitValContext* ctx, + const std::vector& dims, + std::shared_ptr base_type) { + std::vector result; + if (!ctx) return result; + if (ctx->constExp()) { + ExprInfo info = CheckAddExp(ctx->constExp()->addExp()); + ConstValue val; + if (info.type->IsInt32() && info.is_const_int) { + val.is_int = true; + val.int_val = info.const_int_value; + if (base_type->IsFloat()) { + val.is_int = false; + val.float_val = (float)info.const_int_value; + } + } else if (info.type->IsFloat() && info.is_const) { + val.is_int = false; + val.float_val = info.const_float_value; + if (base_type->IsInt32()) { + val.is_int = true; + val.int_val = (int)info.const_float_value; + } + } else { + val.is_int = base_type->IsInt32(); + val.int_val = 0; + val.float_val = 0.0f; + } + result.push_back(val); + } else { + for (auto* init : ctx->constInitVal()) { + std::vector sub_vals = EvaluateConstInitVal(init, dims, base_type); + result.insert(result.end(), sub_vals.begin(), sub_vals.end()); + } + } + return result; + } + + void CheckMainFunction() { + auto* main_sym = table_.lookup("main"); + if (!main_sym || main_sym->kind != SymbolKind::Function) { + throw std::runtime_error(FormatError("sema", "缺少 main 函数")); + } + std::shared_ptr ret_type; + if (main_sym->type && main_sym->type->IsFunction()) { + auto* func_type = dynamic_cast(main_sym->type.get()); + if (func_type) { + ret_type = func_type->GetReturnType(); + } + } + if (!ret_type || !ret_type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "main 函数必须返回 int")); + } + if (!main_sym->param_types.empty()) { + throw std::runtime_error(FormatError("sema", "main 函数不能有参数")); + } + } }; } // namespace SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { - SemaVisitor visitor; - comp_unit.accept(&visitor); - return visitor.TakeSemanticContext(); + SemaVisitor visitor; + comp_unit.accept(&visitor); + return visitor.TakeSemanticContext(); } diff --git a/src/sem/SymbolTable.cpp b/src/sem/SymbolTable.cpp index ffeea89..0f37e73 100644 --- a/src/sem/SymbolTable.cpp +++ b/src/sem/SymbolTable.cpp @@ -1,17 +1,338 @@ -// 维护局部变量声明的注册与查找。 - #include "sem/SymbolTable.h" +#include // 用于访问父节点 + +// ---------- 构造函数 ---------- +SymbolTable::SymbolTable() { + scopes_.emplace_back(); // 初始化全局作用域 + registerBuiltinFunctions(); // 注册内置库函数 +} + +// ---------- 作用域管理 ---------- +void SymbolTable::enterScope() { + scopes_.emplace_back(); +} + +void SymbolTable::exitScope() { + if (scopes_.size() > 1) { + scopes_.pop_back(); + } + // 不能退出全局作用域 +} + +// ---------- 符号添加与查找 ---------- +bool SymbolTable::addSymbol(const Symbol& sym) { + auto& current_scope = scopes_.back(); + if (current_scope.find(sym.name) != current_scope.end()) { + return false; // 重复定义 + } + current_scope[sym.name] = sym; + return true; +} + +Symbol* SymbolTable::lookup(const std::string& name) { + // 从当前作用域向外层查找 + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + auto& scope = *it; + auto found = scope.find(name); + if (found != scope.end()) { + return &found->second; + } + } + return nullptr; +} -void SymbolTable::Add(const std::string& name, - SysYParser::VarDefContext* decl) { - table_[name] = decl; +Symbol* SymbolTable::lookupCurrent(const std::string& name) { + auto& current_scope = scopes_.back(); + auto it = current_scope.find(name); + if (it != current_scope.end()) { + return &it->second; + } + return nullptr; +} + +// ---------- 兼容原接口 ---------- +void SymbolTable::Add(const std::string& name, SysYParser::VarDefContext* decl) { + Symbol sym; + sym.name = name; + sym.kind = SymbolKind::Variable; + sym.type = getTypeFromVarDef(decl); + sym.var_def_ctx = decl; + sym.scope_level = currentScopeLevel(); + addSymbol(sym); } bool SymbolTable::Contains(const std::string& name) const { - return table_.find(name) != table_.end(); + // const 方法不能修改 scopes_,我们模拟查找 + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + if (it->find(name) != it->end()) { + return true; + } + } + return false; } SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const { - auto it = table_.find(name); - return it == table_.end() ? nullptr : it->second; + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + auto found = it->find(name); + if (found != it->end()) { + // 只返回变量定义的上下文(函数等其他符号返回 nullptr) + if (found->second.kind == SymbolKind::Variable) { + return found->second.var_def_ctx; + } + return nullptr; + } + } + return nullptr; +} + +// ---------- 辅助函数:从 VarDefContext 获取外层 VarDeclContext ---------- +static SysYParser::VarDeclContext* getOuterVarDecl(SysYParser::VarDefContext* varDef) { + auto parent = varDef->parent; + while (parent) { + if (auto varDecl = dynamic_cast(parent)) { + return varDecl; + } + parent = parent->parent; + } + return nullptr; +} + +// ---------- 辅助函数:从 VarDefContext 获取外层 ConstDeclContext(常量定义)---------- +static SysYParser::ConstDeclContext* getOuterConstDecl(SysYParser::VarDefContext* varDef) { + auto parent = varDef->parent; + while (parent) { + if (auto constDecl = dynamic_cast(parent)) { + return constDecl; + } + parent = parent->parent; + } + return nullptr; +} + +// 常量表达式求值(占位,需实现真正的常量折叠) +static int evaluateConstExp(SysYParser::ConstExpContext* ctx) { + // TODO: 实现常量折叠,目前返回0 + return 0; } + +// 从 VarDefContext 构造类型 +std::shared_ptr SymbolTable::getTypeFromVarDef(SysYParser::VarDefContext* ctx) { + // 1. 获取基本类型(int/float) + std::shared_ptr base_type = nullptr; + auto varDecl = getOuterVarDecl(ctx); + if (varDecl) { + auto bType = varDecl->bType(); + if (bType->Int()) { + base_type = ir::Type::GetInt32Type(); + } else if (bType->Float()) { + base_type = ir::Type::GetFloatType(); + } + } else { + auto constDecl = getOuterConstDecl(ctx); + if (constDecl) { + auto bType = constDecl->bType(); + if (bType->Int()) { + base_type = ir::Type::GetInt32Type(); + } else if (bType->Float()) { + base_type = ir::Type::GetFloatType(); + } + } + } + + if (!base_type) { + base_type = ir::Type::GetInt32Type(); // 默认 int + } + + // 2. 解析数组维度(从 varDef 的 constExp 列表获取) + std::vector dims; + for (auto constExp : ctx->constExp()) { + int dimVal = evaluateConstExp(constExp); + dims.push_back(dimVal); + } + + if (!dims.empty()) { + return ir::Type::GetArrayType(base_type, dims); + } + return base_type; +} + +// 从 FuncDefContext 构造函数类型 +std::shared_ptr SymbolTable::getTypeFromFuncDef(SysYParser::FuncDefContext* ctx) { + // 1. 返回类型 + std::shared_ptr ret_type; + auto funcType = ctx->funcType(); + if (funcType->Void()) { + ret_type = ir::Type::GetVoidType(); + } else if (funcType->Int()) { + ret_type = ir::Type::GetInt32Type(); + } else if (funcType->Float()) { + ret_type = ir::Type::GetFloatType(); + } else { + ret_type = ir::Type::GetInt32Type(); // fallback + } + + // 2. 参数类型 + std::vector> param_types; + auto fParams = ctx->funcFParams(); + if (fParams) { + for (auto param : fParams->funcFParam()) { + std::shared_ptr param_type; + auto bType = param->bType(); + if (bType->Int()) { + param_type = ir::Type::GetInt32Type(); + } else if (bType->Float()) { + param_type = ir::Type::GetFloatType(); + } else { + param_type = ir::Type::GetInt32Type(); + } + + // 处理数组参数:如果存在 [ ] 或 [ exp ],退化为指针 + if (param->L_BRACK().size() > 0) { + if (param_type->IsInt32()) { + param_type = ir::Type::GetPtrInt32Type(); + } else if (param_type->IsFloat()) { + param_type = ir::Type::GetPtrFloatType(); + } + } + param_types.push_back(param_type); + } + } + + return ir::Type::GetFunctionType(ret_type, param_types); +} + +// ----- 注册内置库函数----- +void SymbolTable::registerBuiltinFunctions() { + // 确保当前处于全局作用域(scopes_ 只有一层) + // 1. getint: int getint() + Symbol getint; + getint.name = "getint"; + getint.kind = SymbolKind::Function; + getint.type = ir::Type::GetFunctionType(ir::Type::GetInt32Type(), {}); // 无参数 + getint.param_types = {}; + getint.scope_level = 0; + getint.is_builtin = true; + addSymbol(getint); + + // 2. getfloat: float getfloat() + Symbol getfloat; + getfloat.name = "getfloat"; + getfloat.kind = SymbolKind::Function; + getfloat.type = ir::Type::GetFunctionType(ir::Type::GetFloatType(), {}); + getfloat.param_types = {}; + getfloat.scope_level = 0; + getfloat.is_builtin = true; + addSymbol(getfloat); + + // 3. getch: int getch() + Symbol getch; + getch.name = "getch"; + getch.kind = SymbolKind::Function; + getch.type = ir::Type::GetFunctionType(ir::Type::GetInt32Type(), {}); + getch.param_types = {}; + getch.scope_level = 0; + getch.is_builtin = true; + addSymbol(getch); + + // 4. putint: void putint(int) + std::vector> putint_params = { ir::Type::GetInt32Type() }; + Symbol putint; + putint.name = "putint"; + putint.kind = SymbolKind::Function; + putint.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), putint_params); + putint.param_types = putint_params; + putint.scope_level = 0; + putint.is_builtin = true; + addSymbol(putint); + + // 5. putfloat: void putfloat(float) + std::vector> putfloat_params = { ir::Type::GetFloatType() }; + Symbol putfloat; + putfloat.name = "putfloat"; + putfloat.kind = SymbolKind::Function; + putfloat.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), putfloat_params); + putfloat.param_types = putfloat_params; + putfloat.scope_level = 0; + putfloat.is_builtin = true; + addSymbol(putfloat); + + // 6. putch: void putch(int) + std::vector> putch_params = { ir::Type::GetInt32Type() }; + Symbol putch; + putch.name = "putch"; + putch.kind = SymbolKind::Function; + putch.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), putch_params); + putch.param_types = putch_params; + putch.scope_level = 0; + putch.is_builtin = true; + addSymbol(putch); + + // 7. getarray: int getarray(int a[]) + // 参数类型: int a[] 退化为 int* 即 PtrInt32 + std::vector> getarray_params = { ir::Type::GetPtrInt32Type() }; + Symbol getarray; + getarray.name = "getarray"; + getarray.kind = SymbolKind::Function; + getarray.type = ir::Type::GetFunctionType(ir::Type::GetInt32Type(), getarray_params); + getarray.param_types = getarray_params; + getarray.scope_level = 0; + getarray.is_builtin = true; + addSymbol(getarray); + + // 8. putarray: void putarray(int n, int a[]) + // 参数: int n, int a[] -> 实际类型: int, int* + std::vector> putarray_params = { ir::Type::GetInt32Type(), ir::Type::GetPtrInt32Type() }; + Symbol putarray; + putarray.name = "putarray"; + putarray.kind = SymbolKind::Function; + putarray.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), putarray_params); + putarray.param_types = putarray_params; + putarray.scope_level = 0; + putarray.is_builtin = true; + addSymbol(putarray); + + // starttime: void starttime() + Symbol starttime; + starttime.name = "starttime"; + starttime.kind = SymbolKind::Function; + starttime.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), {}); // 无参数,返回 void + starttime.param_types = {}; + starttime.scope_level = 0; + starttime.is_builtin = true; + addSymbol(starttime); + + // stoptime: void stoptime() + Symbol stoptime; + stoptime.name = "stoptime"; + stoptime.kind = SymbolKind::Function; + stoptime.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), {}); // 无参数,返回 void + stoptime.param_types = {}; + stoptime.scope_level = 0; + stoptime.is_builtin = true; + addSymbol(stoptime); + + // getfarray: int getfarray(float arr[]) + std::vector> getfarray_params = { ir::Type::GetPtrFloatType() }; + Symbol getfarray; + getfarray.name = "getfarray"; + getfarray.kind = SymbolKind::Function; + getfarray.type = ir::Type::GetFunctionType(ir::Type::GetInt32Type(), getfarray_params); + getfarray.param_types = getfarray_params; + getfarray.scope_level = 0; + getfarray.is_builtin = true; + addSymbol(getfarray); + + // putfarray: void putfarray(int len, float arr[]) + std::vector> putfarray_params = { + ir::Type::GetInt32Type(), + ir::Type::GetPtrFloatType() + }; + Symbol putfarray; + putfarray.name = "putfarray"; + putfarray.kind = SymbolKind::Function; + putfarray.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), putfarray_params); + putfarray.param_types = putfarray_params; + putfarray.scope_level = 0; + putfarray.is_builtin = true; + addSymbol(putfarray); +} \ No newline at end of file