From 09ce09b433bceddd3c795c2715c8fb31f8f0934b Mon Sep 17 00:00:00 2001 From: Junhe Wu <2561075610@qq.com> Date: Thu, 26 Mar 2026 12:46:04 +0800 Subject: [PATCH] =?UTF-8?q?feat(ir):=20=E5=88=9D=E6=AD=A5=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E4=BA=86=E4=B8=AD=E9=97=B4=E4=BB=A3=E7=A0=81=E7=94=9F?= =?UTF-8?q?=E6=88=90=E7=9A=84=E8=BF=87=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/ir/IR.h | 350 +++++++++++++++++----- include/irgen/IRGen.h | 107 +++++-- include/sem/Sema.h | 3 - scripts/run_ir_test.sh | 67 +++++ scripts/verify_ir.sh | 2 +- src/ir/Context.cpp | 12 +- src/ir/Function.cpp | 21 +- src/ir/IRBuilder.cpp | 209 +++++++++++-- src/ir/IRPrinter.cpp | 302 +++++++++++++++---- src/ir/Instruction.cpp | 255 +++++++++++----- src/ir/Module.cpp | 55 +++- src/ir/Type.cpp | 26 +- src/ir/Value.cpp | 17 +- src/irgen/CMakeLists.txt | 1 + src/irgen/IRGenDecl.cpp | 273 +++++++++++++---- src/irgen/IRGenExp.cpp | 620 ++++++++++++++++++++++++++++++++++++--- src/irgen/IRGenFunc.cpp | 145 +++++---- src/irgen/IRGenStmt.cpp | 195 ++++++++++-- src/sem/CMakeLists.txt | 3 +- src/sem/Sema.cpp | 387 ++++++++++++++---------- test_const_float.sy | 9 + test_float.sy | 8 + test_float_full.sy | 30 ++ test_float_simple.sy | 18 ++ 24 files changed, 2506 insertions(+), 609 deletions(-) create mode 100755 scripts/run_ir_test.sh create mode 100644 test_const_float.sy create mode 100644 test_float.sy create mode 100644 test_float_full.sy create mode 100644 test_float_simple.sy diff --git a/include/ir/IR.h b/include/ir/IR.h index b961192..0315e28 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -1,32 +1,12 @@ -// 当前只支撑 i32、i32*、void 以及最小的内存/算术指令,演示用。 +// IR 中间表示:类型系统、Value 体系、指令集、基本块、函数、模块、IRBuilder。 // -// 当前已经实现: -// 1. 基础类型系统:void / i32 / i32* -// 2. Value 体系:Value / ConstantValue / ConstantInt / Function / BasicBlock / User / GlobalValue / Instruction -// 3. 最小指令集:Add / Alloca / Load / Store / Ret -// 4. BasicBlock / Function / Module 三层组织结构 -// 5. IRBuilder:便捷创建常量和最小指令 -// 6. def-use 关系的轻量实现: -// - Instruction 保存 operand 列表 -// - Value 保存 uses -// - 支持 ReplaceAllUsesWith 的简化实现 -// -// 当前尚未实现或只做了最小占位: -// 1. 完整类型系统:数组、函数类型、label 类型等 -// 2. 更完整的指令系统:br / condbr / call / phi / gep 等 -// 3. 更成熟的 Use 管理(例如 LLVM 风格的双向链式结构) -// 4. 更完整的 IR verifier 和优化基础设施 -// -// 当前需要特别说明的两个简化点: -// 1. BasicBlock 虽然已经纳入 Value 体系,但其类型目前仍用 void 作为占位, -// 后续如果补 label type,可以再改成更合理的块标签类型。 -// 2. ConstantValue 体系目前只实现了 ConstantInt,后续可以继续补 ConstantFloat、 -// ConstantArray等更完整的常量种类。 -// -// 建议的扩展顺序: -// 1. 先补更多指令和类型 -// 2. 再补控制流相关 IR -// 3. 最后再考虑把 Value/User/Use 进一步抽象成更完整的框架 +// 已实现: +// 1. 类型系统:void / i1 / i32 / i32* +// 2. Value 体系:ConstantInt / Function / BasicBlock / GlobalVariable / Instruction +// 3. 指令集:Add/Sub/Mul/Div/Mod/ICmp/Alloca/Load/Store/Ret/Br/CondBr/Call/ZExt +// 4. 全局变量 / 外部函数声明 +// 5. IRBuilder 便捷接口 +// 6. use-def 关系 #pragma once @@ -45,17 +25,14 @@ class Value; class User; class ConstantValue; class ConstantInt; -class GlobalValue; +class ConstantFloat; +class GlobalVariable; class Instruction; class BasicBlock; class Function; +class Module; -// Use 表示一个 Value 的一次使用记录。 -// 当前实现设计: -// - value:被使用的值 -// - user:使用该值的 User -// - operand_index:该值在 user 操作数列表中的位置 - +// ─── Use ────────────────────────────────────────────────────────────────────── class Use { public: Use() = default; @@ -76,40 +53,45 @@ class Use { size_t operand_index_ = 0; }; -// IR 上下文:集中管理类型、常量等共享资源,便于复用与扩展。 +// ─── Context ────────────────────────────────────────────────────────────────── class Context { public: Context() = default; ~Context(); - // 去重创建 i32 常量。 ConstantInt* GetConstInt(int v); - + ConstantFloat* GetConstFloat(float v); std::string NextTemp(); private: std::unordered_map> const_ints_; + std::unordered_map> const_floats_; int temp_index_ = -1; }; +// ─── Type ───────────────────────────────────────────────────────────────────── class Type { public: - enum class Kind { Void, Int32, PtrInt32 }; + enum class Kind { Void, Int1, Int32, Float32, PtrInt32, PtrFloat32 }; explicit Type(Kind k); - // 使用静态共享对象获取类型。 - // 同一类型可直接比较返回值是否相等,例如: - // Type::GetInt32Type() == Type::GetInt32Type() static const std::shared_ptr& GetVoidType(); + static const std::shared_ptr& GetInt1Type(); static const std::shared_ptr& GetInt32Type(); + static const std::shared_ptr& GetFloat32Type(); static const std::shared_ptr& GetPtrInt32Type(); + static const std::shared_ptr& GetPtrFloat32Type(); Kind GetKind() const; bool IsVoid() const; + bool IsInt1() const; bool IsInt32() const; + bool IsFloat32() const; bool IsPtrInt32() const; + bool IsPtrFloat32() const; private: Kind kind_; }; +// ─── Value ──────────────────────────────────────────────────────────────────── class Value { public: Value(std::shared_ptr ty, std::string name); @@ -118,12 +100,17 @@ class Value { const std::string& GetName() const; void SetName(std::string n); bool IsVoid() const; + bool IsInt1() const; bool IsInt32() const; + bool IsFloat32() const; bool IsPtrInt32() const; + bool IsPtrFloat32() const; bool IsConstant() const; bool IsInstruction() const; bool IsUser() const; bool IsFunction() const; + bool IsBasicBlock() const; + bool IsGlobalVariable() const; void AddUse(User* user, size_t operand_index); void RemoveUse(User* user, size_t operand_index); const std::vector& GetUses() const; @@ -135,8 +122,7 @@ class Value { std::vector uses_; }; -// ConstantValue 是常量体系的基类。 -// 当前只实现了 ConstantInt,后续可继续扩展更多常量种类。 +// ─── ConstantValue ───────────────────────────────────────────────────────────── class ConstantValue : public Value { public: ConstantValue(std::shared_ptr ty, std::string name = ""); @@ -151,11 +137,42 @@ class ConstantInt : public ConstantValue { int value_{}; }; -// 后续还需要扩展更多指令类型。 -enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret }; +class ConstantFloat : public ConstantValue { + public: + ConstantFloat(std::shared_ptr ty, float v); + float GetValue() const { return value_; } + + private: + float value_{}; +}; -// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。 -// 当前实现中只有 Instruction 继承自 User。 +// ─── Opcode ─────────────────────────────────────────────────────────────────── +enum class Opcode { + // 整数算术 + Add, Sub, Mul, Div, Mod, + // 浮点算术 + FAdd, FSub, FMul, FDiv, + // 比较(结果为 i1) + ICmp, FCmp, + // 内存 + Alloca, Load, Store, + // 地址计算 + Gep, + // 控制流 + Ret, Br, CondBr, + // 函数调用 + Call, + // 类型转换 + ZExt, SIToFP, FPToSI, +}; + +// ICmp 谓词 +enum class ICmpPredicate { EQ, NE, SLT, SLE, SGT, SGE }; + +// FCmp 谓词 +enum class FCmpPredicate { OEQ, ONE, OLT, OLE, OGT, OGE }; + +// ─── User ───────────────────────────────────────────────────────────────────── class User : public Value { public: User(std::shared_ptr ty, std::string name); @@ -164,20 +181,36 @@ class User : public Value { void SetOperand(size_t index, Value* value); protected: - // 统一的 operand 入口。 void AddOperand(Value* value); private: std::vector operands_; }; -// GlobalValue 是全局值/全局变量体系的空壳占位类。 -// 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。 +// ─── GlobalValue (占位) ──────────────────────────────────────────────────────── class GlobalValue : public User { public: GlobalValue(std::shared_ptr ty, std::string name); }; +// ─── GlobalVariable ──────────────────────────────────────────────────────────── +// 表示全局整型变量或常量。类型为 i32*(可直接用于 load/store)。 +class GlobalVariable : public Value { + public: + GlobalVariable(std::string name, bool is_const, int init_val, + int num_elements = 1); + bool IsConst() const { return is_const_; } + int GetInitVal() const { return init_val_; } + int GetNumElements() const { return num_elements_; } + bool IsArray() const { return num_elements_ > 1; } + // GlobalVariable 的"指针类型"是 i32*,访问时使用 load/store + private: + bool is_const_; + int init_val_; + int num_elements_; +}; + +// ─── Instruction ────────────────────────────────────────────────────────────── class Instruction : public User { public: Instruction(Opcode op, std::shared_ptr ty, std::string name = ""); @@ -191,23 +224,125 @@ class Instruction : public User { BasicBlock* parent_ = nullptr; }; +// 二元算术指令(i32 × i32 → i32) class BinaryInst : public Instruction { public: BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, Value* rhs, std::string name); Value* GetLhs() const; - Value* GetRhs() const; + Value* GetRhs() const; }; +// 整数比较指令(i32 × i32 → i1) +class ICmpInst : public Instruction { + public: + ICmpInst(ICmpPredicate pred, Value* lhs, Value* rhs, std::string name); + ICmpPredicate GetPredicate() const { return pred_; } + Value* GetLhs() const; + Value* GetRhs() const; + + private: + ICmpPredicate pred_; +}; + +// 无条件跳转 +class BrInst : public Instruction { + public: + explicit BrInst(BasicBlock* target); + BasicBlock* GetTarget() const; +}; + +// 条件跳转 +class CondBrInst : public Instruction { + public: + CondBrInst(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb); + Value* GetCond() const; + BasicBlock* GetTrueBB() const; + BasicBlock* GetFalseBB() const; +}; + +// 函数调用 +// callee 为 nullptr 时表示外部函数,使用 callee_name_ +class CallInst : public Instruction { + public: + // 调用已知 Function(模块内定义) + CallInst(Function* callee, std::vector args, std::string name); + // 调用外部声明函数(名称 + 返回类型) + CallInst(std::string callee_name, std::shared_ptr ret_type, + std::vector args, std::string name); + bool IsExternal() const { return callee_ == nullptr; } + Function* GetCallee() const { return callee_; } + const std::string& GetCalleeName() const { return callee_name_; } + size_t GetNumArgs() const { return GetNumOperands(); } + Value* GetArg(size_t i) const { return GetOperand(i); } + + private: + Function* callee_ = nullptr; + std::string callee_name_; +}; + +// 零扩展:i1 → i32 +class ZExtInst : public Instruction { + public: + ZExtInst(Value* val, std::string name); + Value* GetSrc() const; +}; + +// 浮点比较指令(f32 × f32 → i1) +class FCmpInst : public Instruction { + public: + FCmpInst(FCmpPredicate pred, Value* lhs, Value* rhs, std::string name); + FCmpPredicate GetPredicate() const { return pred_; } + Value* GetLhs() const; + Value* GetRhs() const; + + private: + FCmpPredicate pred_; +}; + +// 有符号整数转浮点:i32 → f32 +class SIToFPInst : public Instruction { + public: + SIToFPInst(Value* val, std::string name); + Value* GetSrc() const; +}; + +// 浮点转有符号整数:f32 → i32 +class FPToSIInst : public Instruction { + public: + FPToSIInst(Value* val, std::string name); + Value* GetSrc() const; +}; + +// return 语句(val 为 nullptr 表示 void return) class ReturnInst : public Instruction { public: - ReturnInst(std::shared_ptr void_ty, Value* val); - Value* GetValue() const; + // 有返回值 + explicit ReturnInst(Value* val); + // void 返回 + ReturnInst(); + bool HasValue() const { return GetNumOperands() > 0; } + Value* GetValue() const; // 可能为 nullptr }; class AllocaInst : public Instruction { public: + // 标量 alloca(num_elements == 1) AllocaInst(std::shared_ptr ptr_ty, std::string name); + // 数组 alloca(num_elements > 1) + AllocaInst(std::shared_ptr ptr_ty, int num_elements, std::string name); + int GetNumElements() const { return num_elements_; } + bool IsArray() const { return num_elements_ > 1; } + private: + int num_elements_ = 1; +}; + +// GetElementPtr: ptr + index → i32* +class GepInst : public Instruction { + public: + GepInst(Value* base_ptr, Value* index, std::string name); + Value* GetBasePtr() const; + Value* GetIndex() const; }; class LoadInst : public Instruction { @@ -218,13 +353,12 @@ class LoadInst : public Instruction { class StoreInst : public Instruction { public: - StoreInst(std::shared_ptr void_ty, Value* val, Value* ptr); + StoreInst(Value* val, Value* ptr); Value* GetValue() const; Value* GetPtr() const; }; -// BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。 -// 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。 +// ─── BasicBlock ─────────────────────────────────────────────────────────────── class BasicBlock : public Value { public: explicit BasicBlock(std::string name); @@ -234,11 +368,11 @@ class BasicBlock : public Value { const std::vector>& GetInstructions() const; const std::vector& GetPredecessors() const; const std::vector& GetSuccessors() const; + template T* Append(Args&&... args) { if (HasTerminator()) { - throw std::runtime_error("BasicBlock 已有 terminator,不能继续追加指令: " + - name_); + throw std::runtime_error("BasicBlock 已有 terminator: " + name_); } auto inst = std::make_unique(std::forward(args)...); auto* ptr = inst.get(); @@ -254,62 +388,142 @@ class BasicBlock : public Value { std::vector successors_; }; -// Function 当前也采用了最小实现。 -// 需要特别注意:由于项目里还没有单独的 FunctionType, -// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”, -// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。 -// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、 -// 形参和调用,通常需要引入专门的函数类型表示。 +// ─── Argument ───────────────────────────────────────────────────────────────── +// 函数形式参数,作为 SSA 值可用于 store 等指令 +class Argument : public Value { + public: + Argument(std::shared_ptr ty, std::string name) + : Value(std::move(ty), std::move(name)) {} +}; + +// ─── Function ───────────────────────────────────────────────────────────────── class Function : public Value { public: - // 当前构造函数接收的也是返回类型,而不是完整函数类型。 Function(std::string name, std::shared_ptr ret_type); BasicBlock* CreateBlock(const std::string& name); BasicBlock* GetEntry(); const BasicBlock* GetEntry() const; const std::vector>& GetBlocks() const; + // 参数值(按顺序,与 GetParamNames/GetParamTypes 一一对应) + Argument* AddArgument(std::shared_ptr ty, const std::string& name); + Argument* GetArgument(size_t i) const; + size_t GetNumArgs() const { return args_.size(); } + bool IsVoidReturn() const { return type_->IsVoid(); } private: BasicBlock* entry_ = nullptr; std::vector> blocks_; + std::vector> args_; +}; + +// ─── ExternalFuncDecl ───────────────────────────────────────────────────────── +struct ExternalFuncDecl { + std::string name; + std::shared_ptr ret_type; + std::vector> param_types; + bool is_variadic = false; }; +// ─── Module ─────────────────────────────────────────────────────────────────── class Module { public: Module() = default; Context& GetContext(); const Context& GetContext() const; - // 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。 + Function* CreateFunction(const std::string& name, std::shared_ptr ret_type); + Function* GetFunction(const std::string& name) const; const std::vector>& GetFunctions() const; + GlobalVariable* CreateGlobalVariable(const std::string& name, bool is_const, + int init_val, int num_elements = 1); + GlobalVariable* GetGlobalVariable(const std::string& name) const; + const std::vector>& GetGlobalVariables() const; + + void DeclareExternalFunc(const std::string& name, + std::shared_ptr ret_type, + std::vector> param_types, + bool is_variadic = false); + bool HasExternalDecl(const std::string& name) const; + const std::vector& GetExternalDecls() const; + private: Context context_; std::vector> functions_; + std::unordered_map func_map_; + std::vector> globals_; + std::unordered_map global_map_; + std::vector external_decls_; + std::unordered_map external_decl_index_; }; +// ─── IRBuilder ──────────────────────────────────────────────────────────────── class IRBuilder { public: IRBuilder(Context& ctx, BasicBlock* bb); void SetInsertPoint(BasicBlock* bb); BasicBlock* GetInsertBlock() const; - // 构造常量、二元运算、返回指令的最小集合。 + // 常量 ConstantInt* CreateConstInt(int v); + ConstantFloat* CreateConstFloat(float v); + + // 整数算术 BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name); BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateDiv(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateMod(Value* lhs, Value* rhs, const std::string& name); + + // 浮点算术 + BinaryInst* CreateFAdd(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateFSub(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateFMul(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateFDiv(Value* lhs, Value* rhs, const std::string& name); + + // 比较(返回 i1) + ICmpInst* CreateICmp(ICmpPredicate pred, Value* lhs, Value* rhs, + const std::string& name); + FCmpInst* CreateFCmp(FCmpPredicate pred, Value* lhs, Value* rhs, + const std::string& name); + + // 内存 AllocaInst* CreateAllocaI32(const std::string& name); + AllocaInst* CreateAllocaF32(const std::string& name); + AllocaInst* CreateAllocaArray(int num_elements, const std::string& name); + GepInst* CreateGep(Value* base_ptr, Value* index, const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name); StoreInst* CreateStore(Value* val, Value* ptr); + + // 控制流 ReturnInst* CreateRet(Value* v); + ReturnInst* CreateRetVoid(); + BrInst* CreateBr(BasicBlock* target); + CondBrInst* CreateCondBr(Value* cond, BasicBlock* true_bb, + BasicBlock* false_bb); + + // 调用 + CallInst* CreateCall(Function* callee, std::vector args, + const std::string& name = ""); + CallInst* CreateCallExternal(const std::string& callee_name, + std::shared_ptr ret_type, + std::vector args, + const std::string& name = ""); + + // 类型转换 + ZExtInst* CreateZExt(Value* val, const std::string& name); + SIToFPInst* CreateSIToFP(Value* val, const std::string& name); + FPToSIInst* CreateFPToSI(Value* val, const std::string& name); private: Context& ctx_; BasicBlock* insert_block_; }; +// ─── IRPrinter ──────────────────────────────────────────────────────────────── class IRPrinter { public: void Print(const Module& module, std::ostream& os); diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index 231ba90..e509d53 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -1,5 +1,5 @@ -// 将语法树翻译为 IR。 -// 实现拆分在 IRGenFunc/IRGenStmt/IRGenExp/IRGenDecl。 +// IRGen:语法树 → IR +// 按语法树节点类型分发到 IRGenFunc/IRGenStmt/IRGenExp/IRGenDecl #pragma once @@ -7,51 +7,110 @@ #include #include #include +#include #include "SysYBaseVisitor.h" #include "SysYParser.h" #include "ir/IR.h" #include "sem/Sema.h" -namespace ir { -class Module; -class Function; -class IRBuilder; -class Value; -} - class IRGenImpl final : public SysYBaseVisitor { public: IRGenImpl(ir::Module& module, const SemanticContext& sema); + // ── 顶层 ────────────────────────────────────────────────────────────────── std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; + + // ── 函数 ────────────────────────────────────────────────────────────────── std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; + + // ── 块与块内项 ───────────────────────────────────────────────────────────── + std::any visitBlock(SysYParser::BlockContext* ctx) override; std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; + + // ── 声明 ────────────────────────────────────────────────────────────────── std::any visitDecl(SysYParser::DeclContext* ctx) override; - std::any visitStmt(SysYParser::StmtContext* ctx) override; + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override; std::any visitVarDef(SysYParser::VarDefContext* ctx) override; - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override; - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override; - std::any visitVarExp(SysYParser::VarExpContext* ctx) override; - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override; + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override; + + // ── 语句 ────────────────────────────────────────────────────────────────── + std::any visitStmt(SysYParser::StmtContext* ctx) override; + + // ── 表达式(返回 ir::Value*,i32) ───────────────────────────────────────── + std::any visitExp(SysYParser::ExpContext* ctx) override; + std::any visitAddExp(SysYParser::AddExpContext* ctx) override; + std::any visitMulExp(SysYParser::MulExpContext* ctx) override; + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override; + std::any visitLVar(SysYParser::LVarContext* ctx) override; + std::any visitNumber(SysYParser::NumberContext* ctx) override; + + // ── 条件(返回 ir::Value*,i1) ──────────────────────────────────────────── + std::any visitCond(SysYParser::CondContext* ctx) override; + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override; + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override; + std::any visitEqExp(SysYParser::EqExpContext* ctx) override; + std::any visitRelExp(SysYParser::RelExpContext* ctx) override; private: - enum class BlockFlow { - Continue, - Terminated, - }; + // BlockFlow 用于通知调用者当前块是否已终结 + enum class BlockFlow { Continue, Terminated }; - BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item); + // ── 辅助函数 ────────────────────────────────────────────────────────────── + // 求值表达式,返回 i32 值 ir::Value* EvalExpr(SysYParser::ExpContext& expr); + ir::Value* EvalExprAdd(SysYParser::AddExpContext& expr); + // 求值条件,返回 i1 值(供 if/while 使用) + ir::Value* EvalCond(SysYParser::CondContext& cond); + // 把整型值转为 i1 条件(icmp ne i32 val, 0) + ir::Value* ToI1(ir::Value* v); + // 把 i1 值零扩展为 i32 + ir::Value* ToI32(ir::Value* v); + // 隐式类型转换:确保两个操作数类型一致(int 转 float) + void ImplicitConvert(ir::Value*& lhs, ir::Value*& rhs); + // 转换为 float(如果是 int) + ir::Value* ToFloat(ir::Value* v); + // 转换为 int(如果是 float) + ir::Value* ToInt(ir::Value* v); + + // 访问一条块内项,返回流控状态 + BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item); + // 访问 stmt,返回流控状态 + BlockFlow VisitStmt(SysYParser::StmtContext& stmt); + // 向外部函数声明注册(幂等) + void EnsureExternalDecl(const std::string& name); + + // ── 状态 ────────────────────────────────────────────────────────────────── ir::Module& module_; const SemanticContext& sema_; - ir::Function* func_; + ir::Function* func_ = nullptr; ir::IRBuilder builder_; - // 名称绑定由 Sema 负责;IRGen 只维护“声明 -> 存储槽位”的代码生成状态。 - std::unordered_map storage_map_; + + // 局部变量/参数的存储槽位:声明上下文 → alloca Value* + std::unordered_map storage_map_; + // 全局变量/常量的存储槽位(不随函数切换清空) + std::unordered_map global_storage_map_; + + // 数组维度信息:声明上下文 → 各维度大小(从外到内) + std::unordered_map> array_dims_; + std::unordered_map> global_array_dims_; + + // 辅助:求常量表达式的整数值(失败返回 0) + int EvalConstExprInt(SysYParser::ConstExpContext* ctx); + // 辅助:计算一个 lVar 的元素地址(支持多维数组) + ir::Value* EvalLVarAddr(SysYParser::LVarContext* ctx); + + // 是否处于全局作用域(visitCompUnit 中处理 decl 时) + bool in_global_scope_ = false; + + // 循环上下文(break/continue 的目标块) + struct LoopCtx { + ir::BasicBlock* cond_bb; // continue 跳到这里 + ir::BasicBlock* after_bb; // break 跳到这里 + }; + std::vector loop_stack_; }; std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, diff --git a/include/sem/Sema.h b/include/sem/Sema.h index 8f3524b..efdfea7 100644 --- a/include/sem/Sema.h +++ b/include/sem/Sema.h @@ -9,9 +9,6 @@ class SemanticContext { public: void BindVarUse(antlr4::tree::TerminalNode* use, antlr4::ParserRuleContext* decl) { - if(ResolveVarUse(use)){ - throw std::runtime_error(FormatError("sema", "变量名重定义")); - } var_uses_[use] = decl; } diff --git a/scripts/run_ir_test.sh b/scripts/run_ir_test.sh new file mode 100755 index 0000000..cb5d120 --- /dev/null +++ b/scripts/run_ir_test.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd) +COMPILER="$PROJECT_ROOT/build/bin/compiler" +TEST_CASE_DIR="$PROJECT_ROOT/test/test_case" +TEST_RESULT_DIR="$PROJECT_ROOT/test/test_result/ir" + +if [ ! -x "$COMPILER" ]; then + echo "错误:编译器不存在或不可执行: $COMPILER" + echo "请先构建项目:cmake --build build -j\$(nproc)" + exit 1 +fi + +mkdir -p "$TEST_RESULT_DIR" + +pass_count=0 +fail_count=0 +failed_cases=() + +echo "=== 开始测试 IR 生成 ===" +echo "" + +while IFS= read -r test_file; do + relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file") + output_file="$TEST_RESULT_DIR/${relative_path%.sy}.ll" + + mkdir -p "$(dirname "$output_file")" + + echo -n "测试: $relative_path ... " + + # stderr 单独捕获,stdout 写到输出文件 + err_output=$("$COMPILER" --emit-ir "$test_file" > "$output_file" 2>&1) + exit_code=$? + + if [ $exit_code -eq 0 ]; then + # 确认输出文件非空且不含 [error] + if [ -s "$output_file" ] && ! grep -q '\[error\]' "$output_file"; then + echo "通过" + pass_count=$((pass_count + 1)) + else + echo "失败 (生成内容含错误)" + fail_count=$((fail_count + 1)) + failed_cases+=("$relative_path") + echo " 错误信息已保存到: $output_file" + fi + else + echo "失败" + fail_count=$((fail_count + 1)) + failed_cases+=("$relative_path") + echo " 错误信息已保存到: $output_file" + fi +done < <(find "$TEST_CASE_DIR" -name "*.sy" | sort) + +echo "" +echo "=== 测试完成 ===" +echo "通过: $pass_count" +echo "失败: $fail_count" +echo "结果保存在: $TEST_RESULT_DIR" + +if [ ${#failed_cases[@]} -gt 0 ]; then + echo "" + echo "=== 失败的用例 ===" + for f in "${failed_cases[@]}"; do + echo " - $f" + done + exit 1 +fi diff --git a/scripts/verify_ir.sh b/scripts/verify_ir.sh index f41f6b3..049b725 100755 --- a/scripts/verify_ir.sh +++ b/scripts/verify_ir.sh @@ -91,4 +91,4 @@ if [[ "$run_exec" == true ]]; then else echo "未找到预期输出文件,跳过比对: $expected_file" fi -fi +fi \ No newline at end of file diff --git a/src/ir/Context.cpp b/src/ir/Context.cpp index 16c982c..a6a6ac5 100644 --- a/src/ir/Context.cpp +++ b/src/ir/Context.cpp @@ -1,4 +1,4 @@ -// 管理基础类型、整型常量池和临时名生成。 +// 管理基础类型、常量池和临时名生成。 #include "ir/IR.h" #include @@ -15,9 +15,17 @@ ConstantInt* Context::GetConstInt(int v) { return inserted->second.get(); } +ConstantFloat* Context::GetConstFloat(float v) { + auto it = const_floats_.find(v); + if (it != const_floats_.end()) return it->second.get(); + auto inserted = + const_floats_.emplace(v, std::make_unique(Type::GetFloat32Type(), v)).first; + return inserted->second.get(); +} + std::string Context::NextTemp() { std::ostringstream oss; - oss << "%" << ++temp_index_; + oss << ++temp_index_; return oss.str(); } diff --git a/src/ir/Function.cpp b/src/ir/Function.cpp index cf14d48..4cc7067 100644 --- a/src/ir/Function.cpp +++ b/src/ir/Function.cpp @@ -1,6 +1,4 @@ -// IR Function: -// - 保存参数列表、基本块列表 -// - 记录函数属性/元信息(按需要扩展) +// IR Function #include "ir/IR.h" namespace ir { @@ -15,18 +13,25 @@ BasicBlock* Function::CreateBlock(const std::string& name) { auto* ptr = block.get(); ptr->SetParent(this); blocks_.push_back(std::move(block)); - if (!entry_) { - entry_ = ptr; - } + if (!entry_) entry_ = ptr; return ptr; } BasicBlock* Function::GetEntry() { return entry_; } - const BasicBlock* Function::GetEntry() const { return entry_; } - const std::vector>& Function::GetBlocks() const { return blocks_; } +Argument* Function::AddArgument(std::shared_ptr ty, + const std::string& name) { + args_.push_back(std::make_unique(std::move(ty), name)); + return args_.back().get(); +} + +Argument* Function::GetArgument(size_t i) const { + if (i >= args_.size()) return nullptr; + return args_[i].get(); +} + } // namespace ir diff --git a/src/ir/IRBuilder.cpp b/src/ir/IRBuilder.cpp index 90f03c4..c11568f 100644 --- a/src/ir/IRBuilder.cpp +++ b/src/ir/IRBuilder.cpp @@ -1,7 +1,3 @@ -// IR 构建工具: -// - 管理插入点(当前基本块/位置) -// - 提供创建各类指令的便捷接口,降低 IRGen 复杂度 - #include "ir/IR.h" #include @@ -9,32 +5,32 @@ #include "utils/Log.h" namespace ir { + IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb) : ctx_(ctx), insert_block_(bb) {} void IRBuilder::SetInsertPoint(BasicBlock* bb) { insert_block_ = bb; } - BasicBlock* IRBuilder::GetInsertBlock() const { return insert_block_; } ConstantInt* IRBuilder::CreateConstInt(int v) { - // 常量不需要挂在基本块里,由 Context 负责去重与生命周期。 return ctx_.GetConstInt(v); } +ConstantFloat* IRBuilder::CreateConstFloat(float v) { + return ctx_.GetConstFloat(v); +} + +// ─── 算术 ───────────────────────────────────────────────────────────────────── BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); } - if (!lhs) { - throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateBinary 缺少 lhs")); - } - if (!rhs) { - throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateBinary 缺少 rhs")); + if (!lhs || !rhs) { + throw std::runtime_error(FormatError("ir", "CreateBinary 缺少操作数")); } - return insert_block_->Append(op, lhs->GetType(), lhs, rhs, name); + return insert_block_->Append(op, Type::GetInt32Type(), lhs, rhs, + name); } BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs, @@ -42,6 +38,81 @@ BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs, return CreateBinary(Opcode::Add, lhs, rhs, name); } +BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::Sub, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::Mul, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateDiv(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::Div, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateMod(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::Mod, lhs, rhs, name); +} + +// ─── 浮点算术 ───────────────────────────────────────────────────────────────── +BinaryInst* IRBuilder::CreateFAdd(Value* lhs, Value* rhs, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Opcode::FAdd, Type::GetFloat32Type(), + lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateFSub(Value* lhs, Value* rhs, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Opcode::FSub, Type::GetFloat32Type(), + lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateFMul(Value* lhs, Value* rhs, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Opcode::FMul, Type::GetFloat32Type(), + lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateFDiv(Value* lhs, Value* rhs, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Opcode::FDiv, Type::GetFloat32Type(), + lhs, rhs, name); +} + +// ─── 比较 ───────────────────────────────────────────────────────────────────── +ICmpInst* IRBuilder::CreateICmp(ICmpPredicate pred, Value* lhs, Value* rhs, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(pred, lhs, rhs, name); +} + +FCmpInst* IRBuilder::CreateFCmp(FCmpPredicate pred, Value* lhs, Value* rhs, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(pred, lhs, rhs, name); +} + +// ─── 内存 ───────────────────────────────────────────────────────────────────── AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); @@ -49,41 +120,121 @@ AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { return insert_block_->Append(Type::GetPtrInt32Type(), name); } +AllocaInst* IRBuilder::CreateAllocaF32(const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Type::GetPtrFloat32Type(), name); +} + +AllocaInst* IRBuilder::CreateAllocaArray(int num_elements, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Type::GetPtrInt32Type(), + num_elements, name); +} + +GepInst* IRBuilder::CreateGep(Value* base_ptr, Value* index, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(base_ptr, index, name); +} + LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); } if (!ptr) { - throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr")); + throw std::runtime_error(FormatError("ir", "CreateLoad 缺少 ptr")); } - return insert_block_->Append(Type::GetInt32Type(), ptr, name); + auto val_type = ptr->GetType()->IsPtrFloat32() ? Type::GetFloat32Type() : Type::GetInt32Type(); + return insert_block_->Append(val_type, ptr, name); } StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); } - if (!val) { - throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateStore 缺少 val")); - } - if (!ptr) { - throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateStore 缺少 ptr")); + if (!val || !ptr) { + throw std::runtime_error(FormatError("ir", "CreateStore 缺少操作数")); } - return insert_block_->Append(Type::GetVoidType(), val, ptr); + return insert_block_->Append(val, ptr); } +// ─── 控制流 ─────────────────────────────────────────────────────────────────── ReturnInst* IRBuilder::CreateRet(Value* v) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); } - if (!v) { - throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateRet 缺少返回值")); + return insert_block_->Append(v); +} + +ReturnInst* IRBuilder::CreateRetVoid() { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(); +} + +BrInst* IRBuilder::CreateBr(BasicBlock* target) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(target); +} + +CondBrInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* true_bb, + BasicBlock* false_bb) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(cond, true_bb, false_bb); +} + +// ─── 调用 ───────────────────────────────────────────────────────────────────── +CallInst* IRBuilder::CreateCall(Function* callee, std::vector args, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(callee, std::move(args), name); +} + +CallInst* IRBuilder::CreateCallExternal(const std::string& callee_name, + std::shared_ptr ret_type, + std::vector args, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(callee_name, std::move(ret_type), + std::move(args), name); +} + +// ─── 类型转换 ───────────────────────────────────────────────────────────────── +ZExtInst* IRBuilder::CreateZExt(Value* val, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(val, name); +} + +SIToFPInst* IRBuilder::CreateSIToFP(Value* val, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(val, name); +} + +FPToSIInst* IRBuilder::CreateFPToSI(Value* val, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); } - return insert_block_->Append(Type::GetVoidType(), v); + return insert_block_->Append(val, name); } } // namespace ir diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 30efbb6..6e0a827 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -1,7 +1,3 @@ -// IR 文本输出: -// - 将 IR 打印为 .ll 风格的文本 -// - 支撑调试与测试对比(diff) - #include "ir/IR.h" #include @@ -12,95 +8,299 @@ namespace ir { -static const char* TypeToString(const Type& ty) { +static const char* TypeToStr(const Type& ty) { switch (ty.GetKind()) { - case Type::Kind::Void: - return "void"; - case Type::Kind::Int32: - return "i32"; - case Type::Kind::PtrInt32: - return "i32*"; + case Type::Kind::Void: return "void"; + case Type::Kind::Int1: return "i1"; + case Type::Kind::Int32: return "i32"; + case Type::Kind::Float32: return "float"; + case Type::Kind::PtrInt32: return "i32*"; + case Type::Kind::PtrFloat32: return "float*"; } throw std::runtime_error(FormatError("ir", "未知类型")); } -static const char* OpcodeToString(Opcode op) { - switch (op) { - case Opcode::Add: - return "add"; - case Opcode::Sub: - return "sub"; - case Opcode::Mul: - return "mul"; - case Opcode::Alloca: - return "alloca"; - case Opcode::Load: - return "load"; - case Opcode::Store: - return "store"; - case Opcode::Ret: - return "ret"; +static const char* PredToStr(ICmpPredicate pred) { + switch (pred) { + case ICmpPredicate::EQ: return "eq"; + case ICmpPredicate::NE: return "ne"; + case ICmpPredicate::SLT: return "slt"; + case ICmpPredicate::SLE: return "sle"; + case ICmpPredicate::SGT: return "sgt"; + case ICmpPredicate::SGE: return "sge"; + } + return "?"; +} + +static const char* FPredToStr(FCmpPredicate pred) { + switch (pred) { + case FCmpPredicate::OEQ: return "oeq"; + case FCmpPredicate::ONE: return "one"; + case FCmpPredicate::OLT: return "olt"; + case FCmpPredicate::OLE: return "ole"; + case FCmpPredicate::OGT: return "ogt"; + case FCmpPredicate::OGE: return "oge"; } return "?"; } -static std::string ValueToString(const Value* v) { +static std::string ValStr(const Value* v) { + if (!v) return ""; if (auto* ci = dynamic_cast(v)) { return std::to_string(ci->GetValue()); } - return v ? v->GetName() : ""; + if (auto* cf = dynamic_cast(v)) { + return std::to_string(cf->GetValue()); + } + // BasicBlock: 打印为 label %name + if (dynamic_cast(v)) { + return "%" + v->GetName(); + } + // GlobalVariable: 打印为 @name + if (auto* gv = dynamic_cast(v)) { + if (gv->IsArray()) { + // 数组全局变量的指针:getelementptr [N x i32], [N x i32]* @name, i32 0, i32 0 + return "getelementptr ([" + std::to_string(gv->GetNumElements()) + + " x i32], [" + std::to_string(gv->GetNumElements()) + + " x i32]* @" + gv->GetName() + ", i32 0, i32 0)"; + } + return "@" + v->GetName(); + } + return "%" + v->GetName(); +} + +static std::string TypeVal(const Value* v) { + if (!v) return "void"; + if (auto* ci = dynamic_cast(v)) { + return std::string(TypeToStr(*ci->GetType())) + " " + + std::to_string(ci->GetValue()); + } + if (auto* cf = dynamic_cast(v)) { + return std::string(TypeToStr(*cf->GetType())) + " " + + std::to_string(cf->GetValue()); + } + return std::string(TypeToStr(*v->GetType())) + " " + ValStr(v); } void IRPrinter::Print(const Module& module, std::ostream& os) { + // 1. 全局变量/常量 + for (const auto& g : module.GetGlobalVariables()) { + if (g->IsArray()) { + // 全局数组:zeroinitializer + if (g->IsConst()) { + os << "@" << g->GetName() << " = constant [" << g->GetNumElements() + << " x i32] zeroinitializer\n"; + } else { + os << "@" << g->GetName() << " = global [" << g->GetNumElements() + << " x i32] zeroinitializer\n"; + } + } else { + if (g->IsConst()) { + os << "@" << g->GetName() << " = constant i32 " << g->GetInitVal() + << "\n"; + } else { + os << "@" << g->GetName() << " = global i32 " << g->GetInitVal() + << "\n"; + } + } + } + if (!module.GetGlobalVariables().empty()) os << "\n"; + + // 2. 外部函数声明 + for (const auto& decl : module.GetExternalDecls()) { + os << "declare " << TypeToStr(*decl.ret_type) << " @" << decl.name << "("; + for (size_t i = 0; i < decl.param_types.size(); ++i) { + if (i > 0) os << ", "; + os << TypeToStr(*decl.param_types[i]); + } + if (decl.is_variadic) { + if (!decl.param_types.empty()) os << ", "; + os << "..."; + } + os << ")\n"; + } + if (!module.GetExternalDecls().empty()) os << "\n"; + + // 3. 函数定义 for (const auto& func : module.GetFunctions()) { - os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() - << "() {\n"; + os << "define " << TypeToStr(*func->GetType()) << " @" << func->GetName() + << "("; + for (size_t i = 0; i < func->GetNumArgs(); ++i) { + if (i > 0) os << ", "; + auto* arg = func->GetArgument(i); + os << TypeToStr(*arg->GetType()) << " %" << arg->GetName(); + } + os << ") {\n"; + for (const auto& bb : func->GetBlocks()) { - if (!bb) { - continue; - } + if (!bb) continue; os << bb->GetName() << ":\n"; for (const auto& instPtr : bb->GetInstructions()) { const auto* inst = instPtr.get(); switch (inst->GetOpcode()) { + // ── 算术 ────────────────────────────────────────────────────────── case Opcode::Add: case Opcode::Sub: - case Opcode::Mul: { + case Opcode::Mul: + case Opcode::Div: + case Opcode::Mod: { + auto* bin = static_cast(inst); + const char* op_str = nullptr; + switch (bin->GetOpcode()) { + case Opcode::Add: op_str = "add"; break; + case Opcode::Sub: op_str = "sub"; break; + case Opcode::Mul: op_str = "mul"; break; + case Opcode::Div: op_str = "sdiv"; break; + case Opcode::Mod: op_str = "srem"; break; + default: op_str = "?"; break; + } + os << " %" << bin->GetName() << " = " << op_str << " i32 " + << ValStr(bin->GetLhs()) << ", " << ValStr(bin->GetRhs()) + << "\n"; + break; + } + // ── 浮点算术 ────────────────────────────────────────────────────── + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: { auto* bin = static_cast(inst); - os << " " << bin->GetName() << " = " - << OpcodeToString(bin->GetOpcode()) << " " - << TypeToString(*bin->GetLhs()->GetType()) << " " - << ValueToString(bin->GetLhs()) << ", " - << ValueToString(bin->GetRhs()) << "\n"; + const char* op_str = nullptr; + switch (bin->GetOpcode()) { + case Opcode::FAdd: op_str = "fadd"; break; + case Opcode::FSub: op_str = "fsub"; break; + case Opcode::FMul: op_str = "fmul"; break; + case Opcode::FDiv: op_str = "fdiv"; break; + default: op_str = "?"; break; + } + os << " %" << bin->GetName() << " = " << op_str << " float " + << ValStr(bin->GetLhs()) << ", " << ValStr(bin->GetRhs()) + << "\n"; + break; + } + // ── 比较 ────────────────────────────────────────────────────────── + case Opcode::ICmp: { + auto* cmp = static_cast(inst); + os << " %" << cmp->GetName() << " = icmp " + << PredToStr(cmp->GetPredicate()) << " i32 " + << ValStr(cmp->GetLhs()) << ", " << ValStr(cmp->GetRhs()) + << "\n"; + break; + } + case Opcode::FCmp: { + auto* cmp = static_cast(inst); + os << " %" << cmp->GetName() << " = fcmp " + << FPredToStr(cmp->GetPredicate()) << " float " + << ValStr(cmp->GetLhs()) << ", " << ValStr(cmp->GetRhs()) + << "\n"; break; } + // ── 内存 ────────────────────────────────────────────────────────── case Opcode::Alloca: { - auto* alloca = static_cast(inst); - os << " " << alloca->GetName() << " = alloca i32\n"; + auto* al = static_cast(inst); + const char* elem_type = al->GetType()->IsPtrFloat32() ? "float" : "i32"; + if (al->IsArray()) { + os << " %" << al->GetName() << " = alloca " << elem_type << ", i32 " + << al->GetNumElements() << "\n"; + } else { + os << " %" << al->GetName() << " = alloca " << elem_type << "\n"; + } + break; + } + case Opcode::Gep: { + auto* gep = static_cast(inst); + os << " %" << gep->GetName() + << " = getelementptr i32, i32* " + << ValStr(gep->GetBasePtr()) << ", i32 " + << ValStr(gep->GetIndex()) << "\n"; break; } case Opcode::Load: { - auto* load = static_cast(inst); - os << " " << load->GetName() << " = load i32, i32* " - << ValueToString(load->GetPtr()) << "\n"; + auto* ld = static_cast(inst); + const char* val_type = ld->GetType()->IsFloat32() ? "float" : "i32"; + const char* ptr_type = ld->GetPtr()->GetType()->IsPtrFloat32() ? "float*" : "i32*"; + os << " %" << ld->GetName() << " = load " << val_type << ", " << ptr_type << " " + << ValStr(ld->GetPtr()) << "\n"; break; } case Opcode::Store: { - auto* store = static_cast(inst); - os << " store i32 " << ValueToString(store->GetValue()) - << ", i32* " << ValueToString(store->GetPtr()) << "\n"; + auto* st = static_cast(inst); + os << " store " << TypeVal(st->GetValue()) << ", " + << TypeToStr(*st->GetPtr()->GetType()) << " " + << ValStr(st->GetPtr()) << "\n"; break; } + // ── 控制流 ──────────────────────────────────────────────────────── case Opcode::Ret: { auto* ret = static_cast(inst); - os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " - << ValueToString(ret->GetValue()) << "\n"; + if (!ret->HasValue()) { + os << " ret void\n"; + } else { + auto* v = ret->GetValue(); + os << " ret " << TypeVal(v) << "\n"; + } + break; + } + case Opcode::Br: { + auto* br = static_cast(inst); + os << " br label %" << br->GetTarget()->GetName() << "\n"; + break; + } + case Opcode::CondBr: { + auto* cbr = static_cast(inst); + os << " br i1 " << ValStr(cbr->GetCond()) << ", label %" + << cbr->GetTrueBB()->GetName() << ", label %" + << cbr->GetFalseBB()->GetName() << "\n"; + break; + } + // ── 调用 ────────────────────────────────────────────────────────── + case Opcode::Call: { + auto* call = static_cast(inst); + std::string ret_type_str; + if (call->IsVoid()) { + ret_type_str = "void"; + } else { + ret_type_str = TypeToStr(*call->GetType()); + } + // 打印赋值部分(仅当有返回值时) + if (!call->IsVoid() && !call->GetName().empty()) { + os << " %" << call->GetName() << " = "; + } else { + os << " "; + } + os << "call " << ret_type_str << " @" << call->GetCalleeName() + << "("; + for (size_t i = 0; i < call->GetNumArgs(); ++i) { + if (i > 0) os << ", "; + auto* arg = call->GetArg(i); + os << TypeVal(arg); + } + os << ")\n"; + break; + } + // ── 类型转换 ────────────────────────────────────────────────────── + case Opcode::ZExt: { + auto* ze = static_cast(inst); + os << " %" << ze->GetName() << " = zext i1 " + << ValStr(ze->GetSrc()) << " to i32\n"; + break; + } + case Opcode::SIToFP: { + auto* si = static_cast(inst); + os << " %" << si->GetName() << " = sitofp i32 " + << ValStr(si->GetSrc()) << " to float\n"; + break; + } + case Opcode::FPToSI: { + auto* fp = static_cast(inst); + os << " %" << fp->GetName() << " = fptosi float " + << ValStr(fp->GetSrc()) << " to i32\n"; break; } } } } - os << "}\n"; + os << "}\n\n"; } } diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index 7928716..7830c0e 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -1,13 +1,13 @@ -// IR 指令体系: -// - 二元运算/比较、load/store、call、br/condbr、ret、phi、alloca 等 -// - 指令操作数与结果类型管理,支持打印与优化 #include "ir/IR.h" #include +#include #include "utils/Log.h" namespace ir { + +// ─── User ───────────────────────────────────────────────────────────────────── User::User(std::shared_ptr ty, std::string name) : Value(std::move(ty), std::move(name)) {} @@ -24,128 +24,251 @@ void User::SetOperand(size_t index, Value* value) { if (index >= operands_.size()) { throw std::out_of_range("User operand index out of range"); } - if (!value) { - throw std::runtime_error(FormatError("ir", "User operand 不能为空")); - } auto* old = operands_[index]; - if (old == value) { - return; - } - if (old) { - old->RemoveUse(this, index); - } + if (old == value) return; + if (old) old->RemoveUse(this, index); operands_[index] = value; - value->AddUse(this, index); + if (value) value->AddUse(this, index); } void User::AddOperand(Value* value) { - if (!value) { - throw std::runtime_error(FormatError("ir", "User operand 不能为空")); - } - size_t operand_index = operands_.size(); + size_t idx = operands_.size(); operands_.push_back(value); - value->AddUse(this, operand_index); + if (value) value->AddUse(this, idx); } +// ─── Instruction ────────────────────────────────────────────────────────────── Instruction::Instruction(Opcode op, std::shared_ptr ty, std::string name) : User(std::move(ty), std::move(name)), opcode_(op) {} Opcode Instruction::GetOpcode() const { return opcode_; } -bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret; } +bool Instruction::IsTerminator() const { + return opcode_ == Opcode::Ret || opcode_ == Opcode::Br || + opcode_ == Opcode::CondBr; +} BasicBlock* Instruction::GetParent() const { return parent_; } - void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; } +// ─── BinaryInst ─────────────────────────────────────────────────────────────── BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, Value* rhs, std::string name) : Instruction(op, std::move(ty), std::move(name)) { - if (op != Opcode::Add) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add")); + if (op != Opcode::Add && op != Opcode::Sub && op != Opcode::Mul && + op != Opcode::Div && op != Opcode::Mod && op != Opcode::FAdd && + op != Opcode::FSub && op != Opcode::FMul && op != Opcode::FDiv) { + throw std::runtime_error( + FormatError("ir", "BinaryInst: 不支持的操作码")); } if (!lhs || !rhs) { throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数")); } - if (!type_ || !lhs->GetType() || !rhs->GetType()) { - throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息")); - } - if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind() || - type_->GetKind() != lhs->GetType()->GetKind()) { - throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配")); + AddOperand(lhs); + AddOperand(rhs); +} + +Value* BinaryInst::GetLhs() const { return GetOperand(0); } +Value* BinaryInst::GetRhs() const { return GetOperand(1); } + +// ─── ICmpInst ───────────────────────────────────────────────────────────────── +ICmpInst::ICmpInst(ICmpPredicate pred, Value* lhs, Value* rhs, std::string name) + : Instruction(Opcode::ICmp, Type::GetInt1Type(), std::move(name)), + pred_(pred) { + if (!lhs || !rhs) { + throw std::runtime_error(FormatError("ir", "ICmpInst 缺少操作数")); } - if (!type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32")); + AddOperand(lhs); + AddOperand(rhs); +} + +Value* ICmpInst::GetLhs() const { return GetOperand(0); } +Value* ICmpInst::GetRhs() const { return GetOperand(1); } + +// ─── FCmpInst ───────────────────────────────────────────────────────────────── +FCmpInst::FCmpInst(FCmpPredicate pred, Value* lhs, Value* rhs, std::string name) + : Instruction(Opcode::FCmp, Type::GetInt1Type(), std::move(name)), + pred_(pred) { + if (!lhs || !rhs) { + throw std::runtime_error(FormatError("ir", "FCmpInst 缺少操作数")); } AddOperand(lhs); AddOperand(rhs); } -Value* BinaryInst::GetLhs() const { return GetOperand(0); } +Value* FCmpInst::GetLhs() const { return GetOperand(0); } +Value* FCmpInst::GetRhs() const { return GetOperand(1); } -Value* BinaryInst::GetRhs() const { return GetOperand(1); } +// ─── BrInst ─────────────────────────────────────────────────────────────────── +BrInst::BrInst(BasicBlock* target) + : Instruction(Opcode::Br, Type::GetVoidType(), "") { + if (!target) { + throw std::runtime_error(FormatError("ir", "BrInst 缺少目标")); + } + AddOperand(target); +} + +BasicBlock* BrInst::GetTarget() const { + return static_cast(GetOperand(0)); +} + +// ─── CondBrInst ─────────────────────────────────────────────────────────────── +CondBrInst::CondBrInst(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb) + : Instruction(Opcode::CondBr, Type::GetVoidType(), "") { + if (!cond || !true_bb || !false_bb) { + throw std::runtime_error(FormatError("ir", "CondBrInst 缺少操作数")); + } + AddOperand(cond); + AddOperand(true_bb); + AddOperand(false_bb); +} + +Value* CondBrInst::GetCond() const { return GetOperand(0); } +BasicBlock* CondBrInst::GetTrueBB() const { + return static_cast(GetOperand(1)); +} +BasicBlock* CondBrInst::GetFalseBB() const { + return static_cast(GetOperand(2)); +} -ReturnInst::ReturnInst(std::shared_ptr void_ty, Value* val) - : Instruction(Opcode::Ret, std::move(void_ty), "") { +// ─── CallInst ───────────────────────────────────────────────────────────────── +CallInst::CallInst(Function* callee, std::vector args, std::string name) + : Instruction(Opcode::Call, + callee ? callee->GetType() : Type::GetVoidType(), + std::move(name)), + callee_(callee) { + if (!callee) { + throw std::runtime_error(FormatError("ir", "CallInst: callee 为空")); + } + callee_name_ = callee->GetName(); + for (auto* arg : args) { + AddOperand(arg); + } +} + +CallInst::CallInst(std::string callee_name, std::shared_ptr ret_type, + std::vector args, std::string name) + : Instruction(Opcode::Call, std::move(ret_type), std::move(name)), + callee_(nullptr), + callee_name_(std::move(callee_name)) { + for (auto* arg : args) { + AddOperand(arg); + } +} + +// ─── ZExtInst ───────────────────────────────────────────────────────────────── +ZExtInst::ZExtInst(Value* val, std::string name) + : Instruction(Opcode::ZExt, Type::GetInt32Type(), std::move(name)) { + if (!val) { + throw std::runtime_error(FormatError("ir", "ZExtInst 缺少操作数")); + } + AddOperand(val); +} + +Value* ZExtInst::GetSrc() const { return GetOperand(0); } + +// ─── SIToFPInst ─────────────────────────────────────────────────────────────── +SIToFPInst::SIToFPInst(Value* val, std::string name) + : Instruction(Opcode::SIToFP, Type::GetFloat32Type(), std::move(name)) { if (!val) { - throw std::runtime_error(FormatError("ir", "ReturnInst 缺少返回值")); + throw std::runtime_error(FormatError("ir", "SIToFPInst 缺少操作数")); } - if (!type_ || !type_->IsVoid()) { - throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void")); + AddOperand(val); +} + +Value* SIToFPInst::GetSrc() const { return GetOperand(0); } + +// ─── FPToSIInst ─────────────────────────────────────────────────────────────── +FPToSIInst::FPToSIInst(Value* val, std::string name) + : Instruction(Opcode::FPToSI, Type::GetInt32Type(), std::move(name)) { + if (!val) { + throw std::runtime_error(FormatError("ir", "FPToSIInst 缺少操作数")); } AddOperand(val); } -Value* ReturnInst::GetValue() const { return GetOperand(0); } +Value* FPToSIInst::GetSrc() const { return GetOperand(0); } + +// ─── ReturnInst ─────────────────────────────────────────────────────────────── +ReturnInst::ReturnInst(Value* val) + : Instruction(Opcode::Ret, Type::GetVoidType(), "") { + if (val) { + AddOperand(val); + } +} + +ReturnInst::ReturnInst() + : Instruction(Opcode::Ret, Type::GetVoidType(), "") {} +Value* ReturnInst::GetValue() const { + return HasValue() ? GetOperand(0) : nullptr; +} + +// ─── AllocaInst ─────────────────────────────────────────────────────────────── AllocaInst::AllocaInst(std::shared_ptr ptr_ty, std::string name) - : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) { - if (!type_ || !type_->IsPtrInt32()) { - throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*")); + : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)), + num_elements_(1) { + if (!type_->IsPtrInt32() && !type_->IsPtrFloat32()) { + throw std::runtime_error(FormatError("ir", "AllocaInst 只支持 i32* 和 f32*")); } } +AllocaInst::AllocaInst(std::shared_ptr ptr_ty, int num_elements, + std::string name) + : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)), + num_elements_(num_elements) { + if (!type_->IsPtrInt32() && !type_->IsPtrFloat32()) { + throw std::runtime_error(FormatError("ir", "AllocaInst 只支持 i32* 和 f32*")); + } +} + +// ─── GepInst ────────────────────────────────────────────────────────────────── +GepInst::GepInst(Value* base_ptr, Value* index, std::string name) + : Instruction(Opcode::Gep, Type::GetPtrInt32Type(), std::move(name)) { + if (!base_ptr || !index) { + throw std::runtime_error(FormatError("ir", "GepInst 缺少操作数")); + } + AddOperand(base_ptr); + AddOperand(index); +} + +Value* GepInst::GetBasePtr() const { return GetOperand(0); } +Value* GepInst::GetIndex() const { return GetOperand(1); } + +// ─── LoadInst ───────────────────────────────────────────────────────────────── LoadInst::LoadInst(std::shared_ptr val_ty, Value* ptr, std::string name) : Instruction(Opcode::Load, std::move(val_ty), std::move(name)) { if (!ptr) { throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr")); } - if (!type_ || !type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32")); - } - if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { - throw std::runtime_error( - FormatError("ir", "LoadInst 当前只支持从 i32* 加载")); - } AddOperand(ptr); } Value* LoadInst::GetPtr() const { return GetOperand(0); } -StoreInst::StoreInst(std::shared_ptr void_ty, Value* val, Value* ptr) - : Instruction(Opcode::Store, std::move(void_ty), "") { - if (!val) { - throw std::runtime_error(FormatError("ir", "StoreInst 缺少 value")); - } - if (!ptr) { - throw std::runtime_error(FormatError("ir", "StoreInst 缺少 ptr")); - } - if (!type_ || !type_->IsVoid()) { - throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void")); - } - if (!val->GetType() || !val->GetType()->IsInt32()) { - throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32")); - } - if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { - throw std::runtime_error( - FormatError("ir", "StoreInst 当前只支持写入 i32*")); +// ─── StoreInst ──────────────────────────────────────────────────────────────── +StoreInst::StoreInst(Value* val, Value* ptr) + : Instruction(Opcode::Store, Type::GetVoidType(), "") { + if (!val || !ptr) { + throw std::runtime_error(FormatError("ir", "StoreInst 缺少操作数")); } AddOperand(val); AddOperand(ptr); } Value* StoreInst::GetValue() const { return GetOperand(0); } - Value* StoreInst::GetPtr() const { return GetOperand(1); } +// ─── GlobalValue (占位) ──────────────────────────────────────────────────────── +GlobalValue::GlobalValue(std::shared_ptr ty, std::string name) + : User(std::move(ty), std::move(name)) {} + +// ─── GlobalVariable ──────────────────────────────────────────────────────────── +GlobalVariable::GlobalVariable(std::string name, bool is_const, int init_val, + int num_elements) + : Value(Type::GetPtrInt32Type(), std::move(name)), + is_const_(is_const), + init_val_(init_val), + num_elements_(num_elements) {} + } // namespace ir diff --git a/src/ir/Module.cpp b/src/ir/Module.cpp index 928efdc..5d46d90 100644 --- a/src/ir/Module.cpp +++ b/src/ir/Module.cpp @@ -1,21 +1,68 @@ -// 保存函数列表并提供模块级上下文访问。 - #include "ir/IR.h" +#include + namespace ir { Context& Module::GetContext() { return context_; } - const Context& Module::GetContext() const { return context_; } +// ─── 函数管理 ───────────────────────────────────────────────────────────────── Function* Module::CreateFunction(const std::string& name, std::shared_ptr ret_type) { functions_.push_back(std::make_unique(name, std::move(ret_type))); - return functions_.back().get(); + Function* f = functions_.back().get(); + func_map_[name] = f; + return f; +} + +Function* Module::GetFunction(const std::string& name) const { + auto it = func_map_.find(name); + return it == func_map_.end() ? nullptr : it->second; } const std::vector>& Module::GetFunctions() const { return functions_; } +// ─── 全局变量管理 ───────────────────────────────────────────────────────────── +GlobalVariable* Module::CreateGlobalVariable(const std::string& name, + bool is_const, int init_val, + int num_elements) { + globals_.push_back( + std::make_unique(name, is_const, init_val, num_elements)); + GlobalVariable* g = globals_.back().get(); + global_map_[name] = g; + return g; +} + +GlobalVariable* Module::GetGlobalVariable(const std::string& name) const { + auto it = global_map_.find(name); + return it == global_map_.end() ? nullptr : it->second; +} + +const std::vector>& +Module::GetGlobalVariables() const { + return globals_; +} + +// ─── 外部函数声明 ───────────────────────────────────────────────────────────── +void Module::DeclareExternalFunc(const std::string& name, + std::shared_ptr ret_type, + std::vector> param_types, + bool is_variadic) { + if (external_decl_index_.count(name)) return; // 已声明,幂等 + external_decl_index_[name] = external_decls_.size(); + external_decls_.push_back( + {name, std::move(ret_type), std::move(param_types), is_variadic}); +} + +bool Module::HasExternalDecl(const std::string& name) const { + return external_decl_index_.count(name) > 0; +} + +const std::vector& Module::GetExternalDecls() const { + return external_decls_; +} + } // namespace ir diff --git a/src/ir/Type.cpp b/src/ir/Type.cpp index 3e1684d..48970a8 100644 --- a/src/ir/Type.cpp +++ b/src/ir/Type.cpp @@ -1,4 +1,4 @@ -// 当前仅支持 void、i32 和 i32*。 +// 支持 void, i1, i32, f32, i32*, f32* #include "ir/IR.h" namespace ir { @@ -10,22 +10,40 @@ const std::shared_ptr& Type::GetVoidType() { return type; } +const std::shared_ptr& Type::GetInt1Type() { + static const std::shared_ptr type = std::make_shared(Kind::Int1); + return type; +} + const std::shared_ptr& Type::GetInt32Type() { static const std::shared_ptr type = std::make_shared(Kind::Int32); return type; } +const std::shared_ptr& Type::GetFloat32Type() { + static const std::shared_ptr type = std::make_shared(Kind::Float32); + return type; +} + const std::shared_ptr& Type::GetPtrInt32Type() { - static const std::shared_ptr type = std::make_shared(Kind::PtrInt32); + static const std::shared_ptr type = + std::make_shared(Kind::PtrInt32); + return type; +} + +const std::shared_ptr& Type::GetPtrFloat32Type() { + static const std::shared_ptr type = + std::make_shared(Kind::PtrFloat32); return type; } Type::Kind Type::GetKind() const { return kind_; } bool Type::IsVoid() const { return kind_ == Kind::Void; } - +bool Type::IsInt1() const { return kind_ == Kind::Int1; } bool Type::IsInt32() const { return kind_ == Kind::Int32; } - +bool Type::IsFloat32() const { return kind_ == Kind::Float32; } bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; } +bool Type::IsPtrFloat32() const { return kind_ == Kind::PtrFloat32; } } // namespace ir diff --git a/src/ir/Value.cpp b/src/ir/Value.cpp index 2e9f4c1..2e52be0 100644 --- a/src/ir/Value.cpp +++ b/src/ir/Value.cpp @@ -17,26 +17,30 @@ const std::string& Value::GetName() const { return name_; } void Value::SetName(std::string n) { name_ = std::move(n); } bool Value::IsVoid() const { return type_ && type_->IsVoid(); } - +bool Value::IsInt1() const { return type_ && type_->IsInt1(); } bool Value::IsInt32() const { return type_ && type_->IsInt32(); } - +bool Value::IsFloat32() const { return type_ && type_->IsFloat32(); } bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); } +bool Value::IsPtrFloat32() const { return type_ && type_->IsPtrFloat32(); } bool Value::IsConstant() const { return dynamic_cast(this) != nullptr; } - bool Value::IsInstruction() const { return dynamic_cast(this) != nullptr; } - bool Value::IsUser() const { return dynamic_cast(this) != nullptr; } - bool Value::IsFunction() const { return dynamic_cast(this) != nullptr; } +bool Value::IsBasicBlock() const { + return dynamic_cast(this) != nullptr; +} +bool Value::IsGlobalVariable() const { + return dynamic_cast(this) != nullptr; +} void Value::AddUse(User* user, size_t operand_index) { if (!user) return; @@ -80,4 +84,7 @@ ConstantValue::ConstantValue(std::shared_ptr ty, std::string name) ConstantInt::ConstantInt(std::shared_ptr ty, int v) : ConstantValue(std::move(ty), ""), value_(v) {} +ConstantFloat::ConstantFloat(std::shared_ptr ty, float v) + : ConstantValue(std::move(ty), ""), value_(v) {} + } // namespace ir diff --git a/src/irgen/CMakeLists.txt b/src/irgen/CMakeLists.txt index d440bde..ccde815 100644 --- a/src/irgen/CMakeLists.txt +++ b/src/irgen/CMakeLists.txt @@ -10,4 +10,5 @@ target_link_libraries(irgen PUBLIC build_options ${ANTLR4_RUNTIME_TARGET} ir + sem ) diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index 0eb62ae..f0c6a50 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -1,46 +1,53 @@ #include "irgen/IRGen.h" +#include #include +#include +#include #include "SysYParser.h" #include "ir/IR.h" +#include "sem/func.h" #include "utils/Log.h" -namespace { - -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("irgen", "非法左值")); +// ─── EvalConstExprInt ───────────────────────────────────────────────────────── +int IRGenImpl::EvalConstExprInt(SysYParser::ConstExpContext* ctx) { + if (!ctx) return 0; + try { + auto cv = sem::EvaluateConstExp(*ctx); + return static_cast(cv.int_val); + } catch (...) { + return 0; } - return lvalue.ID()->getText(); } -} // namespace - -std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { +// ─── visitBlock ─────────────────────────────────────────────────────────────── +std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句块")); } for (auto* item : ctx->blockItem()) { - if (item) { - if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { - // 当前语法要求 return 为块内最后一条语句;命中后可停止生成。 - break; - } + if (!item) continue; + if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { + return BlockFlow::Terminated; } } - return {}; + return BlockFlow::Continue; } IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult( SysYParser::BlockItemContext& item) { - return std::any_cast(item.accept(this)); + auto result = item.accept(this); + if (result.has_value()) { + try { + return std::any_cast(result); + } catch (...) {} + } + return BlockFlow::Continue; } std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少块内项")); - } + if (!ctx) return BlockFlow::Continue; if (ctx->decl()) { ctx->decl()->accept(this); return BlockFlow::Continue; @@ -48,60 +55,206 @@ std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { if (ctx->stmt()) { return ctx->stmt()->accept(this); } - throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明")); + return BlockFlow::Continue; } -// 变量声明的 IR 生成目前也是最小实现: -// - 先检查声明的基础类型,当前仅支持局部 int; -// - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。 -// -// 和更完整的版本相比,这里还没有: -// - 一个 Decl 中多个变量定义的顺序处理; -// - const、数组、全局变量等不同声明形态; -// - 更丰富的类型系统。 +// ─── visitDecl ──────────────────────────────────────────────────────────────── std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少变量声明")); - } - if (!ctx->btype() || !ctx->btype()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明")); - } - auto* var_def = ctx->varDef(); - if (!var_def) { - throw std::runtime_error(FormatError("irgen", "非法变量声明")); - } - var_def->accept(this); + if (!ctx) return {}; + if (ctx->varDecl()) ctx->varDecl()->accept(this); + if (ctx->constDecl()) ctx->constDecl()->accept(this); return {}; } +// ─── visitConstDecl ─────────────────────────────────────────────────────────── +std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) { + if (!ctx) return {}; + bool is_float = ctx->bType() && ctx->bType()->Float(); -// 当前仍是教学用的最小版本,因此这里只支持: -// - 局部 int 变量; -// - 标量初始化; -// - 一个 VarDef 对应一个槽位。 -std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少变量定义")); + for (auto* constDef : ctx->constDef()) { + if (!constDef || !constDef->Ident()) continue; + std::string name = constDef->Ident()->getText(); + + // 计算维度 + std::vector dims; + for (auto* ce : constDef->constExp()) { + dims.push_back(EvalConstExprInt(ce)); + } + bool is_array = !dims.empty(); + + if (!is_array) { + // 标量常量 + if (is_float) { + float val = 0.0f; + if (constDef->constInitVal() && constDef->constInitVal()->constExp()) { + try { + auto cv = sem::EvaluateConstExp(*constDef->constInitVal()->constExp()); + val = static_cast(cv.float_val); + } catch (...) { + val = 0.0f; + } + } + // 全局和局部都直接存储为常量(简化处理) + storage_map_[constDef] = builder_.CreateConstFloat(val); + if (in_global_scope_) { + global_storage_map_[constDef] = builder_.CreateConstFloat(val); + } + } else { + int val = 0; + if (constDef->constInitVal() && constDef->constInitVal()->constExp()) { + val = EvalConstExprInt(constDef->constInitVal()->constExp()); + } + if (in_global_scope_) { + auto* gv = module_.CreateGlobalVariable(name, true, val); + global_storage_map_[constDef] = gv; + } else { + storage_map_[constDef] = builder_.CreateConstInt(val); + } + } + } else { + // 数组常量 + int total = 1; + for (int d : dims) total *= (d > 0 ? d : 1); + + if (in_global_scope_) { + auto* gv = module_.CreateGlobalVariable(name, true, 0, total); + global_storage_map_[constDef] = gv; + global_array_dims_[constDef] = dims; + } else { + auto* slot = builder_.CreateAllocaArray(total, name); + storage_map_[constDef] = slot; + array_dims_[constDef] = dims; + // 扁平化初始化 + if (constDef->constInitVal()) { + std::vector flat; + flat.reserve(total); + std::function flatten = + [&](SysYParser::ConstInitValContext* iv) { + if (!iv) return; + if (iv->constExp()) { + flat.push_back(EvalConstExprInt(iv->constExp())); + } else { + for (auto* sub : iv->constInitVal()) flatten(sub); + } + }; + flatten(constDef->constInitVal()); + for (int i = 0; i < total; ++i) { + int v = (i < (int)flat.size()) ? flat[i] : 0; + auto* ptr = builder_.CreateGep( + slot, builder_.CreateConstInt(i), + module_.GetContext().NextTemp()); + builder_.CreateStore(builder_.CreateConstInt(v), ptr); + } + } + } + } } - if (!ctx->lValue()) { - throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); + return {}; +} + +// ─── visitVarDecl ───────────────────────────────────────────────────────────── +std::any IRGenImpl::visitVarDecl(SysYParser::VarDeclContext* ctx) { + if (!ctx) return {}; + for (auto* varDef : ctx->varDef()) { + if (varDef) varDef->accept(this); } - GetLValueName(*ctx->lValue()); - if (storage_map_.find(ctx) != storage_map_.end()) { - throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); + return {}; +} + +// ─── visitVarDef ────────────────────────────────────────────────────────────── +std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { + if (!ctx || !ctx->Ident()) return {}; + std::string name = ctx->Ident()->getText(); + + // 获取类型(从父节点 VarDecl) + auto* parent = dynamic_cast(ctx->parent); + bool is_float = parent && parent->bType() && parent->bType()->Float(); + + // 计算维度 + std::vector dims; + for (auto* ce : ctx->constExp()) { + dims.push_back(EvalConstExprInt(ce)); } - auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); - storage_map_[ctx] = slot; + bool is_array = !dims.empty(); - ir::Value* init = nullptr; - if (auto* init_value = ctx->initValue()) { - if (!init_value->exp()) { - throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化")); + if (in_global_scope_) { + if (!is_array) { + // 全局标量:初始化器必须是常量,简化处理为0 + int init_val = 0; + if (ctx->initVal() && ctx->initVal()->exp()) { + try { + auto cv = sem::EvaluateExp(*ctx->initVal()->exp()->addExp()); + init_val = static_cast(cv.int_val); + } catch (...) { + init_val = 0; + } + } + auto* gv = module_.CreateGlobalVariable(name, false, init_val); + global_storage_map_[ctx] = gv; + } else { + int total = 1; + for (int d : dims) total *= (d > 0 ? d : 1); + auto* gv = module_.CreateGlobalVariable(name, false, 0, total); + global_storage_map_[ctx] = gv; + global_array_dims_[ctx] = dims; } - init = EvalExpr(*init_value->exp()); } else { - init = builder_.CreateConstInt(0); + if (storage_map_.count(ctx)) { + throw std::runtime_error( + FormatError("irgen", "变量重复生成存储: " + name)); + } + + if (!is_array) { + auto* slot = is_float ? builder_.CreateAllocaF32(module_.GetContext().NextTemp()) + : builder_.CreateAllocaI32(module_.GetContext().NextTemp()); + storage_map_[ctx] = slot; + ir::Value* init; + if (ctx->initVal() && ctx->initVal()->exp()) { + init = EvalExpr(*ctx->initVal()->exp()); + } else { + init = is_float ? static_cast(builder_.CreateConstFloat(0.0f)) + : static_cast(builder_.CreateConstInt(0)); + } + builder_.CreateStore(init, slot); + } else { + int total = 1; + for (int d : dims) total *= (d > 0 ? d : 1); + auto* slot = builder_.CreateAllocaArray(total, name); + storage_map_[ctx] = slot; + array_dims_[ctx] = dims; + + if (ctx->initVal()) { + // 收集扁平化初始值 + std::vector flat; + flat.reserve(total); + std::function flatten = + [&](SysYParser::InitValContext* iv) { + if (!iv) return; + if (iv->exp()) { + flat.push_back(EvalExpr(*iv->exp())); + } else { + for (auto* sub : iv->initVal()) flatten(sub); + } + }; + flatten(ctx->initVal()); + for (int i = 0; i < total; ++i) { + ir::Value* v = (i < (int)flat.size()) ? flat[i] + : builder_.CreateConstInt(0); + auto* ptr = builder_.CreateGep( + slot, builder_.CreateConstInt(i), + module_.GetContext().NextTemp()); + builder_.CreateStore(v, ptr); + } + } else { + // 零初始化 + for (int i = 0; i < total; ++i) { + auto* ptr = builder_.CreateGep( + slot, builder_.CreateConstInt(i), + module_.GetContext().NextTemp()); + builder_.CreateStore(builder_.CreateConstInt(0), ptr); + } + } + } } - builder_.CreateStore(init, slot); return {}; } diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index cf4797c..39a04f3 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -1,80 +1,604 @@ #include "irgen/IRGen.h" #include +#include #include "SysYParser.h" #include "ir/IR.h" +#include "sem/func.h" #include "utils/Log.h" -// 表达式生成当前也只实现了很小的一个子集。 -// 目前支持: -// - 整数字面量 -// - 普通局部变量读取 -// - 括号表达式 -// - 二元加法 -// -// 还未支持: -// - 减乘除与一元运算 -// - 赋值表达式 -// - 函数调用 -// - 数组、指针、下标访问 -// - 条件与比较表达式 -// - ... +// ─── 辅助 ───────────────────────────────────────────────────────────────────── + +// 把 i32 值转成 i1(icmp ne i32 v, 0) +ir::Value* IRGenImpl::ToI1(ir::Value* v) { + if (!v) throw std::runtime_error(FormatError("irgen", "ToI1: null value")); + if (v->IsInt1()) return v; + return builder_.CreateICmp(ir::ICmpPredicate::NE, v, + builder_.CreateConstInt(0), + module_.GetContext().NextTemp()); +} + +// 把 i1 值零扩展为 i32 +ir::Value* IRGenImpl::ToI32(ir::Value* v) { + if (!v) throw std::runtime_error(FormatError("irgen", "ToI32: null value")); + if (v->IsInt32()) return v; + return builder_.CreateZExt(v, module_.GetContext().NextTemp()); +} + +// 转换为 float(如果是 int) +ir::Value* IRGenImpl::ToFloat(ir::Value* v) { + if (!v) throw std::runtime_error(FormatError("irgen", "ToFloat: null value")); + if (v->IsFloat32()) return v; + if (v->IsInt32()) return builder_.CreateSIToFP(v, module_.GetContext().NextTemp()); + if (v->IsInt1()) { + auto* i32 = ToI32(v); + return builder_.CreateSIToFP(i32, module_.GetContext().NextTemp()); + } + throw std::runtime_error(FormatError("irgen", "ToFloat: 不支持的类型")); +} + +// 转换为 int(如果是 float) +ir::Value* IRGenImpl::ToInt(ir::Value* v) { + if (!v) throw std::runtime_error(FormatError("irgen", "ToInt: null value")); + if (v->IsInt32()) return v; + if (v->IsFloat32()) return builder_.CreateFPToSI(v, module_.GetContext().NextTemp()); + if (v->IsInt1()) return ToI32(v); + throw std::runtime_error(FormatError("irgen", "ToInt: 不支持的类型")); +} + +// 隐式类型转换:确保两个操作数类型一致(int 转 float) +void IRGenImpl::ImplicitConvert(ir::Value*& lhs, ir::Value*& rhs) { + if (!lhs || !rhs) return; + bool lhs_float = lhs->IsFloat32(); + bool rhs_float = rhs->IsFloat32(); + if (lhs_float && !rhs_float) { + rhs = ToFloat(rhs); + } else if (!lhs_float && rhs_float) { + lhs = ToFloat(lhs); + } +} + +// 求值 exp(i32) ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { - return std::any_cast(expr.accept(this)); + auto result = expr.accept(this); + return std::any_cast(result); } +// 求值 addExp(i32) +ir::Value* IRGenImpl::EvalExprAdd(SysYParser::AddExpContext& expr) { + auto result = expr.accept(this); + return std::any_cast(result); +} + +// 求值 cond(i1) +ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) { + auto result = cond.accept(this); + auto* v = std::any_cast(result); + return ToI1(v); +} -std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "非法括号表达式")); +// 注册外部函数声明(幂等) +void IRGenImpl::EnsureExternalDecl(const std::string& name) { + if (module_.HasExternalDecl(name) || module_.GetFunction(name)) return; + // SysY 标准运行库函数签名 + if (name == "getint") { + module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {}); + } else if (name == "getch") { + module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {}); + } else if (name == "getfloat") { + module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {}); // 近似 + } else if (name == "putint") { + module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), + {ir::Type::GetInt32Type()}); + } else if (name == "putch") { + module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), + {ir::Type::GetInt32Type()}); + } else if (name == "putfloat") { + module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {}); + } else if (name == "putarray") { + module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), + {ir::Type::GetInt32Type()}); + } else if (name == "starttime" || name == "stoptime") { + module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), + {ir::Type::GetInt32Type()}); + } else { + // 未知外部函数,按 i32 返回声明 + module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {}); } - return EvalExpr(*ctx->exp()); } +// ─── 表达式访问器(返回 ir::Value*,i32) ───────────────────────────────────── -std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量")); +std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) { + if (!ctx || !ctx->addExp()) { + throw std::runtime_error(FormatError("irgen", "非法表达式")); } - return static_cast( - builder_.CreateConstInt(std::stoi(ctx->number()->getText()))); + return ctx->addExp()->accept(this); } -// 变量使用的处理流程: -// 1. 先通过语义分析结果把变量使用绑定回声明; -// 2. 再通过 storage_map_ 找到该声明对应的栈槽位; -// 3. 最后生成 load,把内存中的值读出来。 -// -// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 -std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) { - if (!ctx || !ctx->var() || !ctx->var()->ID()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); +// addExp : mulExp (AddOp mulExp)* +std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法加减表达式")); } - auto* decl = sema_.ResolveVarUse(ctx->var()); + auto muls = ctx->mulExp(); + if (muls.empty()) { + throw std::runtime_error(FormatError("irgen", "addExp 缺少操作数")); + } + + ir::Value* result = std::any_cast(muls[0]->accept(this)); + auto ops = ctx->AddOp(); + + for (size_t i = 0; i < ops.size(); ++i) { + ir::Value* rhs = std::any_cast(muls[i + 1]->accept(this)); + ImplicitConvert(result, rhs); + std::string tmp = module_.GetContext().NextTemp(); + std::string op = ops[i]->getText(); + if (result->IsFloat32()) { + result = (op == "+") ? builder_.CreateFAdd(result, rhs, tmp) + : builder_.CreateFSub(result, rhs, tmp); + } else { + result = (op == "+") ? builder_.CreateAdd(result, rhs, tmp) + : builder_.CreateSub(result, rhs, tmp); + } + } + return static_cast(result); +} + +// mulExp : unaryExp (MulOp unaryExp)* +std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法乘除表达式")); + } + auto unarys = ctx->unaryExp(); + if (unarys.empty()) { + throw std::runtime_error(FormatError("irgen", "mulExp 缺少操作数")); + } + + ir::Value* result = std::any_cast(unarys[0]->accept(this)); + auto ops = ctx->MulOp(); + + for (size_t i = 0; i < ops.size(); ++i) { + ir::Value* rhs = std::any_cast(unarys[i + 1]->accept(this)); + ImplicitConvert(result, rhs); + std::string tmp = module_.GetContext().NextTemp(); + std::string op = ops[i]->getText(); + if (result->IsFloat32()) { + if (op == "*") result = builder_.CreateFMul(result, rhs, tmp); + else if (op == "/") result = builder_.CreateFDiv(result, rhs, tmp); + else throw std::runtime_error(FormatError("irgen", "float 不支持取模")); + } else { + if (op == "*") result = builder_.CreateMul(result, rhs, tmp); + else if (op == "/") result = builder_.CreateDiv(result, rhs, tmp); + else result = builder_.CreateMod(result, rhs, tmp); + } + } + return static_cast(result); +} + +// unaryExp : primaryExp +// | Ident L_PAREN (funcRParams)? R_PAREN +// | unaryOp unaryExp +std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法一元表达式")); + } + + // ── 一元运算符 ───────────────────────────────────────────────────────────── + if (ctx->unaryOp() && ctx->unaryExp()) { + ir::Value* operand = + std::any_cast(ctx->unaryExp()->accept(this)); + std::string op = ctx->unaryOp()->getText(); + if (op == "-") { + if (operand->IsFloat32()) { + return static_cast( + builder_.CreateFSub(builder_.CreateConstFloat(0.0f), operand, + module_.GetContext().NextTemp())); + } else { + return static_cast( + builder_.CreateSub(builder_.CreateConstInt(0), operand, + module_.GetContext().NextTemp())); + } + } else if (op == "+") { + return static_cast(operand); + } else if (op == "!") { + ir::Value* cmp; + if (operand->IsFloat32()) { + cmp = builder_.CreateFCmp(ir::FCmpPredicate::OEQ, operand, + builder_.CreateConstFloat(0.0f), + module_.GetContext().NextTemp()); + } else { + operand = ToI32(operand); + cmp = builder_.CreateICmp(ir::ICmpPredicate::EQ, operand, + builder_.CreateConstInt(0), + module_.GetContext().NextTemp()); + } + return static_cast(ToI32(cmp)); + } + throw std::runtime_error(FormatError("irgen", "不支持的一元运算符: " + op)); + } + + // ── 函数调用 ────────────────────────────────────────────────────────────── + if (ctx->Ident() && ctx->L_PAREN()) { + std::string callee_name = ctx->Ident()->getText(); + + // 收集实参 + std::vector args; + if (ctx->funcRParams()) { + for (auto* exp : ctx->funcRParams()->exp()) { + args.push_back(EvalExpr(*exp)); + } + } + + // 模块内已知函数? + ir::Function* callee = module_.GetFunction(callee_name); + if (callee) { + std::string ret_name = + callee->IsVoidReturn() ? "" : module_.GetContext().NextTemp(); + auto* call = + builder_.CreateCall(callee, std::move(args), ret_name); + return static_cast( + callee->IsVoidReturn() ? static_cast( + builder_.CreateConstInt(0)) + : call); + } + + // 外部函数 + EnsureExternalDecl(callee_name); + // 获取返回类型 + std::shared_ptr ret_type = ir::Type::GetInt32Type(); + for (const auto& decl : module_.GetExternalDecls()) { + if (decl.name == callee_name) { + ret_type = decl.ret_type; + break; + } + } + bool is_void = ret_type->IsVoid(); + std::string ret_name = is_void ? "" : module_.GetContext().NextTemp(); + auto* call = builder_.CreateCallExternal(callee_name, ret_type, + std::move(args), ret_name); + // void 调用返回 0 占位 + return static_cast( + is_void ? static_cast(builder_.CreateConstInt(0)) : call); + } + + // ── primaryExp ──────────────────────────────────────────────────────────── + if (ctx->primaryExp()) { + return ctx->primaryExp()->accept(this); + } + + throw std::runtime_error(FormatError("irgen", "非法一元表达式结构")); +} + +// primaryExp : L_PAREN exp R_PAREN | lVar | number +std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法基本表达式")); + } + if (ctx->exp()) { + return ctx->exp()->accept(this); + } + if (ctx->lVar()) { + return ctx->lVar()->accept(this); + } + if (ctx->number()) { + return ctx->number()->accept(this); + } + throw std::runtime_error(FormatError("irgen", "primaryExp 结构非法")); +} + +// EvalLVarAddr:计算 lVar 的地址(支持数组索引) +ir::Value* IRGenImpl::EvalLVarAddr(SysYParser::LVarContext* ctx) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("irgen", "非法变量引用")); + } + auto* decl = sema_.ResolveVarUse(ctx->Ident()); if (!decl) { throw std::runtime_error( - FormatError("irgen", - "变量使用缺少语义绑定: " + ctx->var()->ID()->getText())); + FormatError("irgen", "变量未绑定: " + ctx->Ident()->getText())); } + + // 查找存储槽位 + ir::Value* base = nullptr; + std::vector dims; + auto it = storage_map_.find(decl); - if (it == storage_map_.end()) { + if (it != storage_map_.end()) { + base = it->second; + auto dit = array_dims_.find(decl); + if (dit != array_dims_.end()) dims = dit->second; + } else { + auto git = global_storage_map_.find(decl); + if (git == global_storage_map_.end()) { + throw std::runtime_error( + FormatError("irgen", "变量无存储槽位: " + ctx->Ident()->getText())); + } + base = git->second; + auto gdit = global_array_dims_.find(decl); + if (gdit != global_array_dims_.end()) dims = gdit->second; + } + + // 无索引 → 返回基地址 + if (ctx->exp().empty()) return base; + + // 有索引 → 计算扁平化偏移 + auto indices = ctx->exp(); + + // 对于数组参数(第一维为-1),允许索引数等于维度数 + bool is_array_param = !dims.empty() && dims[0] == -1; + if (!is_array_param && indices.size() > dims.size()) { + throw std::runtime_error(FormatError("irgen", "数组索引维度过多")); + } + + ir::Value* offset = builder_.CreateConstInt(0); + + if (is_array_param) { + // 数组参数:dims[0]=-1, dims[1..n]是已知维度 + // 索引:indices[0]对应第一维,indices[1]对应第二维... + for (size_t i = 0; i < indices.size(); ++i) { + ir::Value* idx = EvalExpr(*indices[i]); + if (i == 0) { + // 第一维:stride = dims[1] * dims[2] * ... (如果有的话) + int stride = 1; + for (size_t j = 1; j < dims.size(); ++j) { + stride *= dims[j]; + } + if (stride > 1) { + ir::Value* scaled = builder_.CreateMul( + idx, builder_.CreateConstInt(stride), + module_.GetContext().NextTemp()); + offset = builder_.CreateAdd(offset, scaled, + module_.GetContext().NextTemp()); + } else { + offset = builder_.CreateAdd(offset, idx, + module_.GetContext().NextTemp()); + } + } else { + // 后续维度 + int stride = 1; + for (size_t j = i + 1; j < dims.size(); ++j) { + stride *= dims[j]; + } + ir::Value* scaled = builder_.CreateMul( + idx, builder_.CreateConstInt(stride), + module_.GetContext().NextTemp()); + offset = builder_.CreateAdd(offset, scaled, + module_.GetContext().NextTemp()); + } + } + } else { + // 普通数组:从最后一维开始计算 + int stride = 1; + for (int i = (int)dims.size() - 1; i >= 0; --i) { + stride = (i == (int)dims.size() - 1) ? 1 : stride * dims[i + 1]; + if (i < (int)indices.size()) { + ir::Value* idx = EvalExpr(*indices[i]); + ir::Value* scaled = builder_.CreateMul( + idx, builder_.CreateConstInt(stride), + module_.GetContext().NextTemp()); + offset = builder_.CreateAdd(offset, scaled, + module_.GetContext().NextTemp()); + } + } + } + + return builder_.CreateGep(base, offset, module_.GetContext().NextTemp()); +} + +// lVar : Ident (L_BRAKT exp R_BRAKT)* +// 在表达式语境下:load 变量值(返回 i32) +std::any IRGenImpl::visitLVar(SysYParser::LVarContext* ctx) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("irgen", "非法变量引用")); + } + auto* decl = sema_.ResolveVarUse(ctx->Ident()); + if (!decl) { throw std::runtime_error( - FormatError("irgen", - "变量声明缺少存储槽位: " + ctx->var()->ID()->getText())); + FormatError("irgen", "变量未绑定: " + ctx->Ident()->getText())); + } + + // 标量常量(ConstDefContext 且无索引) + if (auto* const_def = dynamic_cast(decl)) { + if (ctx->exp().empty()) { + // 先查局部 + auto it = storage_map_.find(const_def); + if (it != storage_map_.end()) { + if (auto* ci = dynamic_cast(it->second)) { + return static_cast(ci); + } + if (auto* cf = dynamic_cast(it->second)) { + return static_cast(cf); + } + } + // 再查全局 + auto git = global_storage_map_.find(const_def); + if (git != global_storage_map_.end()) { + if (auto* cf = dynamic_cast(git->second)) { + return static_cast(cf); + } + } + } } + + // 通用路径:计算地址并 load + ir::Value* addr = EvalLVarAddr(ctx); return static_cast( - builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); + builder_.CreateLoad(addr, module_.GetContext().NextTemp())); } +// number : IntConst | FloatConst +std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法数字")); + } + if (ctx->IntConst()) { + std::string text = ctx->IntConst()->getText(); + int val = 0; + try { + val = std::stoi(text, nullptr, 0); + } catch (...) { + throw std::runtime_error( + FormatError("irgen", "整数字面量解析失败: " + text)); + } + return static_cast(builder_.CreateConstInt(val)); + } + if (ctx->FloatConst()) { + std::string text = ctx->FloatConst()->getText(); + float val = 0.0f; + try { + val = std::stof(text); + } catch (...) { + throw std::runtime_error( + FormatError("irgen", "浮点字面量解析失败: " + text)); + } + return static_cast(builder_.CreateConstFloat(val)); + } + throw std::runtime_error(FormatError("irgen", "非法数字节点")); +} + +// ─── 条件表达式访问器(返回 ir::Value*,i1) ────────────────────────────────── -std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { - throw std::runtime_error(FormatError("irgen", "非法加法表达式")); +std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) { + if (!ctx || !ctx->lOrExp()) { + throw std::runtime_error(FormatError("irgen", "非法条件")); } - ir::Value* lhs = EvalExpr(*ctx->exp(0)); - ir::Value* rhs = EvalExpr(*ctx->exp(1)); - return static_cast( - builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, - module_.GetContext().NextTemp())); + return ctx->lOrExp()->accept(this); +} + +// lOrExp : lAndExp ('||' lAndExp)* +std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { + if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 lOrExp")); + auto ands = ctx->lAndExp(); + if (ands.empty()) throw std::runtime_error(FormatError("irgen", "lOrExp 空")); + + ir::Value* result = std::any_cast(ands[0]->accept(this)); + result = ToI1(result); + + for (size_t i = 1; i < ands.size(); ++i) { + // 检查当前块是否已终结 + if (builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()) { + break; + } + + // 短路:result || rhs + auto* res_slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); + ir::Value* res_ext = ToI32(result); + builder_.CreateStore(res_ext, res_slot); + + ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.rhs"); + ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.end"); + builder_.CreateCondBr(result, end_bb, rhs_bb); + + builder_.SetInsertPoint(rhs_bb); + ir::Value* rhs = std::any_cast(ands[i]->accept(this)); + rhs = ToI32(ToI1(rhs)); + if (!builder_.GetInsertBlock()->HasTerminator()) { + builder_.CreateStore(rhs, res_slot); + builder_.CreateBr(end_bb); + } + + builder_.SetInsertPoint(end_bb); + result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp())); + } + return static_cast(result); +} + +// lAndExp : eqExp ('&&' eqExp)* +std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { + if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 lAndExp")); + auto eqs = ctx->eqExp(); + if (eqs.empty()) throw std::runtime_error(FormatError("irgen", "lAndExp 空")); + + ir::Value* result = std::any_cast(eqs[0]->accept(this)); + result = ToI1(result); + + for (size_t i = 1; i < eqs.size(); ++i) { + // 检查当前块是否已终结 + if (builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()) { + break; + } + + auto* res_slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); + ir::Value* res_ext = ToI32(result); + builder_.CreateStore(res_ext, res_slot); + + ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.rhs"); + ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.end"); + builder_.CreateCondBr(result, rhs_bb, end_bb); + + builder_.SetInsertPoint(rhs_bb); + ir::Value* rhs = std::any_cast(eqs[i]->accept(this)); + rhs = ToI32(ToI1(rhs)); + if (!builder_.GetInsertBlock()->HasTerminator()) { + builder_.CreateStore(rhs, res_slot); + builder_.CreateBr(end_bb); + } + + builder_.SetInsertPoint(end_bb); + result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp())); + } + return static_cast(result); +} + +// eqExp : relExp (EqOp relExp)* +std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { + if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 eqExp")); + auto rels = ctx->relExp(); + if (rels.empty()) throw std::runtime_error(FormatError("irgen", "eqExp 空")); + + ir::Value* result = std::any_cast(rels[0]->accept(this)); + auto ops = ctx->EqOp(); + + for (size_t i = 0; i < ops.size(); ++i) { + ir::Value* rhs = std::any_cast(rels[i + 1]->accept(this)); + ir::Value* lhs = result; + ImplicitConvert(lhs, rhs); + std::string op = ops[i]->getText(); + if (lhs->IsFloat32()) { + ir::FCmpPredicate pred = (op == "==") ? ir::FCmpPredicate::OEQ : ir::FCmpPredicate::ONE; + result = builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()); + } else { + lhs = ToI32(lhs); + rhs = ToI32(rhs); + ir::ICmpPredicate pred = (op == "==") ? ir::ICmpPredicate::EQ : ir::ICmpPredicate::NE; + result = builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()); + } + } + return static_cast(result); +} + +// relExp : addExp (RelOp addExp)* +std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { + if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 relExp")); + auto adds = ctx->addExp(); + if (adds.empty()) throw std::runtime_error(FormatError("irgen", "relExp 空")); + + ir::Value* result = std::any_cast(adds[0]->accept(this)); + auto ops = ctx->RelOp(); + + for (size_t i = 0; i < ops.size(); ++i) { + ir::Value* rhs = std::any_cast(adds[i + 1]->accept(this)); + ir::Value* lhs = result; + ImplicitConvert(lhs, rhs); + std::string op = ops[i]->getText(); + if (lhs->IsFloat32()) { + ir::FCmpPredicate pred; + if (op == "<") pred = ir::FCmpPredicate::OLT; + else if (op == ">") pred = ir::FCmpPredicate::OGT; + else if (op == "<=") pred = ir::FCmpPredicate::OLE; + else pred = ir::FCmpPredicate::OGE; + result = builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()); + } else { + lhs = ToI32(lhs); + rhs = ToI32(rhs); + ir::ICmpPredicate pred; + if (op == "<") pred = ir::ICmpPredicate::SLT; + else if (op == ">") pred = ir::ICmpPredicate::SGT; + else if (op == "<=") pred = ir::ICmpPredicate::SLE; + else pred = ir::ICmpPredicate::SGE; + result = builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()); + } + } + return static_cast(result); } diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 4912d03..ff5f831 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -1,87 +1,130 @@ #include "irgen/IRGen.h" #include +#include #include "SysYParser.h" #include "ir/IR.h" +#include "sem/func.h" #include "utils/Log.h" -namespace { - -void VerifyFunctionStructure(const ir::Function& func) { - // 当前 IRGen 仍是单入口、顺序生成;这里在生成结束后补一层块终结校验。 - for (const auto& bb : func.GetBlocks()) { - if (!bb || !bb->HasTerminator()) { - throw std::runtime_error( - FormatError("irgen", "基本块未正确终结: " + - (bb ? bb->GetName() : std::string("")))); - } +// 辅助:求值表达式为整数(用于数组维度) +static int EvalExprInt(SysYParser::ExpContext* ctx) { + if (!ctx) return 0; + try { + auto cv = sem::EvaluateExp(*ctx->addExp()); + return static_cast(cv.int_val); + } catch (...) { + return 0; } } -} // namespace - +// ─── 构造函数 ───────────────────────────────────────────────────────────────── IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema) : module_(module), sema_(sema), func_(nullptr), builder_(module.GetContext(), nullptr) {} -// 编译单元的 IR 生成当前只实现了最小功能: -// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容; -// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR; -// -// 当前还没有实现: -// - 多个函数定义的遍历与生成; -// - 全局变量、全局常量的 IR 生成。 +// ─── visitCompUnit ──────────────────────────────────────────────────────────── std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少编译单元")); } - auto* func = ctx->funcDef(); - if (!func) { - throw std::runtime_error(FormatError("irgen", "缺少函数定义")); + // 先处理全局声明 + in_global_scope_ = true; + for (auto* decl : ctx->decl()) { + if (decl) decl->accept(this); + } + in_global_scope_ = false; + // 再生成函数 + for (auto* funcDef : ctx->funcDef()) { + if (funcDef) funcDef->accept(this); } - func->accept(this); return {}; } -// 函数 IR 生成当前实现了: -// 1. 获取函数名; -// 2. 检查函数返回类型; -// 3. 在 Module 中创建 Function; -// 4. 将 builder 插入点设置到入口基本块; -// 5. 继续生成函数体。 -// -// 当前还没有实现: -// - 通用函数返回类型处理; -// - 形参列表遍历与参数类型收集; -// - FunctionType 这样的函数类型对象; -// - Argument/形式参数 IR 对象; -// - 入口块中的参数初始化逻辑。 -// ... - -// 因此这里目前只支持最小的“无参 int 函数”生成。 +// ─── visitFuncDef ───────────────────────────────────────────────────────────── std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少函数定义")); + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("irgen", "缺少函数定义或函数名")); } - if (!ctx->blockStmt()) { + if (!ctx->block()) { throw std::runtime_error(FormatError("irgen", "函数体为空")); } - if (!ctx->ID()) { - throw std::runtime_error(FormatError("irgen", "缺少函数名")); + + std::string func_name = ctx->Ident()->getText(); + + // 确定返回类型 + std::shared_ptr ret_type; + if (!ctx->funcType()) { + throw std::runtime_error(FormatError("irgen", "缺少函数返回类型")); } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数")); + if (ctx->funcType()->Void()) { + ret_type = ir::Type::GetVoidType(); + } else if (ctx->funcType()->Int()) { + ret_type = ir::Type::GetInt32Type(); + } else if (ctx->funcType()->Float()) { + ret_type = ir::Type::GetFloat32Type(); + } else { + throw std::runtime_error( + FormatError("irgen", "函数 " + func_name + " 返回类型不支持")); } - func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type()); - builder_.SetInsertPoint(func_->GetEntry()); + func_ = module_.CreateFunction(func_name, ret_type); storage_map_.clear(); - ctx->blockStmt()->accept(this); - // 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 - VerifyFunctionStructure(*func_); + // 设置插入点到入口块 + builder_.SetInsertPoint(func_->GetEntry()); + + // 处理参数 + if (ctx->funcFParams()) { + for (auto* param : ctx->funcFParams()->funcFParam()) { + if (!param || !param->Ident()) continue; + std::string pname = param->Ident()->getText(); + + bool is_float = param->bType() && param->bType()->Float(); + bool is_array_param = !param->L_BRAKT().empty(); + + if (is_array_param) { + // 数组参数:传递为指针 + auto ptr_type = is_float ? ir::Type::GetPtrFloat32Type() : ir::Type::GetPtrInt32Type(); + ir::Argument* arg = func_->AddArgument(ptr_type, pname); + storage_map_[param] = arg; + + // 记录维度信息(第一维未知,后续维度从 exp 中获取) + std::vector dims; + dims.push_back(-1); // 第一维未知 + for (auto* exp_ctx : param->exp()) { + dims.push_back(EvalExprInt(exp_ctx)); + } + array_dims_[param] = dims; // 始终记录,包括一维数组参数 + } else { + // 标量参数 + auto arg_type = is_float ? ir::Type::GetFloat32Type() : ir::Type::GetInt32Type(); + ir::Argument* arg = func_->AddArgument(arg_type, pname); + auto* slot = is_float ? builder_.CreateAllocaF32(pname + ".addr") + : builder_.CreateAllocaI32(pname + ".addr"); + storage_map_[param] = slot; + builder_.CreateStore(arg, slot); + } + } + } + + // 生成函数体 + ctx->block()->accept(this); + + // 若最后一个基本块没有 terminator,自动补 ret + auto* last_bb = builder_.GetInsertBlock(); + if (last_bb && !last_bb->HasTerminator()) { + if (ret_type->IsVoid()) { + builder_.CreateRetVoid(); + } else if (ret_type->IsFloat32()) { + builder_.CreateRet(builder_.CreateConstFloat(0.0f)); + } else { + builder_.CreateRet(builder_.CreateConstInt(0)); + } + } + return {}; } diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index 751550c..75eb78a 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -1,39 +1,188 @@ #include "irgen/IRGen.h" #include +#include #include "SysYParser.h" #include "ir/IR.h" #include "utils/Log.h" -// 语句生成当前只实现了最小子集。 -// 目前支持: -// - return ; -// -// 还未支持: -// - 赋值语句 -// - if / while 等控制流 -// - 空语句、块语句嵌套分发之外的更多语句形态 - +// ─── visitStmt ──────────────────────────────────────────────────────────────── std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少语句")); + return VisitStmt(*ctx); +} + +IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) { + auto* ctx = &s; + + // 若当前块已经终结,跳过死代码 + { + auto* cur = builder_.GetInsertBlock(); + if (cur && cur->HasTerminator()) return BlockFlow::Terminated; } - if (ctx->returnStmt()) { - return ctx->returnStmt()->accept(this); + + // ── return ────────────────────────────────────────────────────────────────── + if (ctx->Return()) { + if (ctx->exp()) { + ir::Value* v = EvalExpr(*ctx->exp()); + builder_.CreateRet(v); + } else { + builder_.CreateRetVoid(); + } + return BlockFlow::Terminated; + } + + // ── break ──────────────────────────────────────────────────────────────────── + if (ctx->Break()) { + if (loop_stack_.empty()) { + throw std::runtime_error(FormatError("irgen", "break 在循环外")); + } + builder_.CreateBr(loop_stack_.back().after_bb); + return BlockFlow::Terminated; } - throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); -} + // ── continue ───────────────────────────────────────────────────────────────── + if (ctx->Continue()) { + if (loop_stack_.empty()) { + throw std::runtime_error(FormatError("irgen", "continue 在循环外")); + } + builder_.CreateBr(loop_stack_.back().cond_bb); + return BlockFlow::Terminated; + } -std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少 return 语句")); + // ── 赋值 lVar = exp ────────────────────────────────────────────────────────── + if (ctx->lVar() && ctx->Assign()) { + ir::Value* rhs = EvalExpr(*ctx->exp()); + ir::Value* addr = EvalLVarAddr(ctx->lVar()); + builder_.CreateStore(rhs, addr); + return BlockFlow::Continue; } - if (!ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "return 缺少表达式")); + + // ── 纯表达式语句 (exp)? ; ──────────────────────────────────────────────────── + if (ctx->Semi() && !ctx->Return() && !ctx->If() && !ctx->While() && + !ctx->Break() && !ctx->Continue() && !ctx->Assign()) { + if (ctx->exp()) { + EvalExpr(*ctx->exp()); + } + return BlockFlow::Continue; } - ir::Value* v = EvalExpr(*ctx->exp()); - builder_.CreateRet(v); - return BlockFlow::Terminated; + + // ── if ─────────────────────────────────────────────────────────────────────── + if (ctx->If()) { + if (!ctx->cond()) { + throw std::runtime_error(FormatError("irgen", "if 缺少条件")); + } + + auto stmts = ctx->stmt(); + ir::BasicBlock* then_bb = func_->CreateBlock( + module_.GetContext().NextTemp() + ".if.then"); + ir::BasicBlock* else_bb = nullptr; + ir::BasicBlock* merge_bb = func_->CreateBlock( + module_.GetContext().NextTemp() + ".if.end"); + + // 求值条件(可能创建短路求值块) + ir::Value* cond_val = EvalCond(*ctx->cond()); + + // 检查当前块是否已终结(短路求值可能导致) + if (builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()) { + // 条件求值已经终结了当前块,无法继续 + // 这种情况下,我们需要在merge_bb继续 + builder_.SetInsertPoint(merge_bb); + return BlockFlow::Continue; + } + + if (stmts.size() >= 2) { + // if-else + else_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".if.else"); + builder_.CreateCondBr(cond_val, then_bb, else_bb); + } else { + builder_.CreateCondBr(cond_val, then_bb, merge_bb); + } + + // then 分支 + builder_.SetInsertPoint(then_bb); + auto then_flow = VisitStmt(*stmts[0]); + if (then_flow != BlockFlow::Terminated) { + builder_.CreateBr(merge_bb); + } + + // else 分支 + if (else_bb) { + builder_.SetInsertPoint(else_bb); + auto else_flow = VisitStmt(*stmts[1]); + if (else_flow != BlockFlow::Terminated) { + builder_.CreateBr(merge_bb); + } + } + + builder_.SetInsertPoint(merge_bb); + return BlockFlow::Continue; + } + + // ── while ──────────────────────────────────────────────────────────────────── + if (ctx->While()) { + if (!ctx->cond()) { + throw std::runtime_error(FormatError("irgen", "while 缺少条件")); + } + ir::BasicBlock* cond_bb = func_->CreateBlock( + module_.GetContext().NextTemp() + ".while.cond"); + ir::BasicBlock* body_bb = func_->CreateBlock( + module_.GetContext().NextTemp() + ".while.body"); + ir::BasicBlock* after_bb = func_->CreateBlock( + module_.GetContext().NextTemp() + ".while.end"); + + // 跳转到条件块 + if (!builder_.GetInsertBlock()->HasTerminator()) { + builder_.CreateBr(cond_bb); + } + + // 条件块 + builder_.SetInsertPoint(cond_bb); + ir::Value* cond_val = EvalCond(*ctx->cond()); + + // 检查条件求值后是否已终结 + if (!builder_.GetInsertBlock()->HasTerminator()) { + builder_.CreateCondBr(cond_val, body_bb, after_bb); + } + + // 循环体(压入循环栈) + loop_stack_.push_back({cond_bb, after_bb}); + builder_.SetInsertPoint(body_bb); + auto stmts = ctx->stmt(); + if (!stmts.empty()) { + auto body_flow = VisitStmt(*stmts[0]); + if (body_flow != BlockFlow::Terminated) { + builder_.CreateBr(cond_bb); // 循环回跳 + } + } else { + builder_.CreateBr(cond_bb); + } + loop_stack_.pop_back(); + + builder_.SetInsertPoint(after_bb); + return BlockFlow::Continue; + } + + // ── 块语句 ─────────────────────────────────────────────────────────────────── + if (ctx->block()) { + auto result = ctx->block()->accept(this); + if (result.has_value()) { + try { + auto flow = std::any_cast(result); + if (flow == BlockFlow::Terminated) return BlockFlow::Terminated; + } catch (...) {} + } + // 检查 builder 的实际状态 + if (builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()) { + return BlockFlow::Terminated; + } + return BlockFlow::Continue; + } + + // ── 空语句 ; ───────────────────────────────────────────────────────────────── + if (ctx->Semi()) { + return BlockFlow::Continue; + } + + throw std::runtime_error(FormatError("irgen", "不支持的语句类型")); } diff --git a/src/sem/CMakeLists.txt b/src/sem/CMakeLists.txt index b3bc011..2787608 100644 --- a/src/sem/CMakeLists.txt +++ b/src/sem/CMakeLists.txt @@ -2,9 +2,10 @@ add_library(sem STATIC Sema.cpp SymbolTable.cpp ConstEval.cpp + func.cpp ) target_link_libraries(sem PUBLIC build_options ${ANTLR4_RUNTIME_TARGET} -) +) \ No newline at end of file diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index 3a97be6..058825c 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -1,14 +1,12 @@ #include "sem/Sema.h" #include -#include -#include #include #include +#include #include #include "SysYBaseVisitor.h" -#include "sem/SymbolTable.h" #include "sem/func.h" #include "utils/Log.h" @@ -16,200 +14,267 @@ using namespace sem; namespace { -// 编译时求值常量表达式 +// ─── 作用域栈 ───────────────────────────────────────────────────────────────── +class ScopeStack { + public: + void PushScope() { scopes_.emplace_back(); } + void PopScope() { + if (!scopes_.empty()) scopes_.pop_back(); + } + // 在当前作用域定义符号 + void Define(const std::string& name, antlr4::ParserRuleContext* decl) { + if (scopes_.empty()) return; + scopes_.back()[name] = decl; + } + // 检查当前作用域是否已定义(用于重复定义检查) + bool ContainsInCurrent(const std::string& name) const { + return !scopes_.empty() && scopes_.back().count(name) > 0; + } + // 向上查找 + antlr4::ParserRuleContext* Lookup(const std::string& name) const { + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + auto found = it->find(name); + if (found != it->end()) return found->second; + } + return nullptr; + } + + private: + std::vector> + scopes_; +}; +// ─── 语义分析访问者 ──────────────────────────────────────────────────────────── class SemaVisitor final : public SysYBaseVisitor { public: + // 顶层:global scope → decl (global var/const) → funcDef std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { if (!ctx) { throw std::runtime_error(FormatError("sema", "缺少编译单元")); } - - // 先处理声明(包括常量声明) - auto decls = ctx->decl(); - for (auto* decl : decls) { - decl->accept(this); + scope_.PushScope(); // global scope + // 先处理全局声明 + for (auto* decl : ctx->decl()) { + if (decl) decl->accept(this); } - // 再处理函数定义 - auto funcs = ctx->funcDef(); - int count = 0; - int len = funcs.size(); - if(len == 0){ - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); - } - for(int i = 0; i < len; ++i){ - auto* func = ctx->funcDef(i); - func->accept(this); - if(!func->Ident() || func->Ident()->getText() == "main"){ - count ++; - if(count > 1){ - throw std::runtime_error(FormatError("sema", "有多个 main 函数定义")); - } - if(func->funcFParams()){ - throw std::runtime_error(FormatError("sema", "main 函数不该有参数")); - } - if(!func->funcType() || !func->funcType()->Int()){ - throw std::runtime_error(FormatError("sema", "main 函数的返回值必须是 Int")); - } + bool has_main = false; + for (auto* func : ctx->funcDef()) { + if (!func) continue; + if (func->Ident() && func->Ident()->getText() == "main") { + has_main = true; } + func->accept(this); + } + if (!has_main && ctx->funcDef().empty()) { + throw std::runtime_error(FormatError("sema", "缺少函数定义")); } + scope_.PopScope(); return {}; } - std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) { - if (!ctx || !ctx->bType()) { - throw std::runtime_error(FormatError("sema", "非法常量声明")); - } - - // 获取类型信息 - bool is_int = ctx->bType()->Int() != nullptr; - bool is_float = ctx->bType()->Float() != nullptr; - - // 处理所有常量定义 - auto const_defs = ctx->constDef(); - for (auto* const_def : const_defs) { - // 检查标识符 - if (!const_def->Ident()) { - throw std::runtime_error(FormatError("sema", "常量声明缺少标识符")); - } - - std::string name = const_def->Ident()->getText(); - - // 检查是否重复定义 - if (table_.Contains(name)) { - throw std::runtime_error(FormatError("sema", "重复定义常量: " + name)); - } - - // 处理数组维度 - auto const_exps = const_def->constExp(); - std::vector dimensions; - for (auto* const_exp : const_exps) { - ConstValue value = EvaluateConstExp(*const_exp); - if (!value.is_int) { - throw std::runtime_error(FormatError("sema", "数组维度必须是整数")); + // 函数定义:参数入作用域 → 函数体 + std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { + if (!ctx) return {}; + scope_.PushScope(); + if (auto* params = ctx->funcFParams()) { + for (auto* param : params->funcFParam()) { + if (!param || !param->Ident()) continue; + std::string name = param->Ident()->getText(); + if (scope_.ContainsInCurrent(name)) { + throw std::runtime_error( + FormatError("sema", "参数重复定义: " + name)); } - if (value.int_val < 0) { - throw std::runtime_error(FormatError("sema", "数组维度必须是非负整数")); - } - dimensions.push_back(static_cast(value.int_val)); - } - - // 处理初始化器 - auto* init_val = const_def->constInitVal(); - if (init_val) { - // 检查标量常量的初始化器 - if (dimensions.empty()) { - // 标量常量,ConstInitVal必须是单个初始数值,不能是花括号列表 - if (init_val->L_BRACE()) { - throw std::runtime_error(FormatError("sema", "单个常量只能赋单个值,不能使用花括号列表")); - } - if (!init_val->constExp()) { - throw std::runtime_error(FormatError("sema", "单个常量缺少初始值")); - } - } - // 数组常量 - // 计算数组总元素个数 - size_t total_elements = 1; - for (auto dim : dimensions) { - total_elements *= dim; - } - - // 检查初始化器,传递总元素个数进行检查 - CheckConstInitVal(*init_val, dimensions, is_int, total_elements); + scope_.Define(name, param); } - - // 添加到符号表 - table_.AddConst(name, const_def); } - + if (ctx->block()) ctx->block()->accept(this); + scope_.PopScope(); + return {}; + } + + // 块:新作用域 + std::any visitBlock(SysYParser::BlockContext* ctx) override { + if (!ctx) return {}; + scope_.PushScope(); + for (auto* item : ctx->blockItem()) { + if (item) item->accept(this); + } + scope_.PopScope(); return {}; } + // 块内项:分发 + std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { + if (!ctx) return {}; + if (ctx->decl()) ctx->decl()->accept(this); + if (ctx->stmt()) ctx->stmt()->accept(this); + return {}; + } + // 声明:分发到 varDecl / constDecl + std::any visitDecl(SysYParser::DeclContext* ctx) override { + if (!ctx) return {}; + if (ctx->varDecl()) ctx->varDecl()->accept(this); + if (ctx->constDecl()) ctx->constDecl()->accept(this); + return {}; + } - std::any visitVarDecl(SysYParser::VarDeclContext* ctx) { - if (!ctx || !ctx->bType()) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); - } - - // 获取类型信息 - bool is_int = ctx->bType()->Int() != nullptr; - bool is_float = ctx->bType()->Float() != nullptr; - - // 处理所有变量定义 - auto var_defs = ctx->varDef(); - for (auto* var_def : var_defs) { - // 检查标识符 - if (!var_def->Ident()) { - throw std::runtime_error(FormatError("sema", "变量声明缺少标识符")); + // 变量声明 + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override { + if (!ctx) return {}; + for (auto* varDef : ctx->varDef()) { + if (!varDef || !varDef->Ident()) continue; + std::string name = varDef->Ident()->getText(); + if (scope_.ContainsInCurrent(name)) { + throw std::runtime_error( + FormatError("sema", "变量重复定义: " + name)); } - - std::string name = var_def->Ident()->getText(); - - // 检查是否重复定义 - if (table_.Contains(name)) { - throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); + scope_.Define(name, varDef); + // 访问初始化表达式中的变量引用 + if (varDef->initVal()) varDef->initVal()->accept(this); + } + return {}; + } + + // 常量声明 + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override { + if (!ctx) return {}; + for (auto* constDef : ctx->constDef()) { + if (!constDef || !constDef->Ident()) continue; + std::string name = constDef->Ident()->getText(); + if (scope_.ContainsInCurrent(name)) { + throw std::runtime_error( + FormatError("sema", "常量重复定义: " + name)); } - - // 处理数组维度 - auto const_exps = var_def->constExp(); - std::vector dimensions; - for (auto* const_exp : const_exps) { - ConstValue value = EvaluateConstExp(*const_exp); - if (!value.is_int) { - throw std::runtime_error(FormatError("sema", "数组维度必须是整数")); - } - if (value.int_val < 0) { - throw std::runtime_error(FormatError("sema", "数组维度必须是非负整数")); - } - dimensions.push_back(static_cast(value.int_val)); + scope_.Define(name, constDef); + // 初始化表达式中的引用 + if (constDef->constInitVal()) constDef->constInitVal()->accept(this); + } + return {}; + } + + // 语句:分发各种形式 + std::any visitStmt(SysYParser::StmtContext* ctx) override { + if (!ctx) return {}; + // 赋值语句:lVar = exp; + if (ctx->lVar() && ctx->Assign()) { + ctx->lVar()->accept(this); + if (ctx->exp()) ctx->exp()->accept(this); + return {}; + } + // 表达式语句:(exp)?; + if (ctx->exp() && !ctx->Return() && !ctx->If() && !ctx->While()) { + ctx->exp()->accept(this); + return {}; + } + // 块语句 + if (ctx->block()) { + ctx->block()->accept(this); + return {}; + } + // if 语句 + if (ctx->If()) { + if (ctx->cond()) ctx->cond()->accept(this); + auto stmts = ctx->stmt(); + for (auto* s : stmts) { + if (s) s->accept(this); } - - // 处理初始化器 - auto* init_val = var_def->initVal(); - if (init_val) { - // 检查标量变量的初始化器 - if (dimensions.empty()) { - // 标量变量,InitVal必须是单个表达式,不能是花括号列表 - if (init_val->L_BRACE()) { - throw std::runtime_error(FormatError("sema", "单个变量只能赋单个值,不能使用花括号列表")); - } - if (!init_val->exp()) { - throw std::runtime_error(FormatError("sema", "单个变量缺少初始值")); - } - } - // 数组变量 - // 计算数组总元素个数 - size_t total_elements = 1; - for (auto dim : dimensions) { - total_elements *= dim; - } - - // 检查初始化器,传递总元素个数进行检查 - CheckInitVal(*init_val, dimensions, is_int, total_elements); + return {}; + } + // while 语句 + if (ctx->While()) { + if (ctx->cond()) ctx->cond()->accept(this); + auto stmts = ctx->stmt(); + for (auto* s : stmts) { + if (s) s->accept(this); } - - // 添加到符号表 - table_.AddVar(name, var_def); + return {}; + } + // return 语句 + if (ctx->Return()) { + if (ctx->exp()) ctx->exp()->accept(this); + return {}; } - + // break / continue:无变量引用 return {}; } - std::any visitVarDef(SysYParser::VarDefContext* ctx) { - // 此方法由visitVarDecl调用,不需要单独处理 + // lVar:绑定变量使用 + std::any visitLVar(SysYParser::LVarContext* ctx) override { + if (!ctx || !ctx->Ident()) return {}; + std::string name = ctx->Ident()->getText(); + auto* decl = scope_.Lookup(name); + if (!decl) { + // 可能是外部函数或未声明变量,不强制报错(IRGen 会处理外部调用) + // 但对变量使用报错 + throw std::runtime_error( + FormatError("sema", "未声明变量: " + name)); + } + sema_.BindVarUse(ctx->Ident(), decl); + // 下标表达式也需要访问 + for (auto* e : ctx->exp()) { + if (e) e->accept(this); + } return {}; } + // 表达式:通过 visitChildren 自动递归 + std::any visitExp(SysYParser::ExpContext* ctx) override { + return visitChildren(ctx); + } + std::any visitAddExp(SysYParser::AddExpContext* ctx) override { + return visitChildren(ctx); + } + std::any visitMulExp(SysYParser::MulExpContext* ctx) override { + return visitChildren(ctx); + } + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override { + // 函数调用:Ident L_PAREN ... + if (ctx->Ident() && ctx->L_PAREN()) { + if (ctx->funcRParams()) ctx->funcRParams()->accept(this); + return {}; + } + return visitChildren(ctx); + } + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override { + return visitChildren(ctx); + } + std::any visitCond(SysYParser::CondContext* ctx) override { + return visitChildren(ctx); + } + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override { + return visitChildren(ctx); + } + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override { + return visitChildren(ctx); + } + std::any visitRelExp(SysYParser::RelExpContext* ctx) override { + return visitChildren(ctx); + } + std::any visitEqExp(SysYParser::EqExpContext* ctx) override { + return visitChildren(ctx); + } + std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override { + return visitChildren(ctx); + } + std::any visitInitVal(SysYParser::InitValContext* ctx) override { + return visitChildren(ctx); + } + std::any visitConstInitVal(SysYParser::ConstInitValContext* ctx) override { + // 常量初始化器中通常没有变量引用(只有字面量), + // 但如果有 constExp 引用其他常量则需要访问 + return visitChildren(ctx); + } + std::any visitConstExp(SysYParser::ConstExpContext* ctx) override { + return visitChildren(ctx); + } + SemanticContext TakeSemanticContext() { return std::move(sema_); } private: - SymbolTable table_; + ScopeStack scope_; SemanticContext sema_; - bool seen_return_ = false; - size_t current_item_index_ = 0; - size_t total_items_ = 0; }; } // namespace @@ -218,4 +283,4 @@ SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { SemaVisitor visitor; comp_unit.accept(&visitor); return visitor.TakeSemanticContext(); -} \ No newline at end of file +} diff --git a/test_const_float.sy b/test_const_float.sy new file mode 100644 index 0000000..19c20c4 --- /dev/null +++ b/test_const_float.sy @@ -0,0 +1,9 @@ +const float PI = 3.14; +const float E = 2.718; + +int main() { + float r = 5.0; + float area = PI * r * r; + float sum = PI + E; + return 0; +} diff --git a/test_float.sy b/test_float.sy new file mode 100644 index 0000000..1e76fba --- /dev/null +++ b/test_float.sy @@ -0,0 +1,8 @@ +int main() { + float a = 3.14; + float b = 2.0; + float c = a + b; + int d = 10; + float e = d + a; // 隐式转换 int -> float + return 0; +} diff --git a/test_float_full.sy b/test_float_full.sy new file mode 100644 index 0000000..ccb4e9a --- /dev/null +++ b/test_float_full.sy @@ -0,0 +1,30 @@ +float test_float(float x, int y) { + float z = x * 2.5; + float w = y + z; // 隐式转换 int -> float + return w / 3.0; +} + +int main() { + float a = 1.5; + float b = 2.5; + + // 浮点运算 + float c = a + b; + float d = a - b; + float e = a * b; + float f = a / b; + + // 浮点比较 + int g = a < b; + int h = a == b; + + // 一元运算 + float i = -a; + int j = !a; + + // 混合运算(隐式转换) + int k = 10; + float m = k + a; + + return 0; +} diff --git a/test_float_simple.sy b/test_float_simple.sy new file mode 100644 index 0000000..1eff332 --- /dev/null +++ b/test_float_simple.sy @@ -0,0 +1,18 @@ +float test_float(float x, int y) { + float z = x * 2.5; + float w = y + z; + return w / 3.0; +} + +int main() { + float a = 1.5; + float b = 2.5; + float c = a + b; + float d = a - b; + float e = a * b; + float f = a / b; + float g = -a; + int k = 10; + float m = k + a; + return 0; +}