From 7d4d60c5462f480f36f67cee667044b87c327f14 Mon Sep 17 00:00:00 2001 From: jing <3030349106@qq.com> Date: Wed, 18 Mar 2026 01:53:54 +0800 Subject: [PATCH] =?UTF-8?q?refactor(ir):=20ir=E6=94=B9=E4=B8=BA=E6=9B=B4?= =?UTF-8?q?=E6=A0=87=E5=87=86=E7=9A=84=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/ir/IR.h | 156 +++++++++++++++++++------ src/ir/BasicBlock.cpp | 17 ++- src/ir/CMakeLists.txt | 1 + src/ir/Context.cpp | 23 +--- src/ir/GlobalValue.cpp | 11 ++ src/ir/IRBuilder.cpp | 8 +- src/ir/Type.cpp | 15 +++ src/ir/Value.cpp | 65 ++++++++++- src/sem/Sema.cpp | 259 +++++++++++++++++++++++++---------------- 9 files changed, 390 insertions(+), 165 deletions(-) create mode 100644 src/ir/GlobalValue.cpp diff --git a/include/ir/IR.h b/include/ir/IR.h index c04268d..b961192 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -1,5 +1,33 @@ -// 极简 IR 定义:当前只支撑 i32 和加法,演示用。 -// 可在此基础上扩展更多类型/指令 +// 当前只支撑 i32、i32*、void 以及最小的内存/算术指令,演示用。 +// +// 当前已经实现: +// 1. 基础类型系统:void / i32 / i32* +// 2. Value 体系:Value / ConstantValue / ConstantInt / Function / BasicBlock / User / GlobalValue / Instruction +// 3. 最小指令集:Add / Alloca / Load / Store / Ret +// 4. BasicBlock / Function / Module 三层组织结构 +// 5. IRBuilder:便捷创建常量和最小指令 +// 6. def-use 关系的轻量实现: +// - Instruction 保存 operand 列表 +// - Value 保存 uses +// - 支持 ReplaceAllUsesWith 的简化实现 +// +// 当前尚未实现或只做了最小占位: +// 1. 完整类型系统:数组、函数类型、label 类型等 +// 2. 更完整的指令系统:br / condbr / call / phi / gep 等 +// 3. 更成熟的 Use 管理(例如 LLVM 风格的双向链式结构) +// 4. 更完整的 IR verifier 和优化基础设施 +// +// 当前需要特别说明的两个简化点: +// 1. BasicBlock 虽然已经纳入 Value 体系,但其类型目前仍用 void 作为占位, +// 后续如果补 label type,可以再改成更合理的块标签类型。 +// 2. ConstantValue 体系目前只实现了 ConstantInt,后续可以继续补 ConstantFloat、 +// ConstantArray等更完整的常量种类。 +// +// 建议的扩展顺序: +// 1. 先补更多指令和类型 +// 2. 再补控制流相关 IR +// 3. 最后再考虑把 Value/User/Use 进一步抽象成更完整的框架 + #pragma once #include @@ -13,28 +41,52 @@ namespace ir { class Type; +class Value; +class User; +class ConstantValue; class ConstantInt; +class GlobalValue; class Instruction; class BasicBlock; class Function; +// Use 表示一个 Value 的一次使用记录。 +// 当前实现设计: +// - value:被使用的值 +// - user:使用该值的 User +// - operand_index:该值在 user 操作数列表中的位置 + +class Use { + public: + Use() = default; + Use(Value* value, User* user, size_t operand_index) + : value_(value), user_(user), operand_index_(operand_index) {} + + Value* GetValue() const { return value_; } + User* GetUser() const { return user_; } + size_t GetOperandIndex() const { return operand_index_; } + + void SetValue(Value* value) { value_ = value; } + void SetUser(User* user) { user_ = user; } + void SetOperandIndex(size_t operand_index) { operand_index_ = operand_index; } + + private: + Value* value_ = nullptr; + User* user_ = nullptr; + size_t operand_index_ = 0; +}; + // IR 上下文:集中管理类型、常量等共享资源,便于复用与扩展。 class Context { public: Context() = default; ~Context(); - const std::shared_ptr& Void(); - const std::shared_ptr& Int32(); - const std::shared_ptr& PtrInt32(); // 去重创建 i32 常量。 ConstantInt* GetConstInt(int v); std::string NextTemp(); private: - std::shared_ptr void_; - std::shared_ptr int32_; - std::shared_ptr ptr_i32_; std::unordered_map> const_ints_; int temp_index_ = -1; }; @@ -43,6 +95,12 @@ class Type { public: enum class Kind { Void, Int32, PtrInt32 }; explicit Type(Kind k); + // 使用静态共享对象获取类型。 + // 同一类型可直接比较返回值是否相等,例如: + // Type::GetInt32Type() == Type::GetInt32Type() + static const std::shared_ptr& GetVoidType(); + static const std::shared_ptr& GetInt32Type(); + static const std::shared_ptr& GetPtrInt32Type(); Kind GetKind() const; bool IsVoid() const; bool IsInt32() const; @@ -59,16 +117,32 @@ class Value { const std::shared_ptr& GetType() const; const std::string& GetName() const; void SetName(std::string n); - void AddUser(Instruction* user); - const std::vector& GetUsers() const; + bool IsVoid() const; + bool IsInt32() const; + bool IsPtrInt32() const; + bool IsConstant() const; + bool IsInstruction() const; + bool IsUser() const; + bool IsFunction() const; + void AddUse(User* user, size_t operand_index); + void RemoveUse(User* user, size_t operand_index); + const std::vector& GetUses() const; + void ReplaceAllUsesWith(Value* new_value); protected: std::shared_ptr type_; std::string name_; - std::vector users_; + std::vector uses_; }; -class ConstantInt : public Value { +// ConstantValue 是常量体系的基类。 +// 当前只实现了 ConstantInt,后续可继续扩展更多常量种类。 +class ConstantValue : public Value { + public: + ConstantValue(std::shared_ptr ty, std::string name = ""); +}; + +class ConstantInt : public ConstantValue { public: ConstantInt(std::shared_ptr ty, int v); int GetValue() const { return value_; } @@ -80,7 +154,31 @@ class ConstantInt : public Value { // 后续还需要扩展更多指令类型。 enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret }; -class Instruction : public Value { +// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。 +// 当前实现中只有 Instruction 继承自 User。 +class User : public Value { + public: + User(std::shared_ptr ty, std::string name); + size_t GetNumOperands() const; + Value* GetOperand(size_t index) const; + void SetOperand(size_t index, Value* value); + + protected: + // 统一的 operand 入口。 + void AddOperand(Value* value); + + private: + std::vector operands_; +}; + +// GlobalValue 是全局值/全局变量体系的空壳占位类。 +// 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。 +class GlobalValue : public User { + public: + GlobalValue(std::shared_ptr ty, std::string name); +}; + +class Instruction : public User { public: Instruction(Opcode op, std::shared_ptr ty, std::string name = ""); Opcode GetOpcode() const; @@ -98,20 +196,13 @@ class BinaryInst : public Instruction { BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, Value* rhs, std::string name); Value* GetLhs() const; - Value* GetRhs() const; - - private: - Value* lhs_; - Value* rhs_; + Value* GetRhs() const; }; class ReturnInst : public Instruction { public: ReturnInst(std::shared_ptr void_ty, Value* val); Value* GetValue() const; - - private: - Value* value_; }; class AllocaInst : public Instruction { @@ -123,9 +214,6 @@ class LoadInst : public Instruction { public: LoadInst(std::shared_ptr val_ty, Value* ptr, std::string name); Value* GetPtr() const; - - private: - Value* ptr_; }; class StoreInst : public Instruction { @@ -133,16 +221,13 @@ class StoreInst : public Instruction { StoreInst(std::shared_ptr void_ty, Value* val, Value* ptr); Value* GetValue() const; Value* GetPtr() const; - - private: - Value* value_; - Value* ptr_; }; -class BasicBlock { +// BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。 +// 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。 +class BasicBlock : public Value { public: explicit BasicBlock(std::string name); - const std::string& GetName() const; Function* GetParent() const; void SetParent(Function* parent); bool HasTerminator() const; @@ -163,16 +248,21 @@ class BasicBlock { } private: - std::string name_; Function* parent_ = nullptr; std::vector> instructions_; std::vector predecessors_; std::vector successors_; }; +// Function 当前也采用了最小实现。 +// 需要特别注意:由于项目里还没有单独的 FunctionType, +// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”, +// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。 +// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、 +// 形参和调用,通常需要引入专门的函数类型表示。 class Function : public Value { public: - // 允许显式指定返回类型,便于后续扩展多种函数签名。 + // 当前构造函数接收的也是返回类型,而不是完整函数类型。 Function(std::string name, std::shared_ptr ret_type); BasicBlock* CreateBlock(const std::string& name); BasicBlock* GetEntry(); @@ -189,7 +279,7 @@ class Module { Module() = default; Context& GetContext(); const Context& GetContext() const; - // 创建函数时显式传入返回类型,便于在 IRGen 中根据语法树信息选择类型。 + // 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。 Function* CreateFunction(const std::string& name, std::shared_ptr ret_type); const std::vector>& GetFunctions() const; diff --git a/src/ir/BasicBlock.cpp b/src/ir/BasicBlock.cpp index 4c1b19d..b18502c 100644 --- a/src/ir/BasicBlock.cpp +++ b/src/ir/BasicBlock.cpp @@ -1,6 +1,11 @@ // IR 基本块: // - 保存指令序列 -// - 维护或可计算前驱/后继关系,用于 CFG 分析与优化 +// - 为后续 CFG 分析预留前驱/后继接口 +// +// 当前仍是最小实现: +// - BasicBlock 已纳入 Value 体系,但类型先用 void 占位; +// - 指令追加与 terminator 约束主要在头文件中的 Append 模板里处理; +// - 前驱/后继容器已经预留,但当前项目里还没有分支指令与自动维护逻辑。 #include "ir/IR.h" @@ -8,23 +13,27 @@ namespace ir { -BasicBlock::BasicBlock(std::string name) : name_(std::move(name)) {} - -const std::string& BasicBlock::GetName() const { return name_; } +// 当前 BasicBlock 还没有专门的 label type,因此先用 void 作为占位类型。 +BasicBlock::BasicBlock(std::string name) + : Value(Type::GetVoidType(), std::move(name)) {} Function* BasicBlock::GetParent() const { return parent_; } void BasicBlock::SetParent(Function* parent) { parent_ = parent; } + bool BasicBlock::HasTerminator() const { return !instructions_.empty() && instructions_.back()->IsTerminator(); } +// 按插入顺序返回块内指令序列。 const std::vector>& BasicBlock::GetInstructions() const { return instructions_; } +// 前驱/后继接口先保留给后续 CFG 扩展使用。 +// 当前最小 IR 中还没有 branch 指令,因此这些列表通常为空。 const std::vector& BasicBlock::GetPredecessors() const { return predecessors_; } diff --git a/src/ir/CMakeLists.txt b/src/ir/CMakeLists.txt index c3b6e7b..99987ed 100644 --- a/src/ir/CMakeLists.txt +++ b/src/ir/CMakeLists.txt @@ -3,6 +3,7 @@ add_library(ir_core STATIC Module.cpp Function.cpp BasicBlock.cpp + GlobalValue.cpp Type.cpp Value.cpp Instruction.cpp diff --git a/src/ir/Context.cpp b/src/ir/Context.cpp index 6f95676..16c982c 100644 --- a/src/ir/Context.cpp +++ b/src/ir/Context.cpp @@ -7,32 +7,11 @@ namespace ir { Context::~Context() = default; -const std::shared_ptr& Context::Void() { - if (!void_) { - void_ = std::make_shared(Type::Kind::Void); - } - return void_; -} - -const std::shared_ptr& Context::Int32() { - if (!int32_) { - int32_ = std::make_shared(Type::Kind::Int32); - } - return int32_; -} - -const std::shared_ptr& Context::PtrInt32() { - if (!ptr_i32_) { - ptr_i32_ = std::make_shared(Type::Kind::PtrInt32); - } - return ptr_i32_; -} - ConstantInt* Context::GetConstInt(int v) { auto it = const_ints_.find(v); if (it != const_ints_.end()) return it->second.get(); auto inserted = - const_ints_.emplace(v, std::make_unique(Int32(), v)).first; + const_ints_.emplace(v, std::make_unique(Type::GetInt32Type(), v)).first; return inserted->second.get(); } diff --git a/src/ir/GlobalValue.cpp b/src/ir/GlobalValue.cpp new file mode 100644 index 0000000..7c2abe1 --- /dev/null +++ b/src/ir/GlobalValue.cpp @@ -0,0 +1,11 @@ +// GlobalValue 占位实现: +// - 具体的全局初始化器、打印和链接语义需要自行补全 + +#include "ir/IR.h" + +namespace ir { + +GlobalValue::GlobalValue(std::shared_ptr ty, std::string name) + : User(std::move(ty), std::move(name)) {} + +} // namespace ir diff --git a/src/ir/IRBuilder.cpp b/src/ir/IRBuilder.cpp index 4a2b502..90f03c4 100644 --- a/src/ir/IRBuilder.cpp +++ b/src/ir/IRBuilder.cpp @@ -46,7 +46,7 @@ AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); } - return insert_block_->Append(ctx_.PtrInt32(), name); + return insert_block_->Append(Type::GetPtrInt32Type(), name); } LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { @@ -57,7 +57,7 @@ LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { throw std::runtime_error( FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr")); } - return insert_block_->Append(ctx_.Int32(), ptr, name); + return insert_block_->Append(Type::GetInt32Type(), ptr, name); } StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { @@ -72,7 +72,7 @@ StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { throw std::runtime_error( FormatError("ir", "IRBuilder::CreateStore 缺少 ptr")); } - return insert_block_->Append(ctx_.Void(), val, ptr); + return insert_block_->Append(Type::GetVoidType(), val, ptr); } ReturnInst* IRBuilder::CreateRet(Value* v) { @@ -83,7 +83,7 @@ ReturnInst* IRBuilder::CreateRet(Value* v) { throw std::runtime_error( FormatError("ir", "IRBuilder::CreateRet 缺少返回值")); } - return insert_block_->Append(ctx_.Void(), v); + return insert_block_->Append(Type::GetVoidType(), v); } } // namespace ir diff --git a/src/ir/Type.cpp b/src/ir/Type.cpp index cbbc7cc..3e1684d 100644 --- a/src/ir/Type.cpp +++ b/src/ir/Type.cpp @@ -5,6 +5,21 @@ namespace ir { Type::Type(Kind k) : kind_(k) {} +const std::shared_ptr& Type::GetVoidType() { + static const std::shared_ptr type = std::make_shared(Kind::Void); + return type; +} + +const std::shared_ptr& Type::GetInt32Type() { + static const std::shared_ptr type = std::make_shared(Kind::Int32); + return type; +} + +const std::shared_ptr& Type::GetPtrInt32Type() { + static const std::shared_ptr type = std::make_shared(Kind::PtrInt32); + return type; +} + Type::Kind Type::GetKind() const { return kind_; } bool Type::IsVoid() const { return kind_ == Kind::Void; } diff --git a/src/ir/Value.cpp b/src/ir/Value.cpp index 091eb3b..2e9f4c1 100644 --- a/src/ir/Value.cpp +++ b/src/ir/Value.cpp @@ -3,6 +3,8 @@ // - 提供类型信息与使用/被使用关系(按需要实现) #include "ir/IR.h" +#include + namespace ir { Value::Value(std::shared_ptr ty, std::string name) @@ -14,11 +16,68 @@ const std::string& Value::GetName() const { return name_; } void Value::SetName(std::string n) { name_ = std::move(n); } -void Value::AddUser(Instruction* user) { users_.push_back(user); } +bool Value::IsVoid() const { return type_ && type_->IsVoid(); } + +bool Value::IsInt32() const { return type_ && type_->IsInt32(); } + +bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); } + +bool Value::IsConstant() const { + return dynamic_cast(this) != nullptr; +} + +bool Value::IsInstruction() const { + return dynamic_cast(this) != nullptr; +} + +bool Value::IsUser() const { + return dynamic_cast(this) != nullptr; +} + +bool Value::IsFunction() const { + return dynamic_cast(this) != nullptr; +} + +void Value::AddUse(User* user, size_t operand_index) { + if (!user) return; + uses_.push_back(Use(this, user, operand_index)); +} + +void Value::RemoveUse(User* user, size_t operand_index) { + uses_.erase( + std::remove_if(uses_.begin(), uses_.end(), + [&](const Use& use) { + return use.GetUser() == user && + use.GetOperandIndex() == operand_index; + }), + uses_.end()); +} + +const std::vector& Value::GetUses() const { return uses_; } + +void Value::ReplaceAllUsesWith(Value* new_value) { + if (!new_value) { + throw std::runtime_error("ReplaceAllUsesWith 缺少 new_value"); + } + if (new_value == this) { + return; + } + + auto uses = uses_; + for (const auto& use : uses) { + auto* user = use.GetUser(); + if (!user) continue; + size_t operand_index = use.GetOperandIndex(); + if (user->GetOperand(operand_index) == this) { + user->SetOperand(operand_index, new_value); + } + } +} -const std::vector& Value::GetUsers() const { return users_; } +ConstantValue::ConstantValue(std::shared_ptr ty, std::string name) + : Value(std::move(ty), std::move(name)) {} ConstantInt::ConstantInt(std::shared_ptr ty, int v) - : Value(std::move(ty), ""), value_(v) {} + : ConstantValue(std::move(ty), ""), value_(v) {} } // namespace ir diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index 32b7a75..745374c 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -1,16 +1,15 @@ #include "sem/Sema.h" +#include #include #include +#include "SysYBaseVisitor.h" #include "sem/SymbolTable.h" #include "utils/Log.h" namespace { -void CheckExpr(SysYParser::ExpContext& exp, const SymbolTable& table, - SemanticContext& sema); - std::string GetLValueName(SysYParser::LValueContext& lvalue) { if (!lvalue.ID()) { throw std::runtime_error(FormatError("sema", "非法左值")); @@ -18,122 +17,184 @@ std::string GetLValueName(SysYParser::LValueContext& lvalue) { return lvalue.ID()->getText(); } -void CheckVar(SysYParser::VarContext& var, const SymbolTable& table, - SemanticContext& sema) { - if (!var.ID()) { - throw std::runtime_error(FormatError("sema", "非法变量引用")); - } - const std::string name = var.ID()->getText(); - auto* decl = table.Lookup(name); - if (!decl) { - throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); +class SemaVisitor final : public SysYBaseVisitor { + public: + std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "缺少编译单元")); + } + auto* func = ctx->funcDef(); + if (!func || !func->blockStmt()) { + throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); + } + if (!func->ID() || func->ID()->getText() != "main") { + throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); + } + func->accept(this); + if (!seen_return_) { + throw std::runtime_error( + FormatError("sema", "main 函数必须包含 return 语句")); + } + return {}; } - sema.BindVarUse(&var, decl); -} -void CheckExpr(SysYParser::ExpContext& exp, const SymbolTable& table, - SemanticContext& sema) { - if (auto* paren = dynamic_cast(&exp)) { - CheckExpr(*paren->exp(), table, sema); - return; - } - if (auto* var = dynamic_cast(&exp)) { - if (!var->var()) { - throw std::runtime_error(FormatError("sema", "非法变量表达式")); + std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { + if (!ctx || !ctx->blockStmt()) { + throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); } - CheckVar(*var->var(), table, sema); - return; - } - if (dynamic_cast(&exp)) { - return; - } - if (auto* binary = dynamic_cast(&exp)) { - CheckExpr(*binary->exp(0), table, sema); - CheckExpr(*binary->exp(1), table, sema); - return; + if (!ctx->funcType() || !ctx->funcType()->INT()) { + throw std::runtime_error(FormatError("sema", "当前仅支持 int main")); + } + const auto& items = ctx->blockStmt()->blockItem(); + if (items.empty()) { + throw std::runtime_error( + FormatError("sema", "main 函数不能为空,且必须以 return 结束")); + } + ctx->blockStmt()->accept(this); + return {}; } - throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式")); -} -SysYParser::FuncDefContext* FindMainFunc(SysYParser::CompUnitContext& comp_unit) { - auto* func = comp_unit.funcDef(); - if (func && func->ID() && func->ID()->getText() == "main") { - return func; + std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "缺少语句块")); + } + const auto& items = ctx->blockItem(); + for (size_t i = 0; i < items.size(); ++i) { + auto* item = items[i]; + if (!item) { + continue; + } + if (seen_return_) { + throw std::runtime_error( + FormatError("sema", "return 必须是 main 函数中的最后一条语句")); + } + current_item_index_ = i; + total_items_ = items.size(); + item->accept(this); + } + return {}; } - return nullptr; -} - -} // namespace -SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { - auto* func = FindMainFunc(comp_unit); - if (!func || !func->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); - } - if (!func->funcType() || !func->funcType()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持 int main")); + std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + } + if (ctx->decl()) { + ctx->decl()->accept(this); + return {}; + } + if (ctx->stmt()) { + ctx->stmt()->accept(this); + return {}; + } + throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); } - SymbolTable table; - SemanticContext sema; - bool seen_return = false; + std::any visitDecl(SysYParser::DeclContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "非法变量声明")); + } + if (!ctx->btype() || !ctx->btype()->INT()) { + throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明")); + } + auto* var_def = ctx->varDef(); + if (!var_def || !var_def->lValue()) { + throw std::runtime_error(FormatError("sema", "非法变量声明")); + } + const std::string name = GetLValueName(*var_def->lValue()); + if (table_.Contains(name)) { + throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); + } + if (auto* init = var_def->initValue()) { + if (!init->exp()) { + throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化")); + } + init->exp()->accept(this); + } + table_.Add(name, var_def); + return {}; + } - const auto& items = func->blockStmt()->blockItem(); - if (items.empty()) { - throw std::runtime_error( - FormatError("sema", "main 函数不能为空,且必须以 return 结束")); + std::any visitStmt(SysYParser::StmtContext* ctx) override { + if (!ctx || !ctx->returnStmt()) { + throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + } + ctx->returnStmt()->accept(this); + return {}; } - for (size_t i = 0; i < items.size(); ++i) { - auto* item = items[i]; - if (!item) { - continue; + std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override { + if (!ctx || !ctx->exp()) { + throw std::runtime_error(FormatError("sema", "return 缺少表达式")); } - if (seen_return) { + ctx->exp()->accept(this); + seen_return_ = true; + if (current_item_index_ + 1 != total_items_) { throw std::runtime_error( FormatError("sema", "return 必须是 main 函数中的最后一条语句")); } - if (auto* decl = item->decl()) { - if (!decl->btype() || !decl->btype()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明")); - } - auto* var_def = decl->varDef(); - if (!var_def || !var_def->lValue()) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); - } - const std::string name = GetLValueName(*var_def->lValue()); - if (table.Contains(name)) { - throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); - } - if (auto* init = var_def->initValue()) { - if (!init->exp()) { - throw std::runtime_error( - FormatError("sema", "当前不支持聚合初始化")); - } - CheckExpr(*init->exp(), table, sema); - } - table.Add(name, var_def); - continue; + return {}; + } + + std::any visitParenExp(SysYParser::ParenExpContext* ctx) override { + if (!ctx || !ctx->exp()) { + throw std::runtime_error(FormatError("sema", "非法括号表达式")); } - if (auto* stmt = item->stmt(); stmt && stmt->returnStmt()) { - auto* ret = stmt->returnStmt(); - if (!ret->exp()) { - throw std::runtime_error(FormatError("sema", "return 缺少表达式")); - } - CheckExpr(*ret->exp(), table, sema); - seen_return = true; - if (i + 1 != items.size()) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); - } - continue; + ctx->exp()->accept(this); + return {}; + } + + std::any visitVarExp(SysYParser::VarExpContext* ctx) override { + if (!ctx || !ctx->var()) { + throw std::runtime_error(FormatError("sema", "非法变量表达式")); } - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + ctx->var()->accept(this); + return {}; + } + + std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override { + if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { + throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量")); + } + return {}; + } + + std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override { + if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { + throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式")); + } + ctx->exp(0)->accept(this); + ctx->exp(1)->accept(this); + return {}; } - if (!seen_return) { - throw std::runtime_error(FormatError("sema", "main 函数必须包含 return 语句")); + std::any visitVar(SysYParser::VarContext* ctx) override { + if (!ctx || !ctx->ID()) { + throw std::runtime_error(FormatError("sema", "非法变量引用")); + } + const std::string name = ctx->ID()->getText(); + auto* decl = table_.Lookup(name); + if (!decl) { + throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); + } + sema_.BindVarUse(ctx, decl); + return {}; } - return sema; + SemanticContext TakeSemanticContext() { return std::move(sema_); } + + private: + SymbolTable table_; + SemanticContext sema_; + bool seen_return_ = false; + size_t current_item_index_ = 0; + size_t total_items_ = 0; +}; + +} // namespace + +SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { + SemaVisitor visitor; + comp_unit.accept(&visitor); + return visitor.TakeSemanticContext(); }