diff --git a/include/ir/IR.h b/include/ir/IR.h index b961192..3cc880e 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -1,37 +1,15 @@ -// 当前只支撑 i32、i32*、void 以及最小的内存/算术指令,演示用。 -// -// 当前已经实现: -// 1. 基础类型系统:void / i32 / i32* -// 2. Value 体系:Value / ConstantValue / ConstantInt / Function / BasicBlock / User / GlobalValue / Instruction -// 3. 最小指令集:Add / Alloca / Load / Store / Ret -// 4. BasicBlock / Function / Module 三层组织结构 -// 5. IRBuilder:便捷创建常量和最小指令 -// 6. def-use 关系的轻量实现: -// - Instruction 保存 operand 列表 -// - Value 保存 uses -// - 支持 ReplaceAllUsesWith 的简化实现 -// -// 当前尚未实现或只做了最小占位: -// 1. 完整类型系统:数组、函数类型、label 类型等 -// 2. 更完整的指令系统:br / condbr / call / phi / gep 等 -// 3. 更成熟的 Use 管理(例如 LLVM 风格的双向链式结构) -// 4. 更完整的 IR verifier 和优化基础设施 -// -// 当前需要特别说明的两个简化点: -// 1. BasicBlock 虽然已经纳入 Value 体系,但其类型目前仍用 void 作为占位, -// 后续如果补 label type,可以再改成更合理的块标签类型。 -// 2. ConstantValue 体系目前只实现了 ConstantInt,后续可以继续补 ConstantFloat、 -// ConstantArray等更完整的常量种类。 -// -// 建议的扩展顺序: -// 1. 先补更多指令和类型 -// 2. 再补控制流相关 IR -// 3. 最后再考虑把 Value/User/Use 进一步抽象成更完整的框架 +// 扩展后的 IR 库: +// - 完整基础类型:void/i1/i32/float/ptr/array/function/label +// - 指令:算术、比较、分支、调用、phi、gep、类型转换等 +// - 常量:int/float/array +// - 基本块/函数/模块/IRBuilder 的完整接口 #pragma once +#include #include #include +#include #include #include #include @@ -45,10 +23,14 @@ class Value; class User; class ConstantValue; class ConstantInt; +class ConstantFloat; +class ConstantArray; class GlobalValue; +class GlobalVariable; class Instruction; class BasicBlock; class Function; +class Argument; // Use 表示一个 Value 的一次使用记录。 // 当前实现设计: @@ -83,31 +65,65 @@ class Context { ~Context(); // 去重创建 i32 常量。 ConstantInt* GetConstInt(int v); + ConstantInt* GetConstBool(bool v); + ConstantFloat* GetConstFloat(float v); + ConstantArray* CreateConstArray(std::shared_ptr array_ty, + std::vector elements); std::string NextTemp(); private: std::unordered_map> const_ints_; + std::unordered_map> const_bools_; + std::unordered_map> const_floats_; + std::vector> const_arrays_; int temp_index_ = -1; }; class Type { public: - enum class Kind { Void, Int32, PtrInt32 }; + enum class Kind { Void, Int1, Int32, Float, Pointer, Array, Function, Label }; explicit Type(Kind k); + Type(Kind k, std::shared_ptr elem, size_t count); + Type(Kind k, std::shared_ptr ret, std::vector> params, + bool is_vararg); // 使用静态共享对象获取类型。 // 同一类型可直接比较返回值是否相等,例如: // Type::GetInt32Type() == Type::GetInt32Type() static const std::shared_ptr& GetVoidType(); + static const std::shared_ptr& GetInt1Type(); static const std::shared_ptr& GetInt32Type(); - static const std::shared_ptr& GetPtrInt32Type(); + static const std::shared_ptr& GetFloatType(); + static const std::shared_ptr& GetLabelType(); + static std::shared_ptr GetPointerType(std::shared_ptr elem); + static std::shared_ptr GetArrayType(std::shared_ptr elem, + size_t count); + static std::shared_ptr GetFunctionType( + std::shared_ptr ret, std::vector> params, + bool is_vararg = false); Kind GetKind() const; bool IsVoid() const; + bool IsInt1() const; bool IsInt32() const; - bool IsPtrInt32() const; + bool IsFloat() const; + bool IsPointer() const; + bool IsArray() const; + bool IsFunction() const; + bool IsLabel() const; + const std::shared_ptr& GetElementType() const; + size_t GetArraySize() const; + const std::shared_ptr& GetReturnType() const; + const std::vector>& GetParamTypes() const; + bool IsVarArg() const; + bool Equals(const Type& other) const; private: Kind kind_; + std::shared_ptr elem_type_; + size_t array_size_ = 0; + std::shared_ptr ret_type_; + std::vector> param_types_; + bool is_vararg_ = false; }; class Value { @@ -118,7 +134,12 @@ class Value { const std::string& GetName() const; void SetName(std::string n); bool IsVoid() const; + bool IsInt1() const; bool IsInt32() const; + bool IsFloat() const; + bool IsPointer() const; + bool IsArray() const; + bool IsFunctionType() const; bool IsPtrInt32() const; bool IsConstant() const; bool IsInstruction() const; @@ -151,8 +172,53 @@ class ConstantInt : public ConstantValue { int value_{}; }; +class ConstantFloat : public ConstantValue { + public: + ConstantFloat(std::shared_ptr ty, float v); + float GetValue() const { return value_; } + + private: + float value_{}; +}; + +class ConstantArray : public ConstantValue { + public: + ConstantArray(std::shared_ptr ty, std::vector elements); + const std::vector& GetElements() const { return elements_; } + + private: + std::vector elements_; +}; + // 后续还需要扩展更多指令类型。 -enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret }; +enum class Opcode { + Add, + Sub, + Mul, + SDiv, + SRem, + FAdd, + FSub, + FMul, + FDiv, + Alloca, + Load, + Store, + Ret, + Br, + CondBr, + ICmp, + FCmp, + Call, + Phi, + Gep, + SIToFP, + FPToSI, + ZExt +}; + +enum class ICmpPredicate { Eq, Ne, Slt, Sle, Sgt, Sge }; +enum class FCmpPredicate { Oeq, One, Olt, Ole, Ogt, Oge }; // User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。 // 当前实现中只有 Instruction 继承自 User。 @@ -178,6 +244,20 @@ class GlobalValue : public User { GlobalValue(std::shared_ptr ty, std::string name); }; +class GlobalVariable : public GlobalValue { + public: + GlobalVariable(std::shared_ptr value_ty, std::string name, + ConstantValue* init, bool is_const); + const std::shared_ptr& GetValueType() const; + ConstantValue* GetInitializer() const; + bool IsConst() const; + + private: + std::shared_ptr value_type_; + ConstantValue* initializer_ = nullptr; + bool is_const_ = false; +}; + class Instruction : public User { public: Instruction(Opcode op, std::shared_ptr ty, std::string name = ""); @@ -196,18 +276,67 @@ class BinaryInst : public Instruction { BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, Value* rhs, std::string name); Value* GetLhs() const; - Value* GetRhs() const; + Value* GetRhs() const; +}; + +class ICmpInst : public Instruction { + public: + ICmpInst(ICmpPredicate pred, Value* lhs, Value* rhs, std::string name); + ICmpPredicate GetPredicate() const { return pred_; } + Value* GetLhs() const; + Value* GetRhs() const; + + private: + ICmpPredicate pred_; +}; + +class FCmpInst : public Instruction { + public: + FCmpInst(FCmpPredicate pred, Value* lhs, Value* rhs, std::string name); + FCmpPredicate GetPredicate() const { return pred_; } + Value* GetLhs() const; + Value* GetRhs() const; + + private: + FCmpPredicate pred_; +}; + +class CastInst : public Instruction { + public: + CastInst(Opcode op, std::shared_ptr dst_ty, Value* src, + std::string name); + Value* GetValue() const; +}; + +class BranchInst : public Instruction { + public: + explicit BranchInst(BasicBlock* dest); + BasicBlock* GetDest() const; +}; + +class CondBrInst : public Instruction { + public: + CondBrInst(Value* cond, BasicBlock* true_dest, BasicBlock* false_dest); + Value* GetCond() const; + BasicBlock* GetTrueDest() const; + BasicBlock* GetFalseDest() const; }; class ReturnInst : public Instruction { public: + explicit ReturnInst(std::shared_ptr void_ty); ReturnInst(std::shared_ptr void_ty, Value* val); + bool HasReturnValue() const; Value* GetValue() const; }; class AllocaInst : public Instruction { public: - AllocaInst(std::shared_ptr ptr_ty, std::string name); + AllocaInst(std::shared_ptr allocated_ty, std::string name); + const std::shared_ptr& GetAllocatedType() const; + + private: + std::shared_ptr allocated_type_; }; class LoadInst : public Instruction { @@ -223,8 +352,41 @@ class StoreInst : public Instruction { Value* GetPtr() const; }; -// BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。 -// 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。 +class CallInst : public Instruction { + public: + CallInst(std::shared_ptr ret_ty, Value* callee, + std::vector args, std::string name); + Value* GetCallee() const; + const std::vector& GetArgs() const { return args_; } + + private: + std::vector args_; +}; + +class PhiInst : public Instruction { + public: + PhiInst(std::shared_ptr ty, std::string name); + void AddIncoming(Value* value, BasicBlock* block); + const std::vector& GetIncomingValues() const; + const std::vector& GetIncomingBlocks() const; + + private: + std::vector incoming_values_; + std::vector incoming_blocks_; +}; + +class GepInst : public Instruction { + public: + GepInst(std::shared_ptr result_ptr_ty, Value* base_ptr, + std::vector indices, std::string name); + Value* GetBasePtr() const; + const std::vector& GetIndices() const { return indices_; } + + private: + std::vector indices_; +}; + +// BasicBlock 已纳入 Value 体系,使用 label type。 class BasicBlock : public Value { public: explicit BasicBlock(std::string name); @@ -234,6 +396,8 @@ class BasicBlock : public Value { const std::vector>& GetInstructions() const; const std::vector& GetPredecessors() const; const std::vector& GetSuccessors() const; + void AddPredecessor(BasicBlock* pred); + void AddSuccessor(BasicBlock* succ); template T* Append(Args&&... args) { if (HasTerminator()) { @@ -244,6 +408,7 @@ class BasicBlock : public Value { auto* ptr = inst.get(); ptr->SetParent(this); instructions_.push_back(std::move(inst)); + LinkSuccessorsIfNeeded(ptr); return ptr; } @@ -252,6 +417,7 @@ class BasicBlock : public Value { std::vector> instructions_; std::vector predecessors_; std::vector successors_; + void LinkSuccessorsIfNeeded(Instruction* inst); }; // Function 当前也采用了最小实现。 @@ -262,16 +428,34 @@ class BasicBlock : public Value { // 形参和调用,通常需要引入专门的函数类型表示。 class Function : public Value { public: - // 当前构造函数接收的也是返回类型,而不是完整函数类型。 - Function(std::string name, std::shared_ptr ret_type); + Function(std::string name, std::shared_ptr func_type, + bool is_declaration = false); BasicBlock* CreateBlock(const std::string& name); BasicBlock* GetEntry(); const BasicBlock* GetEntry() const; const std::vector>& GetBlocks() const; + const std::vector>& GetArguments() const; + size_t GetNumArgs() const; + Argument* GetArg(size_t index); + std::shared_ptr GetFunctionType() const; + std::shared_ptr GetReturnType() const; + bool IsDeclaration() const; private: BasicBlock* entry_ = nullptr; std::vector> blocks_; + std::vector> args_; + std::unordered_map block_name_counts_; + bool is_declaration_ = false; +}; + +class Argument : public Value { + public: + Argument(std::shared_ptr ty, std::string name, size_t index); + size_t GetIndex() const { return index_; } + + private: + size_t index_ = 0; }; class Module { @@ -282,11 +466,20 @@ class Module { // 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。 Function* CreateFunction(const std::string& name, std::shared_ptr ret_type); + Function* CreateFunctionWithType(const std::string& name, + std::shared_ptr func_type); + Function* CreateFunctionDecl(const std::string& name, + std::shared_ptr func_type); + GlobalVariable* CreateGlobalVariable(const std::string& name, + std::shared_ptr value_type, + ConstantValue* init, bool is_const); const std::vector>& GetFunctions() const; + const std::vector>& GetGlobals() const; private: Context context_; std::vector> functions_; + std::vector> globals_; }; class IRBuilder { @@ -297,13 +490,44 @@ class IRBuilder { // 构造常量、二元运算、返回指令的最小集合。 ConstantInt* CreateConstInt(int v); + ConstantInt* CreateConstBool(bool v); + ConstantFloat* CreateConstFloat(float v); BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name); BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateSDiv(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateSRem(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateFAdd(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateFSub(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateFMul(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateFDiv(Value* lhs, Value* rhs, const std::string& name); + ICmpInst* CreateICmp(ICmpPredicate pred, Value* lhs, Value* rhs, + const std::string& name); + FCmpInst* CreateFCmp(FCmpPredicate pred, Value* lhs, Value* rhs, + const std::string& name); + CastInst* CreateSIToFP(Value* src, std::shared_ptr dst_ty, + const std::string& name); + CastInst* CreateFPToSI(Value* src, std::shared_ptr dst_ty, + const std::string& name); + CastInst* CreateZExt(Value* src, std::shared_ptr dst_ty, + const std::string& name); + AllocaInst* CreateAlloca(std::shared_ptr ty, + const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name); StoreInst* CreateStore(Value* val, Value* ptr); + GepInst* CreateGep(Value* base_ptr, std::vector indices, + const std::string& name); + CallInst* CreateCall(Value* callee, std::vector args, + const std::string& name); + PhiInst* CreatePhi(std::shared_ptr ty, const std::string& name); + BranchInst* CreateBr(BasicBlock* dest); + CondBrInst* CreateCondBr(Value* cond, BasicBlock* true_dest, + BasicBlock* false_dest); ReturnInst* CreateRet(Value* v); + ReturnInst* CreateRetVoid(); private: Context& ctx_; diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index 231ba90..a18ff6a 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -1,58 +1,112 @@ -// 将语法树翻译为 IR。 -// 实现拆分在 IRGenFunc/IRGenStmt/IRGenExp/IRGenDecl。 - #pragma once #include #include -#include #include #include "SysYBaseVisitor.h" -#include "SysYParser.h" #include "ir/IR.h" #include "sem/Sema.h" -namespace ir { -class Module; -class Function; -class IRBuilder; -class Value; -} - class IRGenImpl final : public SysYBaseVisitor { public: IRGenImpl(ir::Module& module, const SemanticContext& sema); std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; + std::any visitBlock(SysYParser::BlockContext* ctx) override; std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; std::any visitDecl(SysYParser::DeclContext* ctx) override; - std::any visitStmt(SysYParser::StmtContext* ctx) override; + std::any visitConstDef(SysYParser::ConstDefContext* ctx) override; std::any visitVarDef(SysYParser::VarDefContext* ctx) override; - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override; - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override; - std::any visitVarExp(SysYParser::VarExpContext* ctx) override; - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override; + std::any visitStmt(SysYParser::StmtContext* ctx) override; - private: - enum class BlockFlow { - Continue, - Terminated, - }; + std::any visitExp(SysYParser::ExpContext* ctx) override; + std::any visitAddExp(SysYParser::AddExpContext* ctx) override; + std::any visitMulExp(SysYParser::MulExpContext* ctx) override; // 新增 + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; // 新增 + std::any visitRelExp(SysYParser::RelExpContext* ctx) override; + std::any visitEqExp(SysYParser::EqExpContext* ctx) override; + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override; + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override; + std::any visitCond(SysYParser::CondContext* ctx) override; + std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override; + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override; + std::any visitNumber(SysYParser::NumberContext* ctx) override; + std::any visitLVal(SysYParser::LValContext* ctx) override; + std::any visitConstExp(SysYParser::ConstExpContext* ctx) override; + std::any visitConstInitVal(SysYParser::ConstInitValContext* ctx) override; + std::any visitInitVal(SysYParser::InitValContext* ctx) override; + private: + ir::Value* EvalExp(SysYParser::ExpContext* ctx); + ir::Value* EvalCondValue(SysYParser::CondContext* ctx); + void EmitCondBr(SysYParser::CondContext* ctx, ir::BasicBlock* true_bb, + ir::BasicBlock* false_bb); + void EmitLOrCond(SysYParser::LOrExpContext* ctx, ir::BasicBlock* true_bb, + ir::BasicBlock* false_bb); + void EmitLAndCond(SysYParser::LAndExpContext* ctx, ir::BasicBlock* true_bb, + ir::BasicBlock* false_bb); + ir::Value* EmitRelEq(SysYParser::RelExpContext* ctx); + ir::Value* EmitEq(SysYParser::EqExpContext* ctx); + ir::Value* CastToFloat(ir::Value* v); + ir::Value* CastToInt(ir::Value* v); + ir::Value* MakeBool(ir::Value* v); + ir::Value* GetLValAddress(SysYParser::LValContext* ctx); + ir::Value* LoadIfNeeded(ir::Value* addr_or_val, const TypeDesc& ty, + bool as_rvalue); + std::shared_ptr ToIRType(const TypeDesc& ty); + std::shared_ptr ToIRParamType(const TypeDesc& ty); + ir::Value* DefaultValue(const TypeDesc& ty); + void InitArray(ir::Value* base_ptr, const TypeDesc& ty, + SysYParser::InitValContext* init); + void InitConstArray(ir::Value* base_ptr, const TypeDesc& ty, + SysYParser::ConstInitValContext* init); + size_t FillArrayValues(const TypeDesc& ty, SysYParser::InitValContext* init, + std::vector& values, size_t base, + size_t idx, size_t dim); + size_t FillConstArrayValues(const TypeDesc& ty, + SysYParser::ConstInitValContext* init, + std::vector& values, size_t base, + size_t idx, size_t dim); + size_t ArrayStride(const TypeDesc& ty, size_t dim) const; + size_t ArrayTotalSize(const TypeDesc& ty) const; + void PushLoop(ir::BasicBlock* break_bb, ir::BasicBlock* cont_bb); + void PopLoop(); + ir::BasicBlock* CurrentBreak() const; + ir::BasicBlock* CurrentContinue() const; + ir::ConstantValue* EvalConstScalar(SysYParser::ExpContext* ctx); + ir::ConstantValue* EvalConstScalar(SysYParser::ConstExpContext* ctx); + ir::ConstantValue* EvalConstAdd(SysYParser::AddExpContext* ctx); + ir::ConstantValue* EvalConstMul(SysYParser::MulExpContext* ctx); + ir::ConstantValue* EvalConstUnary(SysYParser::UnaryExpContext* ctx); + ir::ConstantValue* EvalConstPrimary(SysYParser::PrimaryExpContext* ctx); + ir::ConstantValue* EvalConstNumber(SysYParser::NumberContext* ctx); + ir::ConstantValue* EvalConstLVal(SysYParser::LValContext* ctx); + size_t InitGlobalArray(const TypeDesc& ty, SysYParser::InitValContext* init, + std::vector& values, size_t base, + size_t idx, size_t dim); + size_t InitGlobalConstArray(const TypeDesc& ty, + SysYParser::ConstInitValContext* init, + std::vector& values, + size_t base, size_t idx, size_t dim); + enum class BlockFlow { Continue, Terminated }; BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item); - ir::Value* EvalExpr(SysYParser::ExpContext& expr); ir::Module& module_; const SemanticContext& sema_; ir::Function* func_; ir::IRBuilder builder_; - // 名称绑定由 Sema 负责;IRGen 只维护“声明 -> 存储槽位”的代码生成状态。 - std::unordered_map storage_map_; + std::unordered_map var_storage_; + std::unordered_map const_storage_; + std::unordered_map param_storage_; + std::unordered_map func_map_; + std::unordered_map + global_var_storage_; + std::unordered_map + global_const_storage_; + std::vector> loop_stack_; }; std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, - const SemanticContext& sema); + const SemanticContext& sema); \ No newline at end of file diff --git a/include/sem/Sema.h b/include/sem/Sema.h index 9ac057b..40d1544 100644 --- a/include/sem/Sema.h +++ b/include/sem/Sema.h @@ -1,30 +1,91 @@ -// 基于语法树的语义检查与名称绑定。 +// 基于语法树的语义检查与名称绑定(Lab2 扩展) #pragma once +#include #include +#include #include "SysYParser.h" +#include "sem/SymbolTable.h" + +struct FuncTypeDesc { + TypeDesc ret; + std::vector params; +}; + +struct BoundDecl { + enum class Kind { Var, Const, Param } kind = Kind::Var; + SysYParser::VarDefContext* var_decl = nullptr; + SysYParser::ConstDefContext* const_decl = nullptr; + SysYParser::FuncFParamContext* param_decl = nullptr; +}; class SemanticContext { public: - void BindVarUse(SysYParser::VarContext* use, - SysYParser::VarDefContext* decl) { + void BindVarUse(SysYParser::LValContext* use, BoundDecl decl) { var_uses_[use] = decl; } - SysYParser::VarDefContext* ResolveVarUse( - const SysYParser::VarContext* use) const { + BoundDecl ResolveVarUse(const SysYParser::LValContext* use) const { auto it = var_uses_.find(use); - return it == var_uses_.end() ? nullptr : it->second; + return it == var_uses_.end() ? BoundDecl{} : it->second; + } + + void RegisterVarDecl(SysYParser::VarDefContext* decl, TypeDesc ty) { + var_types_[decl] = std::move(ty); + } + + void RegisterConstDecl(SysYParser::ConstDefContext* decl, TypeDesc ty) { + const_types_[decl] = std::move(ty); + } + + void RegisterParam(SysYParser::FuncFParamContext* decl, TypeDesc ty) { + param_types_[decl] = std::move(ty); + } + + void RegisterFunc(SysYParser::FuncDefContext* decl, FuncTypeDesc ty) { + func_types_[decl] = std::move(ty); + } + + const TypeDesc* GetVarType(const SysYParser::VarDefContext* decl) const { + auto it = var_types_.find(decl); + return it == var_types_.end() ? nullptr : &it->second; + } + + const TypeDesc* GetConstType(const SysYParser::ConstDefContext* decl) const { + auto it = const_types_.find(decl); + return it == const_types_.end() ? nullptr : &it->second; + } + + const TypeDesc* GetParamType(const SysYParser::FuncFParamContext* decl) const { + auto it = param_types_.find(decl); + return it == param_types_.end() ? nullptr : &it->second; + } + + const FuncTypeDesc* GetFuncType(const SysYParser::FuncDefContext* decl) const { + auto it = func_types_.find(decl); + return it == func_types_.end() ? nullptr : &it->second; + } + + void BindFuncCall(SysYParser::UnaryExpContext* call, + SysYParser::FuncDefContext* decl) { + func_calls_[call] = decl; + } + + SysYParser::FuncDefContext* ResolveFuncCall( + const SysYParser::UnaryExpContext* call) const { + auto it = func_calls_.find(call); + return it == func_calls_.end() ? nullptr : it->second; } private: - std::unordered_map - var_uses_; + std::unordered_map var_uses_; + std::unordered_map var_types_; + std::unordered_map const_types_; + std::unordered_map param_types_; + std::unordered_map func_types_; + std::unordered_map + func_calls_; }; -// 目前仅检查: -// - 变量先声明后使用 -// - 局部变量不允许重复定义 -SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); +SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); \ No newline at end of file diff --git a/include/sem/SymbolTable.h b/include/sem/SymbolTable.h index c9396dd..44110bf 100644 --- a/include/sem/SymbolTable.h +++ b/include/sem/SymbolTable.h @@ -1,17 +1,42 @@ -// 极简符号表:记录局部变量定义点。 +// 符号表:记录局部变量/常量/参数定义。 #pragma once +#include #include #include +#include -#include "SysYParser.h" + #include "SysYParser.h" + +enum class BaseTypeKind { Int, Float, Void }; + +struct TypeDesc { + BaseTypeKind base = BaseTypeKind::Int; + std::vector dims; // 为空表示标量;数组维度允许首维为 -1 表示形参不定长 + bool is_const = false; +}; + +enum class SymbolKind { Var, Const, Param }; + +struct SymbolEntry { + SymbolKind kind = SymbolKind::Var; + SysYParser::VarDefContext* var_decl = nullptr; + SysYParser::ConstDefContext* const_decl = nullptr; + SysYParser::FuncFParamContext* param_decl = nullptr; + TypeDesc type; // 记录类型信息 + bool is_const = false; + std::optional const_value; +}; class SymbolTable { public: - void Add(const std::string& name, SysYParser::VarDefContext* decl); - bool Contains(const std::string& name) const; - SysYParser::VarDefContext* Lookup(const std::string& name) const; + void EnterScope(); + void ExitScope(); + + bool ContainsInCurrentScope(const std::string& name) const; + void Add(const std::string& name, const SymbolEntry& entry); + const SymbolEntry* Lookup(const std::string& name) const; private: - std::unordered_map table_; -}; + std::vector> scopes_; +}; \ No newline at end of file diff --git a/scripts/run_lab2.sh b/scripts/run_lab2.sh new file mode 100644 index 0000000..c4adf45 --- /dev/null +++ b/scripts/run_lab2.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +set -euo pipefail + +# Reconfigure with IR pipeline enabled, build, then run Lab2 test script. +RESULT_FILE="test/test_result/run_lab2_result.log" +mkdir -p "$(dirname \"$RESULT_FILE\")" +: > "$RESULT_FILE" + +{ + echo "[run_lab2] start: $(date '+%Y-%m-%d %H:%M:%S')" + echo "[run_lab2] logging to: $RESULT_FILE" + + cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF + cmake --build build -j "$(nproc)" + bash scripts/test_lab2.sh + + echo "[run_lab2] end: $(date '+%Y-%m-%d %H:%M:%S')" +} 2>&1 | tee "$RESULT_FILE" diff --git a/scripts/test_lab2.sh b/scripts/test_lab2.sh new file mode 100644 index 0000000..5d6c8a3 --- /dev/null +++ b/scripts/test_lab2.sh @@ -0,0 +1,115 @@ +#!/usr/bin/env bash + +set -euo pipefail + +# Lab2 quick/full verification helper. +# Usage: +# bash scripts/test_lab2.sh +# Optional env vars: +# COMPILER=./build/bin/compiler +# CASE_DIR=test/test_case/functional +# OUT_DIR=test/test_result/lab2_ir +# LOG_FILE=test/test_result/lab2_test.log + +COMPILER="${COMPILER:-./build/bin/compiler}" +CASE_DIR="${CASE_DIR:-test/test_case/functional}" +OUT_DIR="${OUT_DIR:-test/test_result/lab2_ir}" +LOG_FILE="${LOG_FILE:-test/test_result/lab2_test.log}" +VERIFY_SCRIPT="./scripts/verify_ir.sh" + +if [[ ! -x "$COMPILER" ]]; then + echo "compiler not found or not executable: $COMPILER" >&2 + echo "build first:" >&2 + echo " cmake -S . -B build -DCMAKE_BUILD_TYPE=Release" >&2 + echo " cmake --build build -j \"\$(nproc)\"" >&2 + exit 1 +fi + +if [[ ! -x "$VERIFY_SCRIPT" ]]; then + echo "verify script not found or not executable: $VERIFY_SCRIPT" >&2 + exit 1 +fi + +if [[ ! -d "$CASE_DIR" ]]; then + echo "case dir not found: $CASE_DIR" >&2 + exit 1 +fi + +mkdir -p "$OUT_DIR" + +# Preflight: ensure compiler supports IR emission (not parse-only build). +probe_input="$CASE_DIR/simple_add.sy" +probe_err="$OUT_DIR/.lab2_probe.err" +if [[ -f "$probe_input" ]]; then + set +e + "$COMPILER" --emit-ir "$probe_input" > /dev/null 2> "$probe_err" + probe_rc=$? + set -e + if [[ $probe_rc -ne 0 ]] && grep -Eiq "parse-only|IR/汇编输出已禁用" "$probe_err"; then + echo "detected parse-only compiler build, cannot run Lab2 IR tests." >&2 + echo "rebuild with IR enabled:" >&2 + echo " cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF" >&2 + echo " cmake --build build -j \"\$(nproc)\"" >&2 + rm -f "$probe_err" + exit 2 + fi + rm -f "$probe_err" +fi + +mkdir -p "$(dirname "$LOG_FILE")" +: > "$LOG_FILE" + +echo "[Lab2] start test" | tee -a "$LOG_FILE" +echo "compiler : $COMPILER" | tee -a "$LOG_FILE" +echo "cases : $CASE_DIR" | tee -a "$LOG_FILE" +echo "out dir : $OUT_DIR" | tee -a "$LOG_FILE" + +echo "[Step 1] single sample check: simple_add.sy" | tee -a "$LOG_FILE" +if "$VERIFY_SCRIPT" "$CASE_DIR/simple_add.sy" "$OUT_DIR" --run >> "$LOG_FILE" 2>&1; then + echo "single sample: PASS" | tee -a "$LOG_FILE" +else + echo "single sample: FAIL" | tee -a "$LOG_FILE" + echo "stop here. see log: $LOG_FILE" >&2 + exit 1 +fi + +echo "[Step 2] full functional regression" | tee -a "$LOG_FILE" + +pass=0 +fail=0 +total=0 +failed_list=() + +while IFS= read -r -d '' sy; do + total=$((total + 1)) + name="$(basename "$sy")" + echo "[$total] $name" | tee -a "$LOG_FILE" + + if "$VERIFY_SCRIPT" "$sy" "$OUT_DIR" --run >> "$LOG_FILE" 2>&1; then + pass=$((pass + 1)) + echo " PASS" | tee -a "$LOG_FILE" + else + fail=$((fail + 1)) + failed_list+=("$sy") + echo " FAIL" | tee -a "$LOG_FILE" + fi +done < <(find "$CASE_DIR" -type f -name "*.sy" -print0 | sort -z) + +echo "" | tee -a "$LOG_FILE" +echo "[Summary]" | tee -a "$LOG_FILE" +echo "total: $total" | tee -a "$LOG_FILE" +echo "pass : $pass" | tee -a "$LOG_FILE" +echo "fail : $fail" | tee -a "$LOG_FILE" + +if [[ $fail -gt 0 ]]; then + echo "failed cases:" | tee -a "$LOG_FILE" + for f in "${failed_list[@]}"; do + echo " - $f" | tee -a "$LOG_FILE" + done + echo "Lab2 target is not fully met yet." | tee -a "$LOG_FILE" + echo "see details in $LOG_FILE" + exit 1 +fi + +echo "All functional cases passed. Lab2 target (functional regression) is met." | tee -a "$LOG_FILE" +echo "see details in $LOG_FILE" diff --git a/scripts/verify_ir.sh b/scripts/verify_ir.sh index f41f6b3..6600198 100755 --- a/scripts/verify_ir.sh +++ b/scripts/verify_ir.sh @@ -60,7 +60,7 @@ if [[ "$run_exec" == true ]]; then stdout_file="$out_dir/$stem.stdout" actual_file="$out_dir/$stem.actual.out" llc -filetype=obj "$out_file" -o "$obj" - clang "$obj" -o "$exe" + clang "$obj" sylib/sylib.c -o "$exe" echo "运行 $exe ..." set +e if [[ -f "$stdin_file" ]]; then diff --git a/src/ir/BasicBlock.cpp b/src/ir/BasicBlock.cpp index b18502c..8d9affa 100644 --- a/src/ir/BasicBlock.cpp +++ b/src/ir/BasicBlock.cpp @@ -13,9 +13,9 @@ namespace ir { -// 当前 BasicBlock 还没有专门的 label type,因此先用 void 作为占位类型。 +// BasicBlock 使用 label type。 BasicBlock::BasicBlock(std::string name) - : Value(Type::GetVoidType(), std::move(name)) {} + : Value(Type::GetLabelType(), std::move(name)) {} Function* BasicBlock::GetParent() const { return parent_; } @@ -42,4 +42,38 @@ const std::vector& BasicBlock::GetSuccessors() const { return successors_; } +void BasicBlock::AddPredecessor(BasicBlock* pred) { + if (!pred) return; + for (auto* p : predecessors_) { + if (p == pred) return; + } + predecessors_.push_back(pred); +} + +void BasicBlock::AddSuccessor(BasicBlock* succ) { + if (!succ) return; + for (auto* s : successors_) { + if (s == succ) return; + } + successors_.push_back(succ); +} + +void BasicBlock::LinkSuccessorsIfNeeded(Instruction* inst) { + if (!inst) return; + if (auto* br = dynamic_cast(inst)) { + auto* dest = br->GetDest(); + AddSuccessor(dest); + dest->AddPredecessor(this); + return; + } + if (auto* cbr = dynamic_cast(inst)) { + auto* t = cbr->GetTrueDest(); + auto* f = cbr->GetFalseDest(); + AddSuccessor(t); + AddSuccessor(f); + t->AddPredecessor(this); + f->AddPredecessor(this); + } +} + } // namespace ir diff --git a/src/ir/Context.cpp b/src/ir/Context.cpp index 16c982c..fce0697 100644 --- a/src/ir/Context.cpp +++ b/src/ir/Context.cpp @@ -1,6 +1,7 @@ // 管理基础类型、整型常量池和临时名生成。 #include "ir/IR.h" +#include #include namespace ir { @@ -15,9 +16,43 @@ ConstantInt* Context::GetConstInt(int v) { return inserted->second.get(); } +ConstantInt* Context::GetConstBool(bool v) { + int iv = v ? 1 : 0; + auto it = const_bools_.find(iv); + if (it != const_bools_.end()) return it->second.get(); + auto inserted = const_bools_.emplace( + iv, std::make_unique(Type::GetInt1Type(), iv)).first; + return inserted->second.get(); +} + +static uint32_t FloatToBits(float v) { + uint32_t bits = 0; + std::memcpy(&bits, &v, sizeof(float)); + return bits; +} + +ConstantFloat* Context::GetConstFloat(float v) { + uint32_t bits = FloatToBits(v); + auto it = const_floats_.find(bits); + if (it != const_floats_.end()) return it->second.get(); + auto inserted = const_floats_.emplace( + bits, std::make_unique(Type::GetFloatType(), v)).first; + return inserted->second.get(); +} + +ConstantArray* Context::CreateConstArray(std::shared_ptr array_ty, + std::vector elements) { + if (!array_ty || !array_ty->IsArray()) { + throw std::runtime_error("CreateConstArray 需要 array type"); + } + const_arrays_.push_back( + std::make_unique(std::move(array_ty), std::move(elements))); + return const_arrays_.back().get(); +} + std::string Context::NextTemp() { std::ostringstream oss; - oss << "%" << ++temp_index_; + oss << "%t" << ++temp_index_; return oss.str(); } diff --git a/src/ir/Function.cpp b/src/ir/Function.cpp index cf14d48..f27dd98 100644 --- a/src/ir/Function.cpp +++ b/src/ir/Function.cpp @@ -5,13 +5,32 @@ namespace ir { -Function::Function(std::string name, std::shared_ptr ret_type) - : Value(std::move(ret_type), std::move(name)) { - entry_ = CreateBlock("entry"); +Function::Function(std::string name, std::shared_ptr func_type, + bool is_declaration) + : Value(std::move(func_type), std::move(name)), + is_declaration_(is_declaration) { + if (!type_ || !type_->IsFunction()) { + throw std::runtime_error("Function 需要 function type"); + } + const auto& params = type_->GetParamTypes(); + args_.reserve(params.size()); + for (size_t i = 0; i < params.size(); ++i) { + args_.push_back(std::make_unique(params[i], "%arg" + std::to_string(i), i)); + } + if (!is_declaration_) { + entry_ = CreateBlock("entry"); + } } BasicBlock* Function::CreateBlock(const std::string& name) { - auto block = std::make_unique(name); + std::string base = name.empty() ? "bb" : name; + auto& count = block_name_counts_[base]; + std::string final_name = base; + if (count > 0) { + final_name = base + "." + std::to_string(count); + } + ++count; + auto block = std::make_unique(final_name); auto* ptr = block.get(); ptr->SetParent(this); blocks_.push_back(std::move(block)); @@ -29,4 +48,31 @@ const std::vector>& Function::GetBlocks() const { return blocks_; } +const std::vector>& Function::GetArguments() const { + return args_; +} + +size_t Function::GetNumArgs() const { return args_.size(); } + +Argument* Function::GetArg(size_t index) { + if (index >= args_.size()) { + throw std::out_of_range("Function arg index out of range"); + } + return args_[index].get(); +} + +std::shared_ptr Function::GetFunctionType() const { return type_; } + +std::shared_ptr Function::GetReturnType() const { + if (!type_ || !type_->IsFunction()) { + throw std::runtime_error("Function type 缺失"); + } + return type_->GetReturnType(); +} + +bool Function::IsDeclaration() const { return is_declaration_; } + +Argument::Argument(std::shared_ptr ty, std::string name, size_t index) + : Value(std::move(ty), std::move(name)), index_(index) {} + } // namespace ir diff --git a/src/ir/GlobalValue.cpp b/src/ir/GlobalValue.cpp index 7c2abe1..eb6dc5a 100644 --- a/src/ir/GlobalValue.cpp +++ b/src/ir/GlobalValue.cpp @@ -8,4 +8,23 @@ namespace ir { GlobalValue::GlobalValue(std::shared_ptr ty, std::string name) : User(std::move(ty), std::move(name)) {} +GlobalVariable::GlobalVariable(std::shared_ptr value_ty, std::string name, + ConstantValue* init, bool is_const) + : GlobalValue(Type::GetPointerType(value_ty), std::move(name)), + value_type_(std::move(value_ty)), + initializer_(init), + is_const_(is_const) { + if (!value_type_) { + throw std::runtime_error("GlobalVariable 缺少 value type"); + } +} + +const std::shared_ptr& GlobalVariable::GetValueType() const { + return value_type_; +} + +ConstantValue* GlobalVariable::GetInitializer() const { return initializer_; } + +bool GlobalVariable::IsConst() const { return is_const_; } + } // namespace ir diff --git a/src/ir/IRBuilder.cpp b/src/ir/IRBuilder.cpp index 90f03c4..4987681 100644 --- a/src/ir/IRBuilder.cpp +++ b/src/ir/IRBuilder.cpp @@ -9,6 +9,42 @@ #include "utils/Log.h" namespace ir { + +static std::string TypeToString(const Type& ty) { + switch (ty.GetKind()) { + case Type::Kind::Void: + return "void"; + case Type::Kind::Int1: + return "i1"; + case Type::Kind::Int32: + return "i32"; + case Type::Kind::Float: + return "float"; + case Type::Kind::Label: + return "label"; + case Type::Kind::Pointer: + return TypeToString(*ty.GetElementType()) + "*"; + case Type::Kind::Array: { + return "[" + std::to_string(ty.GetArraySize()) + " x " + + TypeToString(*ty.GetElementType()) + "]"; + } + case Type::Kind::Function: { + std::string out = TypeToString(*ty.GetReturnType()) + " ("; + const auto& params = ty.GetParamTypes(); + for (size_t i = 0; i < params.size(); ++i) { + if (i > 0) out += ", "; + out += TypeToString(*params[i]); + } + if (ty.IsVarArg()) { + if (!params.empty()) out += ", "; + out += "..."; + } + out += ")"; + return out; + } + } + return "?"; +} IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb) : ctx_(ctx), insert_block_(bb) {} @@ -42,11 +78,107 @@ BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs, return CreateBinary(Opcode::Add, lhs, rhs, name); } -AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { +BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::Sub, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::Mul, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateSDiv(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::SDiv, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateSRem(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::SRem, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateFAdd(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::FAdd, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateFSub(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::FSub, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateFMul(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::FMul, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateFDiv(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::FDiv, lhs, rhs, name); +} + +ICmpInst* IRBuilder::CreateICmp(ICmpPredicate pred, Value* lhs, Value* rhs, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(pred, lhs, rhs, name); +} + +FCmpInst* IRBuilder::CreateFCmp(FCmpPredicate pred, Value* lhs, Value* rhs, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(pred, lhs, rhs, name); +} + +CastInst* IRBuilder::CreateSIToFP(Value* src, std::shared_ptr dst_ty, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Opcode::SIToFP, std::move(dst_ty), src, + name); +} + +CastInst* IRBuilder::CreateFPToSI(Value* src, std::shared_ptr dst_ty, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Opcode::FPToSI, std::move(dst_ty), src, + name); +} + +CastInst* IRBuilder::CreateZExt(Value* src, std::shared_ptr dst_ty, + const std::string& name) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); } - return insert_block_->Append(Type::GetPtrInt32Type(), name); + return insert_block_->Append(Opcode::ZExt, std::move(dst_ty), src, + name); +} + +ConstantInt* IRBuilder::CreateConstBool(bool v) { + return ctx_.GetConstBool(v); +} + +ConstantFloat* IRBuilder::CreateConstFloat(float v) { + return ctx_.GetConstFloat(v); +} + +AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr ty, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(std::move(ty), name); +} + +AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { + return CreateAlloca(Type::GetInt32Type(), name); } LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { @@ -57,7 +189,11 @@ LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { throw std::runtime_error( FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr")); } - return insert_block_->Append(Type::GetInt32Type(), ptr, name); + if (!ptr->GetType() || !ptr->GetType()->IsPointer()) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateLoad ptr 不是指针")); + } + auto val_ty = ptr->GetType()->GetElementType(); + return insert_block_->Append(val_ty, ptr, name); } StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { @@ -75,6 +211,95 @@ StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { return insert_block_->Append(Type::GetVoidType(), val, ptr); } +static std::shared_ptr ResolveGepResultType(const std::shared_ptr& base_ptr_ty, + size_t index_count) { + if (!base_ptr_ty || !base_ptr_ty->IsPointer()) { + throw std::runtime_error("GEP base type 必须是指针"); + } + auto cur = base_ptr_ty->GetElementType(); + for (size_t i = 0; i < index_count; ++i) { + if (cur->IsArray()) { + cur = cur->GetElementType(); + continue; + } + if (cur->IsPointer()) { + cur = cur->GetElementType(); + continue; + } + } + return Type::GetPointerType(cur); +} + +GepInst* IRBuilder::CreateGep(Value* base_ptr, std::vector indices, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!base_ptr || !base_ptr->GetType() || !base_ptr->GetType()->IsPointer()) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateGep base_ptr 非指针")); + } + auto result_ty = ResolveGepResultType(base_ptr->GetType(), indices.size()); + return insert_block_->Append(result_ty, base_ptr, std::move(indices), + name); +} + +CallInst* IRBuilder::CreateCall(Value* callee, std::vector args, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!callee || !callee->GetType()) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCall 缺少 callee")); + } + std::shared_ptr func_ty; + if (callee->GetType()->IsFunction()) { + func_ty = callee->GetType(); + } else if (callee->GetType()->IsPointer() && + callee->GetType()->GetElementType()->IsFunction()) { + func_ty = callee->GetType()->GetElementType(); + } else { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCall callee 非函数")); + } + const auto& params = func_ty->GetParamTypes(); + if (!func_ty->IsVarArg() && params.size() != args.size()) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCall 参数数量不匹配")); + } + for (size_t i = 0; i < params.size() && i < args.size(); ++i) { + if (!args[i] || !args[i]->GetType() || + !args[i]->GetType()->Equals(*params[i])) { + std::string msg = "IRBuilder::CreateCall 参数类型不匹配: arg" + + std::to_string(i) + " got " + + TypeToString(*args[i]->GetType()) + ", expect " + + TypeToString(*params[i]); + throw std::runtime_error(FormatError("ir", msg)); + } + } + auto ret_ty = func_ty->GetReturnType(); + return insert_block_->Append(ret_ty, callee, std::move(args), name); +} + +PhiInst* IRBuilder::CreatePhi(std::shared_ptr ty, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(std::move(ty), name); +} + +BranchInst* IRBuilder::CreateBr(BasicBlock* dest) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(dest); +} + +CondBrInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* true_dest, + BasicBlock* false_dest) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(cond, true_dest, false_dest); +} + ReturnInst* IRBuilder::CreateRet(Value* v) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); @@ -86,4 +311,11 @@ ReturnInst* IRBuilder::CreateRet(Value* v) { return insert_block_->Append(Type::GetVoidType(), v); } +ReturnInst* IRBuilder::CreateRetVoid() { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Type::GetVoidType()); +} + } // namespace ir diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 30efbb6..7e6fa60 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -5,6 +5,11 @@ #include "ir/IR.h" #include +#include +#include +#include +#include +#include #include #include @@ -12,14 +17,41 @@ namespace ir { -static const char* TypeToString(const Type& ty) { +static std::string TypeToString(const Type& ty) { switch (ty.GetKind()) { case Type::Kind::Void: return "void"; + case Type::Kind::Int1: + return "i1"; case Type::Kind::Int32: return "i32"; - case Type::Kind::PtrInt32: - return "i32*"; + case Type::Kind::Float: + return "float"; + case Type::Kind::Label: + return "label"; + case Type::Kind::Pointer: + return TypeToString(*ty.GetElementType()) + "*"; + case Type::Kind::Array: { + std::ostringstream oss; + oss << "[" << ty.GetArraySize() << " x " + << TypeToString(*ty.GetElementType()) << "]"; + return oss.str(); + } + case Type::Kind::Function: { + std::ostringstream oss; + oss << TypeToString(*ty.GetReturnType()) << " ("; + const auto& params = ty.GetParamTypes(); + for (size_t i = 0; i < params.size(); ++i) { + if (i > 0) oss << ", "; + oss << TypeToString(*params[i]); + } + if (ty.IsVarArg()) { + if (!params.empty()) oss << ", "; + oss << "..."; + } + oss << ")"; + return oss.str(); + } } throw std::runtime_error(FormatError("ir", "未知类型")); } @@ -32,6 +64,18 @@ static const char* OpcodeToString(Opcode op) { return "sub"; case Opcode::Mul: return "mul"; + case Opcode::SDiv: + return "sdiv"; + case Opcode::SRem: + return "srem"; + case Opcode::FAdd: + return "fadd"; + case Opcode::FSub: + return "fsub"; + case Opcode::FMul: + return "fmul"; + case Opcode::FDiv: + return "fdiv"; case Opcode::Alloca: return "alloca"; case Opcode::Load: @@ -40,21 +84,161 @@ static const char* OpcodeToString(Opcode op) { return "store"; case Opcode::Ret: return "ret"; + case Opcode::Br: + return "br"; + case Opcode::CondBr: + return "br"; + case Opcode::ICmp: + return "icmp"; + case Opcode::FCmp: + return "fcmp"; + case Opcode::Call: + return "call"; + case Opcode::Phi: + return "phi"; + case Opcode::Gep: + return "getelementptr"; + case Opcode::SIToFP: + return "sitofp"; + case Opcode::FPToSI: + return "fptosi"; + case Opcode::ZExt: + return "zext"; } return "?"; } -static std::string ValueToString(const Value* v) { - if (auto* ci = dynamic_cast(v)) { +static std::string FloatToString(float v) { + std::uint32_t bits = 0; + static_assert(sizeof(bits) == sizeof(v), "float size mismatch"); + std::memcpy(&bits, &v, sizeof(bits)); + std::ostringstream oss; + oss << "bitcast (i32 " << std::dec << static_cast(bits) + << " to float)"; + return oss.str(); +} + +static std::string ConstantToString(const ConstantValue* c) { + if (auto* ci = dynamic_cast(c)) { return std::to_string(ci->GetValue()); } + if (auto* cf = dynamic_cast(c)) { + return FloatToString(cf->GetValue()); + } + if (auto* ca = dynamic_cast(c)) { + std::ostringstream oss; + oss << "["; + const auto& elems = ca->GetElements(); + for (size_t i = 0; i < elems.size(); ++i) { + if (i > 0) oss << ", "; + oss << TypeToString(*elems[i]->GetType()) << " " + << ConstantToString(elems[i]); + } + oss << "]"; + return oss.str(); + } + return ""; +} + +static std::string ValueToString(const Value* v) { + if (auto* c = dynamic_cast(v)) { + return ConstantToString(c); + } + if (auto* func = dynamic_cast(v)) { + const auto& name = func->GetName(); + if (!name.empty() && name[0] == '@') return name; + return "@" + name; + } + if (auto* gv = dynamic_cast(v)) { + const auto& name = gv->GetName(); + if (!name.empty() && name[0] == '@') return name; + return "@" + name; + } return v ? v->GetName() : ""; } +static std::string LabelToString(const BasicBlock* bb) { + if (!bb) return "%"; + const auto& name = bb->GetName(); + if (!name.empty() && name[0] == '%') return name; + return "%" + name; +} + +static const char* ICmpPredToString(ICmpPredicate pred) { + switch (pred) { + case ICmpPredicate::Eq: + return "eq"; + case ICmpPredicate::Ne: + return "ne"; + case ICmpPredicate::Slt: + return "slt"; + case ICmpPredicate::Sle: + return "sle"; + case ICmpPredicate::Sgt: + return "sgt"; + case ICmpPredicate::Sge: + return "sge"; + } + return "?"; +} + +static const char* FCmpPredToString(FCmpPredicate pred) { + switch (pred) { + case FCmpPredicate::Oeq: + return "oeq"; + case FCmpPredicate::One: + return "one"; + case FCmpPredicate::Olt: + return "olt"; + case FCmpPredicate::Ole: + return "ole"; + case FCmpPredicate::Ogt: + return "ogt"; + case FCmpPredicate::Oge: + return "oge"; + } + return "?"; +} + void IRPrinter::Print(const Module& module, std::ostream& os) { + for (const auto& g : module.GetGlobals()) { + if (!g) continue; + os << "@" << g->GetName() << " = " + << (g->IsConst() ? "constant " : "global ") + << TypeToString(*g->GetValueType()) << " "; + if (auto* init = g->GetInitializer()) { + os << ConstantToString(init); + } else { + if (g->GetValueType()->IsArray()) { + os << "zeroinitializer"; + } else if (g->GetValueType()->IsFloat()) { + os << "0.0"; + } else { + os << "0"; + } + } + os << "\n"; + } for (const auto& func : module.GetFunctions()) { - os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() - << "() {\n"; + if (func->IsDeclaration()) { + os << "declare " << TypeToString(*func->GetReturnType()) << " @" + << func->GetName() << "("; + const auto& args = func->GetArguments(); + for (size_t i = 0; i < args.size(); ++i) { + if (i > 0) os << ", "; + os << TypeToString(*args[i]->GetType()); + } + os << ")\n"; + continue; + } + os << "define " << TypeToString(*func->GetReturnType()) << " @" + << func->GetName() << "("; + const auto& args = func->GetArguments(); + for (size_t i = 0; i < args.size(); ++i) { + if (i > 0) os << ", "; + os << TypeToString(*args[i]->GetType()) << " " << args[i]->GetName(); + } + os << ") {\n"; for (const auto& bb : func->GetBlocks()) { if (!bb) { continue; @@ -65,7 +249,13 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { switch (inst->GetOpcode()) { case Opcode::Add: case Opcode::Sub: - case Opcode::Mul: { + case Opcode::Mul: + case Opcode::SDiv: + case Opcode::SRem: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: { auto* bin = static_cast(inst); os << " " << bin->GetName() << " = " << OpcodeToString(bin->GetOpcode()) << " " @@ -74,27 +264,122 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { << ValueToString(bin->GetRhs()) << "\n"; break; } + case Opcode::ICmp: { + auto* cmp = static_cast(inst); + os << " " << cmp->GetName() << " = icmp " + << ICmpPredToString(cmp->GetPredicate()) << " " + << TypeToString(*cmp->GetLhs()->GetType()) << " " + << ValueToString(cmp->GetLhs()) << ", " + << ValueToString(cmp->GetRhs()) << "\n"; + break; + } + case Opcode::FCmp: { + auto* cmp = static_cast(inst); + os << " " << cmp->GetName() << " = fcmp " + << FCmpPredToString(cmp->GetPredicate()) << " " + << TypeToString(*cmp->GetLhs()->GetType()) << " " + << ValueToString(cmp->GetLhs()) << ", " + << ValueToString(cmp->GetRhs()) << "\n"; + break; + } + case Opcode::SIToFP: + case Opcode::FPToSI: + case Opcode::ZExt: { + auto* cast = static_cast(inst); + os << " " << cast->GetName() << " = " + << OpcodeToString(cast->GetOpcode()) << " " + << TypeToString(*cast->GetValue()->GetType()) << " " + << ValueToString(cast->GetValue()) << " to " + << TypeToString(*cast->GetType()) << "\n"; + break; + } case Opcode::Alloca: { auto* alloca = static_cast(inst); - os << " " << alloca->GetName() << " = alloca i32\n"; + os << " " << alloca->GetName() << " = alloca " + << TypeToString(*alloca->GetAllocatedType()) << "\n"; break; } case Opcode::Load: { auto* load = static_cast(inst); - os << " " << load->GetName() << " = load i32, i32* " + os << " " << load->GetName() << " = load " + << TypeToString(*load->GetType()) << ", " + << TypeToString(*load->GetPtr()->GetType()) << " " << ValueToString(load->GetPtr()) << "\n"; break; } case Opcode::Store: { auto* store = static_cast(inst); - os << " store i32 " << ValueToString(store->GetValue()) - << ", i32* " << ValueToString(store->GetPtr()) << "\n"; + os << " store " << TypeToString(*store->GetValue()->GetType()) + << " " << ValueToString(store->GetValue()) << ", " + << TypeToString(*store->GetPtr()->GetType()) << " " + << ValueToString(store->GetPtr()) << "\n"; + break; + } + case Opcode::Br: { + auto* br = static_cast(inst); + os << " br label " << LabelToString(br->GetDest()) << "\n"; + break; + } + case Opcode::CondBr: { + auto* cbr = static_cast(inst); + os << " br i1 " << ValueToString(cbr->GetCond()) + << ", label " << LabelToString(cbr->GetTrueDest()) + << ", label " << LabelToString(cbr->GetFalseDest()) << "\n"; + break; + } + case Opcode::Call: { + auto* call = static_cast(inst); + const auto& args = call->GetArgs(); + if (!call->GetType()->IsVoid()) { + os << " " << call->GetName() << " = "; + } else { + os << " "; + } + os << "call " << TypeToString(*call->GetType()) << " " + << ValueToString(call->GetCallee()) << "("; + for (size_t i = 0; i < args.size(); ++i) { + if (i > 0) os << ", "; + os << TypeToString(*args[i]->GetType()) << " " + << ValueToString(args[i]); + } + os << ")\n"; + break; + } + case Opcode::Phi: { + auto* phi = static_cast(inst); + os << " " << phi->GetName() << " = phi " + << TypeToString(*phi->GetType()) << " "; + const auto& values = phi->GetIncomingValues(); + const auto& blocks = phi->GetIncomingBlocks(); + for (size_t i = 0; i < values.size(); ++i) { + if (i > 0) os << ", "; + os << "[ " << ValueToString(values[i]) << ", " + << LabelToString(blocks[i]) << " ]"; + } + os << "\n"; + break; + } + case Opcode::Gep: { + auto* gep = static_cast(inst); + os << " " << gep->GetName() << " = getelementptr " + << TypeToString(*gep->GetBasePtr()->GetType()->GetElementType()) + << ", " << TypeToString(*gep->GetBasePtr()->GetType()) << " " + << ValueToString(gep->GetBasePtr()); + const auto& idx = gep->GetIndices(); + for (auto* v : idx) { + os << ", i32 " << ValueToString(v); + } + os << "\n"; break; } case Opcode::Ret: { auto* ret = static_cast(inst); - os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " - << ValueToString(ret->GetValue()) << "\n"; + if (ret->HasReturnValue()) { + os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " + << ValueToString(ret->GetValue()) << "\n"; + } else { + os << " ret void\n"; + } break; } } diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index 7928716..6e053c8 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -52,17 +52,30 @@ Instruction::Instruction(Opcode op, std::shared_ptr ty, std::string name) Opcode Instruction::GetOpcode() const { return opcode_; } -bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret; } +bool Instruction::IsTerminator() const { + return opcode_ == Opcode::Ret || opcode_ == Opcode::Br || + opcode_ == Opcode::CondBr; +} BasicBlock* Instruction::GetParent() const { return parent_; } void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; } +static bool IsIntBinaryOp(Opcode op) { + return op == Opcode::Add || op == Opcode::Sub || op == Opcode::Mul || + op == Opcode::SDiv || op == Opcode::SRem; +} + +static bool IsFloatBinaryOp(Opcode op) { + return op == Opcode::FAdd || op == Opcode::FSub || op == Opcode::FMul || + op == Opcode::FDiv; +} + BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, Value* rhs, std::string name) : Instruction(op, std::move(ty), std::move(name)) { - if (op != Opcode::Add) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add")); + if (!IsIntBinaryOp(op) && !IsFloatBinaryOp(op)) { + throw std::runtime_error(FormatError("ir", "BinaryInst 非算术 op")); } if (!lhs || !rhs) { throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数")); @@ -70,12 +83,15 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, if (!type_ || !lhs->GetType() || !rhs->GetType()) { throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息")); } - if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind() || - type_->GetKind() != lhs->GetType()->GetKind()) { + if (!lhs->GetType()->Equals(*rhs->GetType()) || + !type_->Equals(*lhs->GetType())) { throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配")); } - if (!type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32")); + if (IsIntBinaryOp(op) && !type_->IsInt32()) { + throw std::runtime_error(FormatError("ir", "整数二元只支持 i32")); + } + if (IsFloatBinaryOp(op) && !type_->IsFloat()) { + throw std::runtime_error(FormatError("ir", "浮点二元只支持 float")); } AddOperand(lhs); AddOperand(rhs); @@ -85,6 +101,127 @@ Value* BinaryInst::GetLhs() const { return GetOperand(0); } Value* BinaryInst::GetRhs() const { return GetOperand(1); } +ICmpInst::ICmpInst(ICmpPredicate pred, Value* lhs, Value* rhs, std::string name) + : Instruction(Opcode::ICmp, Type::GetInt1Type(), std::move(name)), + pred_(pred) { + if (!lhs || !rhs) { + throw std::runtime_error(FormatError("ir", "ICmpInst 缺少操作数")); + } + if (!lhs->GetType() || !rhs->GetType() || + !lhs->GetType()->Equals(*rhs->GetType())) { + throw std::runtime_error(FormatError("ir", "ICmpInst 类型不匹配")); + } + if (!lhs->GetType()->IsInt1() && !lhs->GetType()->IsInt32()) { + throw std::runtime_error(FormatError("ir", "ICmpInst 仅支持整型")); + } + AddOperand(lhs); + AddOperand(rhs); +} + +Value* ICmpInst::GetLhs() const { return GetOperand(0); } + +Value* ICmpInst::GetRhs() const { return GetOperand(1); } + +FCmpInst::FCmpInst(FCmpPredicate pred, Value* lhs, Value* rhs, std::string name) + : Instruction(Opcode::FCmp, Type::GetInt1Type(), std::move(name)), + pred_(pred) { + if (!lhs || !rhs) { + throw std::runtime_error(FormatError("ir", "FCmpInst 缺少操作数")); + } + if (!lhs->GetType() || !rhs->GetType() || + !lhs->GetType()->Equals(*rhs->GetType())) { + throw std::runtime_error(FormatError("ir", "FCmpInst 类型不匹配")); + } + if (!lhs->GetType()->IsFloat()) { + throw std::runtime_error(FormatError("ir", "FCmpInst 仅支持 float")); + } + AddOperand(lhs); + AddOperand(rhs); +} + +Value* FCmpInst::GetLhs() const { return GetOperand(0); } + +Value* FCmpInst::GetRhs() const { return GetOperand(1); } + +CastInst::CastInst(Opcode op, std::shared_ptr dst_ty, Value* src, + std::string name) + : Instruction(op, std::move(dst_ty), std::move(name)) { + if (op != Opcode::SIToFP && op != Opcode::FPToSI && op != Opcode::ZExt) { + throw std::runtime_error(FormatError("ir", "CastInst 不支持的 op")); + } + if (!src) { + throw std::runtime_error(FormatError("ir", "CastInst 缺少 src")); + } + if (op == Opcode::SIToFP) { + if (!src->GetType()->IsInt32() && !src->GetType()->IsInt1()) { + throw std::runtime_error(FormatError("ir", "SIToFP 仅支持整型")); + } + if (!type_ || !type_->IsFloat()) { + throw std::runtime_error(FormatError("ir", "SIToFP 目标必须是 float")); + } + } else if (op == Opcode::FPToSI) { + if (!src->GetType()->IsFloat()) { + throw std::runtime_error(FormatError("ir", "FPToSI 仅支持 float")); + } + if (!type_ || !type_->IsInt32()) { + throw std::runtime_error(FormatError("ir", "FPToSI 目标必须是 i32")); + } + } else { + if (!src->GetType()->IsInt1()) { + throw std::runtime_error(FormatError("ir", "ZExt 仅支持 i1")); + } + if (!type_ || !type_->IsInt32()) { + throw std::runtime_error(FormatError("ir", "ZExt 目标必须是 i32")); + } + } + AddOperand(src); +} + +Value* CastInst::GetValue() const { return GetOperand(0); } + +BranchInst::BranchInst(BasicBlock* dest) + : Instruction(Opcode::Br, Type::GetVoidType(), "") { + if (!dest) { + throw std::runtime_error(FormatError("ir", "BranchInst 缺少目标块")); + } + AddOperand(dest); +} + +BasicBlock* BranchInst::GetDest() const { + return static_cast(GetOperand(0)); +} + +CondBrInst::CondBrInst(Value* cond, BasicBlock* true_dest, + BasicBlock* false_dest) + : Instruction(Opcode::CondBr, Type::GetVoidType(), "") { + if (!cond || !true_dest || !false_dest) { + throw std::runtime_error(FormatError("ir", "CondBrInst 缺少参数")); + } + if (!cond->GetType() || !cond->GetType()->IsInt1()) { + throw std::runtime_error(FormatError("ir", "CondBrInst cond 必须是 i1")); + } + AddOperand(cond); + AddOperand(true_dest); + AddOperand(false_dest); +} + +Value* CondBrInst::GetCond() const { return GetOperand(0); } + +BasicBlock* CondBrInst::GetTrueDest() const { + return static_cast(GetOperand(1)); +} + +BasicBlock* CondBrInst::GetFalseDest() const { + return static_cast(GetOperand(2)); +} + +ReturnInst::ReturnInst(std::shared_ptr void_ty) + : Instruction(Opcode::Ret, std::move(void_ty), "") { + if (!type_ || !type_->IsVoid()) { + throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void")); + } +} + ReturnInst::ReturnInst(std::shared_ptr void_ty, Value* val) : Instruction(Opcode::Ret, std::move(void_ty), "") { if (!val) { @@ -96,26 +233,36 @@ ReturnInst::ReturnInst(std::shared_ptr void_ty, Value* val) AddOperand(val); } -Value* ReturnInst::GetValue() const { return GetOperand(0); } +bool ReturnInst::HasReturnValue() const { return GetNumOperands() > 0; } + +Value* ReturnInst::GetValue() const { + if (!HasReturnValue()) return nullptr; + return GetOperand(0); +} -AllocaInst::AllocaInst(std::shared_ptr ptr_ty, std::string name) - : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) { - if (!type_ || !type_->IsPtrInt32()) { - throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*")); +AllocaInst::AllocaInst(std::shared_ptr allocated_ty, std::string name) + : Instruction(Opcode::Alloca, Type::GetPointerType(allocated_ty), + std::move(name)), + allocated_type_(std::move(allocated_ty)) { + if (!allocated_type_) { + throw std::runtime_error(FormatError("ir", "AllocaInst 缺少类型")); } } +const std::shared_ptr& AllocaInst::GetAllocatedType() const { + return allocated_type_; +} + LoadInst::LoadInst(std::shared_ptr val_ty, Value* ptr, std::string name) : Instruction(Opcode::Load, std::move(val_ty), std::move(name)) { if (!ptr) { throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr")); } - if (!type_ || !type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32")); + if (!ptr->GetType() || !ptr->GetType()->IsPointer()) { + throw std::runtime_error(FormatError("ir", "LoadInst ptr 不是指针")); } - if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { - throw std::runtime_error( - FormatError("ir", "LoadInst 当前只支持从 i32* 加载")); + if (!type_ || !ptr->GetType()->GetElementType()->Equals(*type_)) { + throw std::runtime_error(FormatError("ir", "LoadInst 类型不匹配")); } AddOperand(ptr); } @@ -133,12 +280,11 @@ StoreInst::StoreInst(std::shared_ptr void_ty, Value* val, Value* ptr) if (!type_ || !type_->IsVoid()) { throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void")); } - if (!val->GetType() || !val->GetType()->IsInt32()) { - throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32")); + if (!ptr->GetType() || !ptr->GetType()->IsPointer()) { + throw std::runtime_error(FormatError("ir", "StoreInst ptr 不是指针")); } - if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { - throw std::runtime_error( - FormatError("ir", "StoreInst 当前只支持写入 i32*")); + if (!ptr->GetType()->GetElementType()->Equals(*val->GetType())) { + throw std::runtime_error(FormatError("ir", "StoreInst 类型不匹配")); } AddOperand(val); AddOperand(ptr); @@ -148,4 +294,70 @@ Value* StoreInst::GetValue() const { return GetOperand(0); } Value* StoreInst::GetPtr() const { return GetOperand(1); } +CallInst::CallInst(std::shared_ptr ret_ty, Value* callee, + std::vector args, std::string name) + : Instruction(Opcode::Call, std::move(ret_ty), std::move(name)), + args_(std::move(args)) { + if (!callee) { + throw std::runtime_error(FormatError("ir", "CallInst 缺少 callee")); + } + AddOperand(callee); + for (auto* arg : args_) { + if (!arg) { + throw std::runtime_error(FormatError("ir", "CallInst arg 为空")); + } + AddOperand(arg); + } +} + +Value* CallInst::GetCallee() const { return GetOperand(0); } + +PhiInst::PhiInst(std::shared_ptr ty, std::string name) + : Instruction(Opcode::Phi, std::move(ty), std::move(name)) {} + +void PhiInst::AddIncoming(Value* value, BasicBlock* block) { + if (!value || !block) { + throw std::runtime_error(FormatError("ir", "PhiInst incoming 为空")); + } + if (!value->GetType() || !type_ || !value->GetType()->Equals(*type_)) { + throw std::runtime_error(FormatError("ir", "PhiInst 类型不匹配")); + } + incoming_values_.push_back(value); + incoming_blocks_.push_back(block); + AddOperand(value); + AddOperand(block); +} + +const std::vector& PhiInst::GetIncomingValues() const { + return incoming_values_; +} + +const std::vector& PhiInst::GetIncomingBlocks() const { + return incoming_blocks_; +} + +GepInst::GepInst(std::shared_ptr result_ptr_ty, Value* base_ptr, + std::vector indices, std::string name) + : Instruction(Opcode::Gep, std::move(result_ptr_ty), std::move(name)), + indices_(std::move(indices)) { + if (!base_ptr) { + throw std::runtime_error(FormatError("ir", "GepInst 缺少 base_ptr")); + } + if (!base_ptr->GetType() || !base_ptr->GetType()->IsPointer()) { + throw std::runtime_error(FormatError("ir", "GepInst base_ptr 不是指针")); + } + if (!type_ || !type_->IsPointer()) { + throw std::runtime_error(FormatError("ir", "GepInst 结果必须是指针")); + } + AddOperand(base_ptr); + for (auto* idx : indices_) { + if (!idx || !idx->GetType() || !idx->GetType()->IsInt32()) { + throw std::runtime_error(FormatError("ir", "GepInst index 必须是 i32")); + } + AddOperand(idx); + } +} + +Value* GepInst::GetBasePtr() const { return GetOperand(0); } + } // namespace ir diff --git a/src/ir/Module.cpp b/src/ir/Module.cpp index 928efdc..37bbb6f 100644 --- a/src/ir/Module.cpp +++ b/src/ir/Module.cpp @@ -10,12 +10,39 @@ const Context& Module::GetContext() const { return context_; } Function* Module::CreateFunction(const std::string& name, std::shared_ptr ret_type) { - functions_.push_back(std::make_unique(name, std::move(ret_type))); + auto func_ty = Type::GetFunctionType(std::move(ret_type), {}); + functions_.push_back(std::make_unique(name, std::move(func_ty))); return functions_.back().get(); } +Function* Module::CreateFunctionWithType(const std::string& name, + std::shared_ptr func_type) { + functions_.push_back( + std::make_unique(name, std::move(func_type), false)); + return functions_.back().get(); +} + +Function* Module::CreateFunctionDecl(const std::string& name, + std::shared_ptr func_type) { + functions_.push_back( + std::make_unique(name, std::move(func_type), true)); + return functions_.back().get(); +} + +GlobalVariable* Module::CreateGlobalVariable(const std::string& name, + std::shared_ptr value_type, + ConstantValue* init, bool is_const) { + globals_.push_back(std::make_unique( + std::move(value_type), name, init, is_const)); + return globals_.back().get(); +} + const std::vector>& Module::GetFunctions() const { return functions_; } +const std::vector>& Module::GetGlobals() const { + return globals_; +} + } // namespace ir diff --git a/src/ir/Type.cpp b/src/ir/Type.cpp index 3e1684d..aa8eb35 100644 --- a/src/ir/Type.cpp +++ b/src/ir/Type.cpp @@ -1,31 +1,148 @@ -// 当前仅支持 void、i32 和 i32*。 +// 支持 void/i1/i32/float/ptr/array/function/label。 #include "ir/IR.h" namespace ir { Type::Type(Kind k) : kind_(k) {} +Type::Type(Kind k, std::shared_ptr elem, size_t count) + : kind_(k), elem_type_(std::move(elem)), array_size_(count) {} + +Type::Type(Kind k, std::shared_ptr ret, + std::vector> params, bool is_vararg) + : kind_(k), ret_type_(std::move(ret)), param_types_(std::move(params)), + is_vararg_(is_vararg) {} + const std::shared_ptr& Type::GetVoidType() { static const std::shared_ptr type = std::make_shared(Kind::Void); return type; } +const std::shared_ptr& Type::GetInt1Type() { + static const std::shared_ptr type = std::make_shared(Kind::Int1); + return type; +} + const std::shared_ptr& Type::GetInt32Type() { static const std::shared_ptr type = std::make_shared(Kind::Int32); return type; } -const std::shared_ptr& Type::GetPtrInt32Type() { - static const std::shared_ptr type = std::make_shared(Kind::PtrInt32); +const std::shared_ptr& Type::GetFloatType() { + static const std::shared_ptr type = std::make_shared(Kind::Float); + return type; +} + +const std::shared_ptr& Type::GetLabelType() { + static const std::shared_ptr type = std::make_shared(Kind::Label); return type; } +std::shared_ptr Type::GetPointerType(std::shared_ptr elem) { + if (!elem) { + throw std::runtime_error("PointerType 缺少 element type"); + } + return std::make_shared(Kind::Pointer, std::move(elem), 0); +} + +std::shared_ptr Type::GetArrayType(std::shared_ptr elem, + size_t count) { + if (!elem) { + throw std::runtime_error("ArrayType 缺少 element type"); + } + return std::make_shared(Kind::Array, std::move(elem), count); +} + +std::shared_ptr Type::GetFunctionType( + std::shared_ptr ret, std::vector> params, + bool is_vararg) { + if (!ret) { + throw std::runtime_error("FunctionType 缺少 return type"); + } + return std::make_shared(Kind::Function, std::move(ret), + std::move(params), is_vararg); +} + Type::Kind Type::GetKind() const { return kind_; } bool Type::IsVoid() const { return kind_ == Kind::Void; } +bool Type::IsInt1() const { return kind_ == Kind::Int1; } + bool Type::IsInt32() const { return kind_ == Kind::Int32; } -bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; } +bool Type::IsFloat() const { return kind_ == Kind::Float; } + +bool Type::IsPointer() const { return kind_ == Kind::Pointer; } + +bool Type::IsArray() const { return kind_ == Kind::Array; } + +bool Type::IsFunction() const { return kind_ == Kind::Function; } + +bool Type::IsLabel() const { return kind_ == Kind::Label; } + +const std::shared_ptr& Type::GetElementType() const { + if (!elem_type_) { + throw std::runtime_error("Type 没有 element type"); + } + return elem_type_; +} + +size_t Type::GetArraySize() const { + if (!IsArray()) { + throw std::runtime_error("Type 不是 array"); + } + return array_size_; +} + +const std::shared_ptr& Type::GetReturnType() const { + if (!IsFunction()) { + throw std::runtime_error("Type 不是 function"); + } + return ret_type_; +} + +const std::vector>& Type::GetParamTypes() const { + if (!IsFunction()) { + throw std::runtime_error("Type 不是 function"); + } + return param_types_; +} + +bool Type::IsVarArg() const { + if (!IsFunction()) { + throw std::runtime_error("Type 不是 function"); + } + return is_vararg_; +} + +bool Type::Equals(const Type& other) const { + if (kind_ != other.kind_) return false; + switch (kind_) { + case Kind::Pointer: + return elem_type_ && other.elem_type_ && + elem_type_->Equals(*other.elem_type_); + case Kind::Array: + return array_size_ == other.array_size_ && elem_type_ && + other.elem_type_ && elem_type_->Equals(*other.elem_type_); + case Kind::Function: { + if (!ret_type_ || !other.ret_type_ || + !ret_type_->Equals(*other.ret_type_) || + is_vararg_ != other.is_vararg_ || + param_types_.size() != other.param_types_.size()) { + return false; + } + for (size_t i = 0; i < param_types_.size(); ++i) { + if (!param_types_[i] || !other.param_types_[i] || + !param_types_[i]->Equals(*other.param_types_[i])) { + return false; + } + } + return true; + } + default: + return true; + } +} } // namespace ir diff --git a/src/ir/Value.cpp b/src/ir/Value.cpp index 2e9f4c1..73d06f7 100644 --- a/src/ir/Value.cpp +++ b/src/ir/Value.cpp @@ -18,9 +18,21 @@ void Value::SetName(std::string n) { name_ = std::move(n); } bool Value::IsVoid() const { return type_ && type_->IsVoid(); } +bool Value::IsInt1() const { return type_ && type_->IsInt1(); } + bool Value::IsInt32() const { return type_ && type_->IsInt32(); } -bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); } +bool Value::IsFloat() const { return type_ && type_->IsFloat(); } + +bool Value::IsPointer() const { return type_ && type_->IsPointer(); } + +bool Value::IsArray() const { return type_ && type_->IsArray(); } + +bool Value::IsFunctionType() const { return type_ && type_->IsFunction(); } + +bool Value::IsPtrInt32() const { + return type_ && type_->IsPointer() && type_->GetElementType()->IsInt32(); +} bool Value::IsConstant() const { return dynamic_cast(this) != nullptr; @@ -78,6 +90,25 @@ ConstantValue::ConstantValue(std::shared_ptr ty, std::string name) : Value(std::move(ty), std::move(name)) {} ConstantInt::ConstantInt(std::shared_ptr ty, int v) - : ConstantValue(std::move(ty), ""), value_(v) {} + : ConstantValue(std::move(ty), ""), value_(v) { + if (!type_ || (!type_->IsInt32() && !type_->IsInt1())) { + throw std::runtime_error("ConstantInt 需要 i1/i32 类型"); + } +} + +ConstantFloat::ConstantFloat(std::shared_ptr ty, float v) + : ConstantValue(std::move(ty), ""), value_(v) { + if (!type_ || !type_->IsFloat()) { + throw std::runtime_error("ConstantFloat 需要 float 类型"); + } +} + +ConstantArray::ConstantArray(std::shared_ptr ty, + std::vector elements) + : ConstantValue(std::move(ty), ""), elements_(std::move(elements)) { + if (!type_ || !type_->IsArray()) { + throw std::runtime_error("ConstantArray 需要 array 类型"); + } +} } // namespace ir diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index 0eb62ae..afa3156 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -1,46 +1,32 @@ #include "irgen/IRGen.h" +#include #include #include "SysYParser.h" #include "ir/IR.h" #include "utils/Log.h" -namespace { - -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("irgen", "非法左值")); - } - return lvalue.ID()->getText(); -} - -} // namespace - -std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少语句块")); - } +std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) { + if (!ctx) return BlockFlow::Continue; + BlockFlow flow = BlockFlow::Continue; for (auto* item : ctx->blockItem()) { if (item) { if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { - // 当前语法要求 return 为块内最后一条语句;命中后可停止生成。 + flow = BlockFlow::Terminated; break; } } } - return {}; + return flow; } -IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult( - SysYParser::BlockItemContext& item) { +IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(SysYParser::BlockItemContext& item) { return std::any_cast(item.accept(this)); } std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少块内项")); - } + if (!ctx) return BlockFlow::Continue; if (ctx->decl()) { ctx->decl()->accept(this); return BlockFlow::Continue; @@ -48,60 +34,169 @@ std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { if (ctx->stmt()) { return ctx->stmt()->accept(this); } - throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明")); + return BlockFlow::Continue; } -// 变量声明的 IR 生成目前也是最小实现: -// - 先检查声明的基础类型,当前仅支持局部 int; -// - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。 -// -// 和更完整的版本相比,这里还没有: -// - 一个 Decl 中多个变量定义的顺序处理; -// - const、数组、全局变量等不同声明形态; -// - 更丰富的类型系统。 std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少变量声明")); - } - if (!ctx->btype() || !ctx->btype()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明")); + if (!ctx) return {}; + if (auto* constDecl = ctx->constDecl()) { + for (auto* def : constDecl->constDef()) { + def->accept(this); + } + return {}; } - auto* var_def = ctx->varDef(); - if (!var_def) { - throw std::runtime_error(FormatError("irgen", "非法变量声明")); + if (auto* varDecl = ctx->varDecl()) { + for (auto* varDef : varDecl->varDef()) { + varDef->accept(this); + } + return {}; } - var_def->accept(this); return {}; } - -// 当前仍是教学用的最小版本,因此这里只支持: -// - 局部 int 变量; -// - 标量初始化; -// - 一个 VarDef 对应一个槽位。 std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少变量定义")); - } - if (!ctx->lValue()) { + if (!ctx) return {}; + if (!ctx->ID()) { throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); } - GetLValueName(*ctx->lValue()); - if (storage_map_.find(ctx) != storage_map_.end()) { - throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); + if (!func_) { + const TypeDesc* ty = sema_.GetVarType(ctx); + if (!ty) { + throw std::runtime_error(FormatError("irgen", "全局变量类型缺失")); + } + if (global_var_storage_.find(ctx) != global_var_storage_.end()) { + throw std::runtime_error(FormatError("irgen", "重复生成全局变量")); + } + ir::ConstantValue* init = nullptr; + if (ty->dims.empty()) { + if (auto* initVal = ctx->initVal()) { + if (!initVal->exp()) { + throw std::runtime_error(FormatError("irgen", "全局变量初始化非法")); + } + init = EvalConstScalar(initVal->exp()); + if (ty->base == BaseTypeKind::Int && + dynamic_cast(init)) { + auto* cf = static_cast(init); + init = module_.GetContext().GetConstInt(static_cast(cf->GetValue())); + } else if (ty->base == BaseTypeKind::Float && + dynamic_cast(init)) { + auto* ci = static_cast(init); + init = module_.GetContext().GetConstFloat(static_cast(ci->GetValue())); + } + } + } else if (auto* initVal = ctx->initVal()) { + size_t total = ArrayTotalSize(*ty); + std::vector values( + total, + ty->base == BaseTypeKind::Float + ? static_cast( + module_.GetContext().GetConstFloat(0.0f)) + : static_cast( + module_.GetContext().GetConstInt(0))); + InitGlobalArray(*ty, initVal, values, 0, 0, 0); + init = module_.GetContext().CreateConstArray(ToIRType(*ty), values); + } + auto* gv = module_.CreateGlobalVariable(ctx->ID()->getText(), + ToIRType(*ty), init, false); + global_var_storage_[ctx] = gv; + return {}; + } + if (var_storage_.find(ctx) != var_storage_.end()) { + throw std::runtime_error(FormatError("irgen", "重复生成存储槽位")); } - auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); - storage_map_[ctx] = slot; + const TypeDesc* ty = sema_.GetVarType(ctx); + if (!ty) { + throw std::runtime_error(FormatError("irgen", "变量类型缺失")); + } + auto* slot = builder_.CreateAlloca(ToIRType(*ty), module_.GetContext().NextTemp()); + var_storage_[ctx] = slot; - ir::Value* init = nullptr; - if (auto* init_value = ctx->initValue()) { - if (!init_value->exp()) { - throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化")); + if (ty->dims.empty()) { + ir::Value* init = nullptr; + if (auto* initVal = ctx->initVal()) { + if (!initVal->exp()) { + throw std::runtime_error(FormatError("irgen", "标量初始化非法")); + } + init = EvalExp(initVal->exp()); + } else { + init = DefaultValue(*ty); } - init = EvalExpr(*init_value->exp()); + builder_.CreateStore(init, slot); } else { - init = builder_.CreateConstInt(0); + InitArray(slot, *ty, ctx->initVal()); } - builder_.CreateStore(init, slot); return {}; } + +std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) { + if (!ctx || !ctx->ID()) { + throw std::runtime_error(FormatError("irgen", "常量声明缺少名称")); + } + if (!func_) { + const TypeDesc* ty = sema_.GetConstType(ctx); + if (!ty) { + throw std::runtime_error(FormatError("irgen", "全局常量类型缺失")); + } + if (global_const_storage_.find(ctx) != global_const_storage_.end()) { + throw std::runtime_error(FormatError("irgen", "重复生成全局常量")); + } + ir::ConstantValue* init = nullptr; + if (ty->dims.empty()) { + if (auto* initVal = ctx->constInitVal()) { + if (!initVal->constExp()) { + throw std::runtime_error(FormatError("irgen", "全局常量初始化非法")); + } + init = EvalConstScalar(initVal->constExp()); + if (ty->base == BaseTypeKind::Int && + dynamic_cast(init)) { + auto* cf = static_cast(init); + init = module_.GetContext().GetConstInt(static_cast(cf->GetValue())); + } else if (ty->base == BaseTypeKind::Float && + dynamic_cast(init)) { + auto* ci = static_cast(init); + init = module_.GetContext().GetConstFloat(static_cast(ci->GetValue())); + } + } + } else if (auto* initVal = ctx->constInitVal()) { + size_t total = ArrayTotalSize(*ty); + std::vector values( + total, + ty->base == BaseTypeKind::Float + ? static_cast( + module_.GetContext().GetConstFloat(0.0f)) + : static_cast( + module_.GetContext().GetConstInt(0))); + InitGlobalConstArray(*ty, initVal, values, 0, 0, 0); + init = module_.GetContext().CreateConstArray(ToIRType(*ty), values); + } + auto* gv = module_.CreateGlobalVariable(ctx->ID()->getText(), + ToIRType(*ty), init, true); + global_const_storage_[ctx] = gv; + return {}; + } + if (const_storage_.find(ctx) != const_storage_.end()) { + throw std::runtime_error(FormatError("irgen", "重复生成常量存储")); + } + const TypeDesc* ty = sema_.GetConstType(ctx); + if (!ty) { + throw std::runtime_error(FormatError("irgen", "常量类型缺失")); + } + auto* slot = builder_.CreateAlloca(ToIRType(*ty), module_.GetContext().NextTemp()); + const_storage_[ctx] = slot; + + if (ty->dims.empty()) { + ir::Value* init = nullptr; + if (auto* initVal = ctx->constInitVal()) { + if (!initVal->constExp()) { + throw std::runtime_error(FormatError("irgen", "常量初始化非法")); + } + init = std::any_cast(initVal->constExp()->accept(this)); + } else { + init = DefaultValue(*ty); + } + builder_.CreateStore(init, slot); + } else { + InitConstArray(slot, *ty, ctx->constInitVal()); + } + return {}; +} \ No newline at end of file diff --git a/src/irgen/IRGenDriver.cpp b/src/irgen/IRGenDriver.cpp index ff94412..dd12394 100644 --- a/src/irgen/IRGenDriver.cpp +++ b/src/irgen/IRGenDriver.cpp @@ -4,12 +4,11 @@ #include "SysYParser.h" #include "ir/IR.h" -#include "utils/Log.h" std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, const SemanticContext& sema) { - auto module = std::make_unique(); - IRGenImpl gen(*module, sema); - tree.accept(&gen); + auto module = std::make_unique(); // 无参构造 + IRGenImpl visitor(*module, sema); + tree.accept(&visitor); return module; -} +} \ No newline at end of file diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index cf4797c..37d338f 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -1,80 +1,1041 @@ #include "irgen/IRGen.h" +#include #include #include "SysYParser.h" #include "ir/IR.h" #include "utils/Log.h" -// 表达式生成当前也只实现了很小的一个子集。 -// 目前支持: -// - 整数字面量 -// - 普通局部变量读取 -// - 括号表达式 -// - 二元加法 -// -// 还未支持: -// - 减乘除与一元运算 -// - 赋值表达式 -// - 函数调用 -// - 数组、指针、下标访问 -// - 条件与比较表达式 -// - ... -ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { - return std::any_cast(expr.accept(this)); +namespace { +ir::Value* AnyToValue(const std::any& val) { + if (val.type() == typeid(ir::Value*)) { + return std::any_cast(val); + } + if (val.type() == typeid(ir::ConstantInt*)) { + return static_cast(std::any_cast(val)); + } + if (val.type() == typeid(ir::ConstantFloat*)) { + return static_cast(std::any_cast(val)); + } + if (val.type() == typeid(ir::Instruction*)) { + return std::any_cast(val); + } + std::cerr << "Unknown type in AnyToValue: " << val.type().name() << std::endl; + throw std::bad_any_cast(); +} + +ir::Function* FindFunctionByName(ir::Module& module, const std::string& name) { + for (const auto& fn : module.GetFunctions()) { + if (fn && fn->GetName() == name) return fn.get(); + } + return nullptr; +} +} // namespace + +ir::Value* IRGenImpl::EvalExp(SysYParser::ExpContext* ctx) { + if (!ctx) return nullptr; + return AnyToValue(ctx->accept(this)); +} + +std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) { + if (!ctx) return static_cast(nullptr); + if (ctx->addExp()) return ctx->addExp()->accept(this); + throw std::runtime_error(FormatError("irgen", "不支持的表达式")); +} + +std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { + auto muls = ctx->mulExp(); + if (muls.empty()) return static_cast(nullptr); + ir::Value* lhs = AnyToValue(muls[0]->accept(this)); + for (size_t i = 1; i < muls.size(); ++i) { + ir::Value* rhs = AnyToValue(muls[i]->accept(this)); + auto* node = ctx->children.at(2 * i - 1); + std::string text = node ? node->getText() : "+"; + bool use_float = lhs->IsFloat() || rhs->IsFloat(); + if (use_float) { + lhs = CastToFloat(lhs); + rhs = CastToFloat(rhs); + if (text == "+") { + lhs = builder_.CreateFAdd(lhs, rhs, module_.GetContext().NextTemp()); + } else { + lhs = builder_.CreateFSub(lhs, rhs, module_.GetContext().NextTemp()); + } + } else { + if (text == "+") { + lhs = builder_.CreateAdd(lhs, rhs, module_.GetContext().NextTemp()); + } else { + lhs = builder_.CreateSub(lhs, rhs, module_.GetContext().NextTemp()); + } + } + } + return static_cast(lhs); +} + +std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { + auto unaries = ctx->unaryExp(); + if (unaries.empty()) return static_cast(nullptr); + ir::Value* lhs = AnyToValue(unaries[0]->accept(this)); + for (size_t i = 1; i < unaries.size(); ++i) { + ir::Value* rhs = AnyToValue(unaries[i]->accept(this)); + auto* node = ctx->children.at(2 * i - 1); + std::string text = node ? node->getText() : "*"; + bool use_float = lhs->IsFloat() || rhs->IsFloat(); + if (use_float) { + lhs = CastToFloat(lhs); + rhs = CastToFloat(rhs); + if (text == "*") { + lhs = builder_.CreateFMul(lhs, rhs, module_.GetContext().NextTemp()); + } else if (text == "/") { + lhs = builder_.CreateFDiv(lhs, rhs, module_.GetContext().NextTemp()); + } else { + throw std::runtime_error(FormatError("irgen", "float 不支持 %")); + } + } else { + if (text == "*") { + lhs = builder_.CreateMul(lhs, rhs, module_.GetContext().NextTemp()); + } else if (text == "/") { + lhs = builder_.CreateSDiv(lhs, rhs, module_.GetContext().NextTemp()); + } else { + lhs = builder_.CreateSRem(lhs, rhs, module_.GetContext().NextTemp()); + } + } + } + return static_cast(lhs); +} + +std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { + if (ctx->primaryExp()) return ctx->primaryExp()->accept(this); + if (ctx->ID() && ctx->LPAREN()) { + std::string name = ctx->ID()->getText(); + ir::Function* callee = FindFunctionByName(module_, name); + if (!callee) { + throw std::runtime_error(FormatError("irgen", "未定义函数: " + name)); + } + std::vector args; + if (ctx->funcRParams()) { + for (auto* exp : ctx->funcRParams()->exp()) { + args.push_back(EvalExp(exp)); + } + } + const auto& param_tys = callee->GetFunctionType()->GetParamTypes(); + for (size_t i = 0; i < args.size() && i < param_tys.size(); ++i) { + auto* arg = args[i]; + const auto& pty = param_tys[i]; + if (pty->IsPointer() && arg && arg->GetType() && arg->GetType()->IsPointer()) { + bool param_elem_array = pty->GetElementType()->IsArray(); + bool arg_elem_array = arg->GetType()->GetElementType()->IsArray(); + if (!param_elem_array && arg_elem_array) { + std::vector idx = {builder_.CreateConstInt(0), + builder_.CreateConstInt(0)}; + args[i] = builder_.CreateGep(arg, std::move(idx), + module_.GetContext().NextTemp()); + arg = args[i]; + } else if (param_elem_array && arg_elem_array) { + auto* param_elem = pty->GetElementType().get(); + auto* arg_elem = arg->GetType()->GetElementType().get(); + if (param_elem && arg_elem && arg_elem->IsArray() && + arg_elem->GetElementType()->Equals(*param_elem)) { + std::vector idx = {builder_.CreateConstInt(0)}; + args[i] = builder_.CreateGep(arg, std::move(idx), + module_.GetContext().NextTemp()); + arg = args[i]; + } + } else if (param_elem_array && !arg_elem_array) { + if (auto* gep = dynamic_cast(arg)) { + const auto& idx = gep->GetIndices(); + auto is_zero = [](ir::Value* v) { + auto* ci = dynamic_cast(v); + return ci && ci->GetValue() == 0; + }; + if (idx.size() == 2 && is_zero(idx[0]) && is_zero(idx[1])) { + auto* base = gep->GetBasePtr(); + if (base && base->GetType() && base->GetType()->IsPointer() && + base->GetType()->GetElementType()->IsArray()) { + args[i] = base; + arg = base; + } + } + } else if (auto* gep = dynamic_cast(arg)) { + auto* base = gep->GetBasePtr(); + if (base && base->GetType() && base->GetType()->IsPointer() && + base->GetType()->GetElementType()->IsArray()) { + auto* base_elem = base->GetType()->GetElementType().get(); + auto* param_elem = pty->GetElementType().get(); + if (base_elem && base_elem->IsArray() && param_elem && + base_elem->GetElementType()->Equals(*param_elem)) { + std::vector idx2 = {builder_.CreateConstInt(0)}; + args[i] = builder_.CreateGep(base, std::move(idx2), + module_.GetContext().NextTemp()); + arg = args[i]; + } + } + } + } + } + if (pty->IsFloat() && (arg->IsInt1() || arg->IsInt32())) { + args[i] = CastToFloat(arg); + } else if (pty->IsInt32() && (arg->IsInt1() || arg->IsFloat())) { + args[i] = CastToInt(arg); + } + } + std::string tmp = callee->GetReturnType()->IsVoid() + ? std::string("") + : module_.GetContext().NextTemp(); + auto* call = builder_.CreateCall(callee, std::move(args), tmp); + return static_cast(call); + } + if (ctx->unaryOp() && ctx->unaryExp()) { + std::string op = ctx->unaryOp()->getText(); + ir::Value* val = AnyToValue(ctx->unaryExp()->accept(this)); + if (op == "+") return static_cast(val); + if (op == "-") { + if (val->IsFloat()) { + auto* zero = builder_.CreateConstFloat(0.0f); + return static_cast( + builder_.CreateFSub(zero, val, module_.GetContext().NextTemp())); + } + auto* zero = builder_.CreateConstInt(0); + return static_cast( + builder_.CreateSub(zero, val, module_.GetContext().NextTemp())); + } + if (op == "!") { + ir::Value* b = MakeBool(val); + auto* zero = builder_.CreateConstBool(false); + return static_cast( + builder_.CreateICmp(ir::ICmpPredicate::Eq, b, zero, + module_.GetContext().NextTemp())); + } + } + return static_cast(nullptr); +} + +std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { + ir::Value* val = EmitRelEq(ctx); + return static_cast(val); +} + +std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { + ir::Value* val = EmitEq(ctx); + return static_cast(val); +} + +std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { + return ctx->eqExp(0)->accept(this); +} + +std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { + return ctx->lAndExp(0)->accept(this); +} + +std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) { + if (ctx->lOrExp()) return ctx->lOrExp()->accept(this); + return static_cast(nullptr); +} + +std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) { + return ctx ? static_cast(nullptr) : static_cast(nullptr); +} + +std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { + if (!ctx) return static_cast(nullptr); + if (ctx->LPAREN() && ctx->exp()) return EvalExp(ctx->exp()); + if (ctx->lVal()) return ctx->lVal()->accept(this); + if (ctx->number()) return ctx->number()->accept(this); + throw std::runtime_error(FormatError("irgen", "不支持的 PrimaryExp")); +} + +std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法常量")); + } + if (ctx->INT_CONST()) { + const std::string text = ctx->getText(); + size_t idx = 0; + long long val = std::stoll(text, &idx, 0); + if (idx != text.size()) { + throw std::runtime_error(FormatError("irgen", "非法整数常量")); + } + return static_cast(builder_.CreateConstInt(val)); + } + if (ctx->FLOAT_CONST()) { + float val = std::stof(ctx->getText()); + return static_cast(builder_.CreateConstFloat(val)); + } + throw std::runtime_error(FormatError("irgen", "不支持的常量类型")); +} + +std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { + if (!ctx || !ctx->ID()) { + throw std::runtime_error(FormatError("irgen", "非法左值")); + } + ir::Value* addr = GetLValAddress(ctx); + BoundDecl bound = sema_.ResolveVarUse(ctx); + const TypeDesc* ty = nullptr; + if (bound.kind == BoundDecl::Kind::Var && bound.var_decl) { + ty = sema_.GetVarType(bound.var_decl); + } else if (bound.kind == BoundDecl::Kind::Const && bound.const_decl) { + ty = sema_.GetConstType(bound.const_decl); + } else if (bound.kind == BoundDecl::Kind::Param && bound.param_decl) { + ty = sema_.GetParamType(bound.param_decl); + } + if (!ty && ctx->ID()) { + const auto name = ctx->ID()->getText(); + for (const auto& kv : var_storage_) { + auto* def = const_cast(kv.first); + if (def && def->ID() && def->ID()->getText() == name) { + ty = sema_.GetVarType(kv.first); + break; + } + } + if (!ty) { + for (const auto& kv : const_storage_) { + auto* def = const_cast(kv.first); + if (def && def->ID() && def->ID()->getText() == name) { + ty = sema_.GetConstType(kv.first); + break; + } + } + } + if (!ty) { + for (const auto& kv : param_storage_) { + auto* def = const_cast(kv.first); + if (def && def->ID() && def->ID()->getText() == name) { + ty = sema_.GetParamType(kv.first); + break; + } + } + } + if (!ty) { + for (const auto& kv : global_var_storage_) { + auto* def = const_cast(kv.first); + if (def && def->ID() && def->ID()->getText() == name) { + ty = sema_.GetVarType(kv.first); + break; + } + } + } + if (!ty) { + for (const auto& kv : global_const_storage_) { + auto* def = const_cast(kv.first); + if (def && def->ID() && def->ID()->getText() == name) { + ty = sema_.GetConstType(kv.first); + break; + } + } + } + } + if (!ty) { + throw std::runtime_error(FormatError("irgen", "无法解析左值类型")); + } + bool as_rvalue = true; + if (!ty->dims.empty() && ctx->LBRACK().empty()) { + as_rvalue = false; + } + return static_cast(LoadIfNeeded(addr, *ty, as_rvalue)); +} + +std::any IRGenImpl::visitConstExp(SysYParser::ConstExpContext* ctx) { + if (!ctx) return static_cast(nullptr); + return ctx->addExp()->accept(this); +} + +std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) { + return ctx ? static_cast(nullptr) : static_cast(nullptr); +} + +std::any IRGenImpl::visitInitVal(SysYParser::InitValContext* ctx) { + return ctx ? static_cast(nullptr) : static_cast(nullptr); +} + +ir::Value* IRGenImpl::CastToFloat(ir::Value* v) { + if (v->IsFloat()) return v; + if (v->IsInt1() || v->IsInt32()) { + return builder_.CreateSIToFP(v, ir::Type::GetFloatType(), + module_.GetContext().NextTemp()); + } + throw std::runtime_error(FormatError("irgen", "无法转换为 float")); +} + +ir::Value* IRGenImpl::CastToInt(ir::Value* v) { + if (v->IsInt32()) return v; + if (v->IsInt1()) { + return builder_.CreateZExt(v, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + if (v->IsFloat()) { + return builder_.CreateFPToSI(v, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + throw std::runtime_error(FormatError("irgen", "无法转换为 int")); +} + +ir::Value* IRGenImpl::MakeBool(ir::Value* v) { + if (v->IsInt1()) return v; + if (v->IsFloat()) { + auto* zero = builder_.CreateConstFloat(0.0f); + return builder_.CreateFCmp(ir::FCmpPredicate::One, v, zero, + module_.GetContext().NextTemp()); + } + auto* zero = builder_.CreateConstInt(0); + return builder_.CreateICmp(ir::ICmpPredicate::Ne, v, zero, + module_.GetContext().NextTemp()); +} + +ir::Value* IRGenImpl::EmitRelEq(SysYParser::RelExpContext* ctx) { + auto exps = ctx->addExp(); + if (exps.empty()) return nullptr; + ir::Value* lhs = AnyToValue(exps[0]->accept(this)); + if (exps.size() == 1) return lhs; + for (size_t i = 1; i < exps.size(); ++i) { + ir::Value* rhs = AnyToValue(exps[i]->accept(this)); + auto* node = ctx->children.at(2 * i - 1); + std::string text = node ? node->getText() : "<"; + bool use_float = lhs->IsFloat() || rhs->IsFloat(); + if (use_float) { + lhs = CastToFloat(lhs); + rhs = CastToFloat(rhs); + ir::FCmpPredicate pred = ir::FCmpPredicate::Olt; + if (text == "<") pred = ir::FCmpPredicate::Olt; + else if (text == "<=") pred = ir::FCmpPredicate::Ole; + else if (text == ">") pred = ir::FCmpPredicate::Ogt; + else pred = ir::FCmpPredicate::Oge; + lhs = builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()); + } else { + ir::ICmpPredicate pred = ir::ICmpPredicate::Slt; + if (text == "<") pred = ir::ICmpPredicate::Slt; + else if (text == "<=") pred = ir::ICmpPredicate::Sle; + else if (text == ">") pred = ir::ICmpPredicate::Sgt; + else pred = ir::ICmpPredicate::Sge; + lhs = builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()); + } + } + return lhs; +} + +ir::Value* IRGenImpl::EmitEq(SysYParser::EqExpContext* ctx) { + auto rels = ctx->relExp(); + if (rels.empty()) return nullptr; + ir::Value* lhs = EmitRelEq(rels[0]); + if (rels.size() == 1) return lhs; + for (size_t i = 1; i < rels.size(); ++i) { + ir::Value* rhs = EmitRelEq(rels[i]); + auto* node = ctx->children.at(2 * i - 1); + std::string text = node ? node->getText() : "=="; + if (lhs->IsFloat() || rhs->IsFloat()) { + lhs = CastToFloat(lhs); + rhs = CastToFloat(rhs); + ir::FCmpPredicate pred = text == "==" ? ir::FCmpPredicate::Oeq + : ir::FCmpPredicate::One; + lhs = builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()); + } else { + ir::ICmpPredicate pred = text == "==" ? ir::ICmpPredicate::Eq + : ir::ICmpPredicate::Ne; + lhs = builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()); + } + } + return lhs; +} + +ir::Value* IRGenImpl::EvalCondValue(SysYParser::CondContext* ctx) { + if (!ctx) return nullptr; + auto* tmp_true = func_->CreateBlock("cond.true"); + auto* tmp_false = func_->CreateBlock("cond.false"); + auto* merge = func_->CreateBlock("cond.merge"); + EmitCondBr(ctx, tmp_true, tmp_false); + + builder_.SetInsertPoint(tmp_true); + builder_.CreateBr(merge); + builder_.SetInsertPoint(tmp_false); + builder_.CreateBr(merge); + builder_.SetInsertPoint(merge); + auto* phi = builder_.CreatePhi(ir::Type::GetInt1Type(), + module_.GetContext().NextTemp()); + phi->AddIncoming(builder_.CreateConstBool(true), tmp_true); + phi->AddIncoming(builder_.CreateConstBool(false), tmp_false); + return phi; +} + +void IRGenImpl::EmitCondBr(SysYParser::CondContext* ctx, ir::BasicBlock* true_bb, + ir::BasicBlock* false_bb) { + if (!ctx || !ctx->lOrExp()) { + throw std::runtime_error(FormatError("irgen", "非法 cond")); + } + EmitLOrCond(ctx->lOrExp(), true_bb, false_bb); +} + +void IRGenImpl::EmitLOrCond(SysYParser::LOrExpContext* ctx, + ir::BasicBlock* true_bb, + ir::BasicBlock* false_bb) { + auto ands = ctx->lAndExp(); + if (ands.empty()) { + builder_.CreateBr(false_bb); + return; + } + for (size_t i = 0; i < ands.size(); ++i) { + if (i == ands.size() - 1) { + EmitLAndCond(ands[i], true_bb, false_bb); + } else { + auto* next = func_->CreateBlock("lor.next"); + EmitLAndCond(ands[i], true_bb, next); + builder_.SetInsertPoint(next); + } + } +} + +void IRGenImpl::EmitLAndCond(SysYParser::LAndExpContext* ctx, + ir::BasicBlock* true_bb, + ir::BasicBlock* false_bb) { + auto eqs = ctx->eqExp(); + if (eqs.empty()) { + builder_.CreateBr(false_bb); + return; + } + for (size_t i = 0; i < eqs.size(); ++i) { + ir::Value* cond = EmitEq(eqs[i]); + cond = MakeBool(cond); + if (i == eqs.size() - 1) { + builder_.CreateCondBr(cond, true_bb, false_bb); + } else { + auto* next = func_->CreateBlock("land.next"); + builder_.CreateCondBr(cond, next, false_bb); + builder_.SetInsertPoint(next); + } + } +} + +ir::Value* IRGenImpl::GetLValAddress(SysYParser::LValContext* ctx) { + BoundDecl bound = sema_.ResolveVarUse(ctx); + ir::Value* base_ptr = nullptr; + const TypeDesc* ty = nullptr; + if (bound.kind == BoundDecl::Kind::Var && bound.var_decl) { + ty = sema_.GetVarType(bound.var_decl); + base_ptr = var_storage_[bound.var_decl]; + } else if (bound.kind == BoundDecl::Kind::Const && bound.const_decl) { + ty = sema_.GetConstType(bound.const_decl); + base_ptr = const_storage_[bound.const_decl]; + } else if (bound.kind == BoundDecl::Kind::Param && bound.param_decl) { + ty = sema_.GetParamType(bound.param_decl); + base_ptr = param_storage_[bound.param_decl]; + } + if (!base_ptr && bound.kind == BoundDecl::Kind::Var && bound.var_decl) { + auto it = global_var_storage_.find(bound.var_decl); + if (it != global_var_storage_.end()) { + ty = sema_.GetVarType(bound.var_decl); + base_ptr = it->second; + } + } + if (!base_ptr && bound.kind == BoundDecl::Kind::Const && bound.const_decl) { + auto it = global_const_storage_.find(bound.const_decl); + if (it != global_const_storage_.end()) { + ty = sema_.GetConstType(bound.const_decl); + base_ptr = it->second; + } + } + if (!base_ptr && ctx && ctx->ID()) { + const auto name = ctx->ID()->getText(); + for (const auto& kv : var_storage_) { + auto* def = const_cast(kv.first); + if (def && def->ID() && def->ID()->getText() == name) { + ty = sema_.GetVarType(kv.first); + base_ptr = kv.second; + break; + } + } + if (!base_ptr) { + for (const auto& kv : const_storage_) { + auto* def = const_cast(kv.first); + if (def && def->ID() && def->ID()->getText() == name) { + ty = sema_.GetConstType(kv.first); + base_ptr = kv.second; + break; + } + } + } + if (!base_ptr) { + for (const auto& kv : param_storage_) { + auto* def = const_cast(kv.first); + if (def && def->ID() && def->ID()->getText() == name) { + ty = sema_.GetParamType(kv.first); + base_ptr = kv.second; + break; + } + } + } + if (!base_ptr) { + for (const auto& kv : global_var_storage_) { + auto* def = const_cast(kv.first); + if (def && def->ID() && def->ID()->getText() == name) { + ty = sema_.GetVarType(kv.first); + base_ptr = kv.second; + break; + } + } + } + if (!base_ptr) { + for (const auto& kv : global_const_storage_) { + auto* def = const_cast(kv.first); + if (def && def->ID() && def->ID()->getText() == name) { + ty = sema_.GetConstType(kv.first); + base_ptr = kv.second; + break; + } + } + } + } + if (!base_ptr || !ty) { + throw std::runtime_error(FormatError("irgen", "左值未绑定")); + } + + if (bound.kind == BoundDecl::Kind::Param && !ty->dims.empty()) { + base_ptr = builder_.CreateLoad(base_ptr, module_.GetContext().NextTemp()); + } + + std::vector indices; + const auto exps = ctx->exp(); + if (!ty->dims.empty() && !exps.empty()) { + bool need_leading_zero = base_ptr->GetType()->GetElementType()->IsArray(); + if (!ty->dims.empty() && ty->dims[0] < 0) { + need_leading_zero = false; + } + if (need_leading_zero) { + indices.push_back(builder_.CreateConstInt(0)); + } + } + for (auto* exp : exps) { + indices.push_back(CastToInt(EvalExp(exp))); + } + + if (!indices.empty()) { + return builder_.CreateGep(base_ptr, std::move(indices), + module_.GetContext().NextTemp()); + } + return base_ptr; +} + +ir::Value* IRGenImpl::LoadIfNeeded(ir::Value* addr_or_val, const TypeDesc& ty, + bool as_rvalue) { + if (!as_rvalue) { + return addr_or_val; + } + return builder_.CreateLoad(addr_or_val, module_.GetContext().NextTemp()); +} + +std::shared_ptr IRGenImpl::ToIRType(const TypeDesc& ty) { + std::shared_ptr base; + if (ty.base == BaseTypeKind::Int) base = ir::Type::GetInt32Type(); + else if (ty.base == BaseTypeKind::Float) base = ir::Type::GetFloatType(); + else base = ir::Type::GetVoidType(); + + for (auto it = ty.dims.rbegin(); it != ty.dims.rend(); ++it) { + if (*it < 0) continue; + base = ir::Type::GetArrayType(base, static_cast(*it)); + } + return base; +} + +std::shared_ptr IRGenImpl::ToIRParamType(const TypeDesc& ty) { + if (ty.dims.empty()) return ToIRType(ty); + TypeDesc elem = ty; + if (!elem.dims.empty() && elem.dims.front() < 0) { + elem.dims.erase(elem.dims.begin()); + } + return ir::Type::GetPointerType(ToIRType(elem)); +} + +ir::Value* IRGenImpl::DefaultValue(const TypeDesc& ty) { + if (ty.base == BaseTypeKind::Float) return builder_.CreateConstFloat(0.0f); + return builder_.CreateConstInt(0); +} + +size_t IRGenImpl::ArrayStride(const TypeDesc& ty, size_t dim) const { + size_t stride = 1; + for (size_t i = dim + 1; i < ty.dims.size(); ++i) { + stride *= static_cast(ty.dims[i]); + } + return stride; +} + +size_t IRGenImpl::ArrayTotalSize(const TypeDesc& ty) const { + size_t total = 1; + for (int d : ty.dims) total *= static_cast(d); + return total; +} + +static size_t AlignIndex(size_t index, size_t align) { + if (align == 0) return index; + return (index + align - 1) / align * align; +} + +size_t IRGenImpl::FillArrayValues(const TypeDesc& ty, + SysYParser::InitValContext* init, + std::vector& values, size_t base, + size_t idx, size_t dim) { + if (!init) return idx; + if (init->exp()) { + if (base + idx < values.size()) { + values[base + idx] = EvalExp(init->exp()); + } + return idx + 1; + } + + size_t sub_size = ArrayStride(ty, dim); + if (init->initVal().empty()) { + idx = AlignIndex(idx, sub_size); + return idx + sub_size; + } + + for (auto* child : init->initVal()) { + if (!child) continue; + if (child->exp()) { + idx = FillArrayValues(ty, child, values, base, idx, dim); + } else { + size_t aligned = AlignIndex(idx, sub_size); + idx = aligned; + idx = FillArrayValues(ty, child, values, base, idx, dim + 1); + idx = aligned + sub_size; + } + } + idx = AlignIndex(idx, sub_size); + return idx; +} + +size_t IRGenImpl::FillConstArrayValues( + const TypeDesc& ty, SysYParser::ConstInitValContext* init, + std::vector& values, size_t base, size_t idx, size_t dim) { + if (!init) return idx; + if (init->constExp()) { + if (base + idx < values.size()) { + values[base + idx] = AnyToValue(init->constExp()->accept(this)); + } + return idx + 1; + } + size_t sub_size = ArrayStride(ty, dim); + if (init->constInitVal().empty()) { + idx = AlignIndex(idx, sub_size); + return idx + sub_size; + } + for (auto* child : init->constInitVal()) { + if (!child) continue; + if (child->constExp()) { + idx = FillConstArrayValues(ty, child, values, base, idx, dim); + } else { + size_t aligned = AlignIndex(idx, sub_size); + idx = aligned; + idx = FillConstArrayValues(ty, child, values, base, idx, dim + 1); + idx = aligned + sub_size; + } + } + idx = AlignIndex(idx, sub_size); + return idx; +} + +void IRGenImpl::InitArray(ir::Value* base_ptr, const TypeDesc& ty, + SysYParser::InitValContext* init) { + size_t total = ArrayTotalSize(ty); + std::vector values(total, DefaultValue(ty)); + FillArrayValues(ty, init, values, 0, 0, 0); + + for (size_t idx = 0; idx < total; ++idx) { + std::vector indices; + indices.push_back(builder_.CreateConstInt(0)); + size_t remain = idx; + for (size_t dim = 0; dim < ty.dims.size(); ++dim) { + size_t stride = ArrayStride(ty, dim); + size_t cur = remain / stride; + remain %= stride; + indices.push_back(builder_.CreateConstInt(static_cast(cur))); + } + ir::Value* addr = builder_.CreateGep(base_ptr, std::move(indices), + module_.GetContext().NextTemp()); + ir::Value* value = values[idx]; + if (ty.base == BaseTypeKind::Float && (value->IsInt1() || value->IsInt32())) { + value = CastToFloat(value); + } else if (ty.base == BaseTypeKind::Int && value->IsFloat()) { + value = CastToInt(value); + } + builder_.CreateStore(value, addr); + } +} + +void IRGenImpl::InitConstArray(ir::Value* base_ptr, const TypeDesc& ty, + SysYParser::ConstInitValContext* init) { + size_t total = ArrayTotalSize(ty); + std::vector values(total, DefaultValue(ty)); + FillConstArrayValues(ty, init, values, 0, 0, 0); + + for (size_t idx = 0; idx < total; ++idx) { + std::vector indices; + indices.push_back(builder_.CreateConstInt(0)); + size_t remain = idx; + for (size_t dim = 0; dim < ty.dims.size(); ++dim) { + size_t stride = ArrayStride(ty, dim); + size_t cur = remain / stride; + remain %= stride; + indices.push_back(builder_.CreateConstInt(static_cast(cur))); + } + ir::Value* addr = builder_.CreateGep(base_ptr, std::move(indices), + module_.GetContext().NextTemp()); + ir::Value* value = values[idx]; + if (ty.base == BaseTypeKind::Float && (value->IsInt1() || value->IsInt32())) { + value = CastToFloat(value); + } else if (ty.base == BaseTypeKind::Int && value->IsFloat()) { + value = CastToInt(value); + } + builder_.CreateStore(value, addr); + } +} + +void IRGenImpl::PushLoop(ir::BasicBlock* break_bb, ir::BasicBlock* cont_bb) { + loop_stack_.push_back({break_bb, cont_bb}); +} + +void IRGenImpl::PopLoop() { + if (!loop_stack_.empty()) loop_stack_.pop_back(); +} + +ir::BasicBlock* IRGenImpl::CurrentBreak() const { + if (loop_stack_.empty()) return nullptr; + return loop_stack_.back().first; +} + +ir::BasicBlock* IRGenImpl::CurrentContinue() const { + if (loop_stack_.empty()) return nullptr; + return loop_stack_.back().second; +} + +namespace { +struct ConstNumber { + bool is_float = false; + double f = 0.0; + long long i = 0; +}; + +ConstNumber ToConstNumber(ir::ConstantValue* v) { + ConstNumber num; + if (auto* ci = dynamic_cast(v)) { + num.is_float = false; + num.i = ci->GetValue(); + return num; + } + if (auto* cf = dynamic_cast(v)) { + num.is_float = true; + num.f = cf->GetValue(); + return num; + } + return num; } +ConstNumber PromoteToFloat(const ConstNumber& v) { + if (v.is_float) return v; + ConstNumber n; + n.is_float = true; + n.f = static_cast(v.i); + return n; +} +} // namespace -std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "非法括号表达式")); +ir::ConstantValue* IRGenImpl::EvalConstScalar(SysYParser::ExpContext* ctx) { + if (!ctx || !ctx->addExp()) { + throw std::runtime_error(FormatError("irgen", "非法常量表达式")); } - return EvalExpr(*ctx->exp()); + return EvalConstAdd(ctx->addExp()); } +ir::ConstantValue* IRGenImpl::EvalConstScalar(SysYParser::ConstExpContext* ctx) { + if (!ctx || !ctx->addExp()) { + throw std::runtime_error(FormatError("irgen", "非法常量表达式")); + } + return EvalConstAdd(ctx->addExp()); +} -std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量")); +ir::ConstantValue* IRGenImpl::EvalConstAdd(SysYParser::AddExpContext* ctx) { + auto muls = ctx->mulExp(); + if (muls.empty()) return module_.GetContext().GetConstInt(0); + ConstNumber lhs = ToConstNumber(EvalConstMul(muls[0])); + for (size_t i = 1; i < muls.size(); ++i) { + ConstNumber rhs = ToConstNumber(EvalConstMul(muls[i])); + auto* node = ctx->children.at(2 * i - 1); + std::string text = node ? node->getText() : "+"; + if (lhs.is_float || rhs.is_float) { + lhs = PromoteToFloat(lhs); + rhs = PromoteToFloat(rhs); + lhs.f = (text == "+") ? lhs.f + rhs.f : lhs.f - rhs.f; + lhs.is_float = true; + } else { + lhs.i = (text == "+") ? lhs.i + rhs.i : lhs.i - rhs.i; + } } - return static_cast( - builder_.CreateConstInt(std::stoi(ctx->number()->getText()))); + if (lhs.is_float) return module_.GetContext().GetConstFloat(static_cast(lhs.f)); + return module_.GetContext().GetConstInt(static_cast(lhs.i)); } -// 变量使用的处理流程: -// 1. 先通过语义分析结果把变量使用绑定回声明; -// 2. 再通过 storage_map_ 找到该声明对应的栈槽位; -// 3. 最后生成 load,把内存中的值读出来。 -// -// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 -std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) { - if (!ctx || !ctx->var() || !ctx->var()->ID()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); +ir::ConstantValue* IRGenImpl::EvalConstMul(SysYParser::MulExpContext* ctx) { + auto unaries = ctx->unaryExp(); + if (unaries.empty()) return module_.GetContext().GetConstInt(0); + ConstNumber lhs = ToConstNumber(EvalConstUnary(unaries[0])); + for (size_t i = 1; i < unaries.size(); ++i) { + ConstNumber rhs = ToConstNumber(EvalConstUnary(unaries[i])); + auto* node = ctx->children.at(2 * i - 1); + std::string text = node ? node->getText() : "*"; + if (text == "%") { + if (lhs.is_float || rhs.is_float) { + throw std::runtime_error(FormatError("irgen", "const % 仅支持整数")); + } + lhs.i = lhs.i % rhs.i; + continue; + } + if (lhs.is_float || rhs.is_float) { + lhs = PromoteToFloat(lhs); + rhs = PromoteToFloat(rhs); + if (text == "*") lhs.f = lhs.f * rhs.f; + else lhs.f = lhs.f / rhs.f; + lhs.is_float = true; + } else { + if (text == "*") lhs.i = lhs.i * rhs.i; + else lhs.i = lhs.i / rhs.i; + } } - auto* decl = sema_.ResolveVarUse(ctx->var()); - if (!decl) { - throw std::runtime_error( - FormatError("irgen", - "变量使用缺少语义绑定: " + ctx->var()->ID()->getText())); + if (lhs.is_float) return module_.GetContext().GetConstFloat(static_cast(lhs.f)); + return module_.GetContext().GetConstInt(static_cast(lhs.i)); +} + +ir::ConstantValue* IRGenImpl::EvalConstUnary(SysYParser::UnaryExpContext* ctx) { + if (ctx->primaryExp()) return EvalConstPrimary(ctx->primaryExp()); + if (ctx->unaryOp() && ctx->unaryExp()) { + ConstNumber val = ToConstNumber(EvalConstUnary(ctx->unaryExp())); + std::string op = ctx->unaryOp()->getText(); + if (op == "+") { + if (val.is_float) { + return module_.GetContext().GetConstFloat(static_cast(val.f)); + } + return module_.GetContext().GetConstInt(static_cast(val.i)); + } + if (op == "-") { + if (val.is_float) { + return module_.GetContext().GetConstFloat(static_cast(-val.f)); + } + return module_.GetContext().GetConstInt(static_cast(-val.i)); + } + if (op == "!") { + bool is_zero = val.is_float ? (val.f == 0.0) : (val.i == 0); + return module_.GetContext().GetConstInt(is_zero ? 1 : 0); + } } - auto it = storage_map_.find(decl); - if (it == storage_map_.end()) { - throw std::runtime_error( - FormatError("irgen", - "变量声明缺少存储槽位: " + ctx->var()->ID()->getText())); + throw std::runtime_error(FormatError("irgen", "const 不支持函数调用")); +} + +ir::ConstantValue* IRGenImpl::EvalConstPrimary(SysYParser::PrimaryExpContext* ctx) { + if (ctx->exp()) return EvalConstScalar(ctx->exp()); + if (ctx->lVal()) return EvalConstLVal(ctx->lVal()); + if (ctx->number()) return EvalConstNumber(ctx->number()); + return module_.GetContext().GetConstInt(0); +} + +ir::ConstantValue* IRGenImpl::EvalConstNumber(SysYParser::NumberContext* ctx) { + if (ctx->INT_CONST()) { + const std::string text = ctx->getText(); + size_t idx = 0; + long long val = std::stoll(text, &idx, 0); + if (idx != text.size()) { + throw std::runtime_error(FormatError("irgen", "非法整数常量")); + } + return module_.GetContext().GetConstInt(static_cast(val)); } - return static_cast( - builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); + if (ctx->FLOAT_CONST()) { + return module_.GetContext().GetConstFloat(std::stof(ctx->getText())); + } + throw std::runtime_error(FormatError("irgen", "非法常量")); } +ir::ConstantValue* IRGenImpl::EvalConstLVal(SysYParser::LValContext* ctx) { + BoundDecl bound = sema_.ResolveVarUse(ctx); + if (bound.kind == BoundDecl::Kind::Const && bound.const_decl) { + auto it = global_const_storage_.find(bound.const_decl); + if (it != global_const_storage_.end()) { + auto* init = it->second->GetInitializer(); + if (init) return init; + } + } + throw std::runtime_error(FormatError("irgen", "constExp 使用了非常量")); +} + +size_t IRGenImpl::InitGlobalArray(const TypeDesc& ty, + SysYParser::InitValContext* init, + std::vector& values, + size_t base, size_t idx, size_t dim) { + if (!init) return idx; + if (init->exp()) { + if (base + idx < values.size()) { + auto* v = EvalConstScalar(init->exp()); + if (ty.base == BaseTypeKind::Int && dynamic_cast(v)) { + auto* cf = static_cast(v); + v = module_.GetContext().GetConstInt(static_cast(cf->GetValue())); + } else if (ty.base == BaseTypeKind::Float && + dynamic_cast(v)) { + auto* ci = static_cast(v); + v = module_.GetContext().GetConstFloat(static_cast(ci->GetValue())); + } + values[base + idx] = v; + } + return idx + 1; + } + size_t sub_size = ArrayStride(ty, dim); + if (init->initVal().empty()) { + idx = AlignIndex(idx, sub_size); + return idx + sub_size; + } + for (auto* child : init->initVal()) { + if (!child) continue; + if (child->exp()) { + idx = InitGlobalArray(ty, child, values, base, idx, dim); + } else { + size_t aligned = AlignIndex(idx, sub_size); + idx = aligned; + idx = InitGlobalArray(ty, child, values, base, idx, dim + 1); + idx = aligned + sub_size; + } + } + idx = AlignIndex(idx, sub_size); + return idx; +} -std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { - throw std::runtime_error(FormatError("irgen", "非法加法表达式")); +size_t IRGenImpl::InitGlobalConstArray(const TypeDesc& ty, + SysYParser::ConstInitValContext* init, + std::vector& values, + size_t base, size_t idx, size_t dim) { + if (!init) return idx; + if (init->constExp()) { + if (base + idx < values.size()) { + auto* v = EvalConstScalar(init->constExp()); + if (ty.base == BaseTypeKind::Int && dynamic_cast(v)) { + auto* cf = static_cast(v); + v = module_.GetContext().GetConstInt(static_cast(cf->GetValue())); + } else if (ty.base == BaseTypeKind::Float && + dynamic_cast(v)) { + auto* ci = static_cast(v); + v = module_.GetContext().GetConstFloat(static_cast(ci->GetValue())); + } + values[base + idx] = v; + } + return idx + 1; + } + size_t sub_size = ArrayStride(ty, dim); + if (init->constInitVal().empty()) { + idx = AlignIndex(idx, sub_size); + return idx + sub_size; + } + for (auto* child : init->constInitVal()) { + if (!child) continue; + if (child->constExp()) { + idx = InitGlobalConstArray(ty, child, values, base, idx, dim); + } else { + size_t aligned = AlignIndex(idx, sub_size); + idx = aligned; + idx = InitGlobalConstArray(ty, child, values, base, idx, dim + 1); + idx = aligned + sub_size; + } } - ir::Value* lhs = EvalExpr(*ctx->exp(0)); - ir::Value* rhs = EvalExpr(*ctx->exp(1)); - return static_cast( - builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, - module_.GetContext().NextTemp())); + idx = AlignIndex(idx, sub_size); + return idx; } diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 4912d03..60c0b68 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -6,82 +6,114 @@ #include "ir/IR.h" #include "utils/Log.h" -namespace { - -void VerifyFunctionStructure(const ir::Function& func) { - // 当前 IRGen 仍是单入口、顺序生成;这里在生成结束后补一层块终结校验。 - for (const auto& bb : func.GetBlocks()) { - if (!bb || !bb->HasTerminator()) { - throw std::runtime_error( - FormatError("irgen", "基本块未正确终结: " + - (bb ? bb->GetName() : std::string("")))); - } - } -} - -} // namespace - IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema) : module_(module), sema_(sema), func_(nullptr), builder_(module.GetContext(), nullptr) {} -// 编译单元的 IR 生成当前只实现了最小功能: -// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容; -// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR; -// -// 当前还没有实现: -// - 多个函数定义的遍历与生成; -// - 全局变量、全局常量的 IR 生成。 std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少编译单元")); + if (!ctx) return {}; + + func_map_.clear(); + global_var_storage_.clear(); + global_const_storage_.clear(); + + func_ = nullptr; + for (auto* decl : ctx->decl()) { + if (decl) decl->accept(this); + } + for (auto* funcDef : ctx->funcDef()) { + if (!funcDef || !funcDef->ID()) continue; + const auto* fty = sema_.GetFuncType(funcDef); + if (!fty) { + throw std::runtime_error(FormatError("irgen", "缺少函数类型")); + } + std::vector> params; + for (const auto& p : fty->params) { + params.push_back(ToIRParamType(p)); + } + auto ret = ToIRType(fty->ret); + auto func_ty = ir::Type::GetFunctionType(ret, params); + auto* fn = module_.CreateFunctionWithType(funcDef->ID()->getText(), func_ty); + func_map_[funcDef] = fn; } - auto* func = ctx->funcDef(); - if (!func) { - throw std::runtime_error(FormatError("irgen", "缺少函数定义")); + + auto declare_builtin = [&](const std::string& name, + std::shared_ptr ret, + std::vector> params) { + for (const auto& fn : module_.GetFunctions()) { + if (fn && fn->GetName() == name) return; + } + auto fty = ir::Type::GetFunctionType(ret, params); + module_.CreateFunctionDecl(name, fty); + }; + auto i32 = ir::Type::GetInt32Type(); + auto f32 = ir::Type::GetFloatType(); + declare_builtin("getint", i32, {}); + declare_builtin("getch", i32, {}); + declare_builtin("getarray", i32, {ir::Type::GetPointerType(i32)}); + declare_builtin("putint", ir::Type::GetVoidType(), {i32}); + declare_builtin("putch", ir::Type::GetVoidType(), {i32}); + declare_builtin("putarray", ir::Type::GetVoidType(), + {i32, ir::Type::GetPointerType(i32)}); + declare_builtin("getfloat", f32, {}); + declare_builtin("getfarray", i32, {ir::Type::GetPointerType(f32)}); + declare_builtin("putfloat", ir::Type::GetVoidType(), {f32}); + declare_builtin("putfarray", ir::Type::GetVoidType(), + {i32, ir::Type::GetPointerType(f32)}); + + for (auto* funcDef : ctx->funcDef()) { + if (funcDef) funcDef->accept(this); } - func->accept(this); return {}; } -// 函数 IR 生成当前实现了: -// 1. 获取函数名; -// 2. 检查函数返回类型; -// 3. 在 Module 中创建 Function; -// 4. 将 builder 插入点设置到入口基本块; -// 5. 继续生成函数体。 -// -// 当前还没有实现: -// - 通用函数返回类型处理; -// - 形参列表遍历与参数类型收集; -// - FunctionType 这样的函数类型对象; -// - Argument/形式参数 IR 对象; -// - 入口块中的参数初始化逻辑。 -// ... - -// 因此这里目前只支持最小的“无参 int 函数”生成。 std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少函数定义")); - } - if (!ctx->blockStmt()) { + if (!ctx || !ctx->block()) { throw std::runtime_error(FormatError("irgen", "函数体为空")); } - if (!ctx->ID()) { - throw std::runtime_error(FormatError("irgen", "缺少函数名")); + auto it = func_map_.find(ctx); + if (it == func_map_.end()) { + throw std::runtime_error(FormatError("irgen", "函数未注册")); } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数")); + func_ = it->second; + auto* entry = func_->GetEntry(); + builder_.SetInsertPoint(entry); + var_storage_.clear(); + const_storage_.clear(); + param_storage_.clear(); + loop_stack_.clear(); + + const auto* fty = sema_.GetFuncType(ctx); + if (!fty) { + throw std::runtime_error(FormatError("irgen", "缺少函数类型")); + } + if (ctx->funcFParams()) { + auto params = ctx->funcFParams()->funcFParam(); + for (size_t i = 0; i < params.size(); ++i) { + auto* param_ctx = params[i]; + auto* arg = func_->GetArg(i); + const TypeDesc* pty = sema_.GetParamType(param_ctx); + if (!pty) { + throw std::runtime_error(FormatError("irgen", "缺少参数类型")); + } + auto slot = builder_.CreateAlloca(ToIRParamType(*pty), + module_.GetContext().NextTemp()); + builder_.CreateStore(arg, slot); + param_storage_[param_ctx] = slot; + } } - func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type()); - builder_.SetInsertPoint(func_->GetEntry()); - storage_map_.clear(); + ctx->block()->accept(this); - ctx->blockStmt()->accept(this); - // 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 - VerifyFunctionStructure(*func_); + if (!builder_.GetInsertBlock()->HasTerminator()) { + if (func_->GetReturnType()->IsVoid()) { + builder_.CreateRetVoid(); + } else { + TypeDesc ret = fty->ret; + builder_.CreateRet(DefaultValue(ret)); + } + } return {}; -} +} \ No newline at end of file diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index 751550c..80c4816 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -1,39 +1,162 @@ #include "irgen/IRGen.h" +#include #include #include "SysYParser.h" #include "ir/IR.h" #include "utils/Log.h" -// 语句生成当前只实现了最小子集。 -// 目前支持: -// - return ; -// -// 还未支持: -// - 赋值语句 -// - if / while 等控制流 -// - 空语句、块语句嵌套分发之外的更多语句形态 - std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少语句")); + if (!ctx) return {}; + if (ctx->lVal() && ctx->ASSIGN()) { + ir::Value* addr = GetLValAddress(ctx->lVal()); + ir::Value* val = EvalExp(ctx->exp()); + BoundDecl bound = sema_.ResolveVarUse(ctx->lVal()); + const TypeDesc* ty = nullptr; + if (bound.kind == BoundDecl::Kind::Var && bound.var_decl) { + ty = sema_.GetVarType(bound.var_decl); + } else if (bound.kind == BoundDecl::Kind::Param && bound.param_decl) { + ty = sema_.GetParamType(bound.param_decl); + } else if (bound.kind == BoundDecl::Kind::Const) { + throw std::runtime_error(FormatError("irgen", "不能给常量赋值")); + } + if (!ty && ctx->lVal()->ID()) { + const auto name = ctx->lVal()->ID()->getText(); + for (const auto& kv : var_storage_) { + auto* def = const_cast(kv.first); + if (def && def->ID() && def->ID()->getText() == name) { + ty = sema_.GetVarType(kv.first); + break; + } + } + if (!ty) { + for (const auto& kv : const_storage_) { + auto* def = const_cast(kv.first); + if (def && def->ID() && def->ID()->getText() == name) { + throw std::runtime_error(FormatError("irgen", "不能给常量赋值")); + } + } + } + if (!ty) { + for (const auto& kv : param_storage_) { + auto* def = const_cast(kv.first); + if (def && def->ID() && def->ID()->getText() == name) { + ty = sema_.GetParamType(kv.first); + break; + } + } + } + if (!ty) { + for (const auto& kv : global_var_storage_) { + auto* def = const_cast(kv.first); + if (def && def->ID() && def->ID()->getText() == name) { + ty = sema_.GetVarType(kv.first); + break; + } + } + } + if (!ty) { + for (const auto& kv : global_const_storage_) { + auto* def = const_cast(kv.first); + if (def && def->ID() && def->ID()->getText() == name) { + throw std::runtime_error(FormatError("irgen", "不能给常量赋值")); + } + } + } + } + if (!ty) { + throw std::runtime_error(FormatError("irgen", "无法解析赋值类型")); + } + if (ty->base == BaseTypeKind::Float && val->IsInt32()) { + val = CastToFloat(val); + } else if (ty->base == BaseTypeKind::Int && val->IsFloat()) { + val = CastToInt(val); + } + builder_.CreateStore(val, addr); + return BlockFlow::Continue; + } + if (ctx->block()) { + return ctx->block()->accept(this); } - if (ctx->returnStmt()) { - return ctx->returnStmt()->accept(this); + if (ctx->IF()) { + auto* then_bb = func_->CreateBlock("if.then"); + auto* else_bb = func_->CreateBlock("if.else"); + auto* merge_bb = func_->CreateBlock("if.end"); + EmitCondBr(ctx->cond(), then_bb, else_bb); + + builder_.SetInsertPoint(then_bb); + auto then_flow = std::any_cast(ctx->stmt(0)->accept(this)); + if (then_flow != BlockFlow::Terminated) { + builder_.CreateBr(merge_bb); + } + + builder_.SetInsertPoint(else_bb); + if (ctx->stmt(1)) { + auto else_flow = std::any_cast(ctx->stmt(1)->accept(this)); + if (else_flow != BlockFlow::Terminated) { + builder_.CreateBr(merge_bb); + } + } else { + builder_.CreateBr(merge_bb); + } + + builder_.SetInsertPoint(merge_bb); + return BlockFlow::Continue; } - throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); -} + if (ctx->WHILE()) { + auto* cond_bb = func_->CreateBlock("while.cond"); + auto* body_bb = func_->CreateBlock("while.body"); + auto* end_bb = func_->CreateBlock("while.end"); + builder_.CreateBr(cond_bb); + builder_.SetInsertPoint(cond_bb); + EmitCondBr(ctx->cond(), body_bb, end_bb); -std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少 return 语句")); + builder_.SetInsertPoint(body_bb); + PushLoop(end_bb, cond_bb); + auto body_flow = std::any_cast(ctx->stmt(0)->accept(this)); + PopLoop(); + if (body_flow != BlockFlow::Terminated) { + builder_.CreateBr(cond_bb); + } + + builder_.SetInsertPoint(end_bb); + return BlockFlow::Continue; + } + if (ctx->BREAK()) { + auto* target = CurrentBreak(); + if (!target) { + throw std::runtime_error(FormatError("irgen", "break 不在循环内")); + } + builder_.CreateBr(target); + return BlockFlow::Terminated; + } + if (ctx->CONTINUE()) { + auto* target = CurrentContinue(); + if (!target) { + throw std::runtime_error(FormatError("irgen", "continue 不在循环内")); + } + builder_.CreateBr(target); + return BlockFlow::Terminated; + } + if (ctx->RETURN()) { + if (!ctx->exp()) { + builder_.CreateRetVoid(); + return BlockFlow::Terminated; + } + ir::Value* v = EvalExp(ctx->exp()); + auto ret_ty = func_->GetReturnType(); + if (ret_ty->IsFloat() && v->IsInt32()) { + v = CastToFloat(v); + } else if (ret_ty->IsInt32() && v->IsFloat()) { + v = CastToInt(v); + } + builder_.CreateRet(v); + return BlockFlow::Terminated; } - if (!ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "return 缺少表达式")); + if (ctx->exp()) { + EvalExp(ctx->exp()); } - ir::Value* v = EvalExpr(*ctx->exp()); - builder_.CreateRet(v); - return BlockFlow::Terminated; -} + return BlockFlow::Continue; +} \ No newline at end of file diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index 745374c..f135a5b 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "SysYBaseVisitor.h" #include "sem/SymbolTable.h" @@ -10,185 +11,494 @@ namespace { -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("sema", "非法左值")); +static BaseTypeKind BaseTypeFromBType(SysYParser::BTypeContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "缺少 bType")); } - return lvalue.ID()->getText(); + if (ctx->INT()) return BaseTypeKind::Int; + if (ctx->FLOAT()) return BaseTypeKind::Float; + throw std::runtime_error(FormatError("sema", "未知基础类型")); } +static BaseTypeKind BaseTypeFromFuncType(SysYParser::FuncTypeContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "缺少 funcType")); + } + if (ctx->VOID()) return BaseTypeKind::Void; + if (ctx->INT()) return BaseTypeKind::Int; + if (ctx->FLOAT()) return BaseTypeKind::Float; + throw std::runtime_error(FormatError("sema", "未知函数返回类型")); +} + +class ConstEvalVisitor final : public SysYBaseVisitor { + public: + explicit ConstEvalVisitor(const SymbolTable& table) : table_(table) {} + + std::any visitConstExp(SysYParser::ConstExpContext* ctx) override { + return visitAddExp(ctx->addExp()); + } + + std::any visitAddExp(SysYParser::AddExpContext* ctx) override { + auto muls = ctx->mulExp(); + if (muls.empty()) return 0; + int value = std::any_cast(muls[0]->accept(this)); + for (size_t i = 1; i < muls.size(); ++i) { + int rhs = std::any_cast(muls[i]->accept(this)); + auto* node = ctx->children.at(2 * i - 1); + auto text = node ? node->getText() : "+"; + if (text == "+") { + value += rhs; + } else if (text == "-") { + value -= rhs; + } else { + throw std::runtime_error(FormatError("sema", "非法加法运算符")); + } + } + return value; + } + + std::any visitMulExp(SysYParser::MulExpContext* ctx) override { + auto unaries = ctx->unaryExp(); + if (unaries.empty()) return 0; + int value = std::any_cast(unaries[0]->accept(this)); + for (size_t i = 1; i < unaries.size(); ++i) { + int rhs = std::any_cast(unaries[i]->accept(this)); + auto* node = ctx->children.at(2 * i - 1); + auto text = node ? node->getText() : "*"; + if (text == "*") { + value *= rhs; + } else if (text == "/") { + value /= rhs; + } else if (text == "%") { + value %= rhs; + } else { + throw std::runtime_error(FormatError("sema", "非法乘法运算符")); + } + } + return value; + } + + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override { + if (ctx->primaryExp()) return ctx->primaryExp()->accept(this); + if (ctx->unaryOp() && ctx->unaryExp()) { + int val = std::any_cast(ctx->unaryExp()->accept(this)); + auto op = ctx->unaryOp()->getText(); + if (op == "+") return val; + if (op == "-") return -val; + throw std::runtime_error(FormatError("sema", "constExp 不支持 !")); + } + throw std::runtime_error(FormatError("sema", "constExp 不支持函数调用")); + } + + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override { + if (ctx->exp()) return ctx->exp()->accept(this); + if (ctx->lVal()) return ctx->lVal()->accept(this); + if (ctx->number()) return ctx->number()->accept(this); + return 0; + } + + std::any visitNumber(SysYParser::NumberContext* ctx) override { + if (ctx->INT_CONST()) { + const std::string text = ctx->getText(); + size_t idx = 0; + long long val = std::stoll(text, &idx, 0); + if (idx != text.size()) { + throw std::runtime_error(FormatError("sema", "非法整数常量")); + } + return static_cast(val); + } + if (ctx->FLOAT_CONST()) { + return static_cast(std::stof(ctx->getText())); + } + throw std::runtime_error(FormatError("sema", "constExp 仅支持整数")); + } + + std::any visitLVal(SysYParser::LValContext* ctx) override { + if (!ctx || !ctx->ID()) { + throw std::runtime_error(FormatError("sema", "constExp 非法变量")); + } + const auto* entry = table_.Lookup(ctx->ID()->getText()); + if (!entry || !entry->is_const || !entry->const_value.has_value()) { + throw std::runtime_error(FormatError("sema", "constExp 使用了非常量")); + } + return entry->const_value.value(); + } + + private: + const SymbolTable& table_; +}; + class SemaVisitor final : public SysYBaseVisitor { public: std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { if (!ctx) { throw std::runtime_error(FormatError("sema", "缺少编译单元")); } - auto* func = ctx->funcDef(); - if (!func || !func->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); + for (auto* func : ctx->funcDef()) { + if (!func || !func->ID()) continue; + std::string name = func->ID()->getText(); + if (func_table_.find(name) != func_table_.end()) { + throw std::runtime_error(FormatError("sema", "重复定义函数: " + name)); + } + func_table_[name] = func; } - if (!func->ID() || func->ID()->getText() != "main") { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); + + for (auto* decl : ctx->decl()) { + if (decl) decl->accept(this); } - func->accept(this); - if (!seen_return_) { - throw std::runtime_error( - FormatError("sema", "main 函数必须包含 return 语句")); + for (auto* func : ctx->funcDef()) { + if (func) func->accept(this); + } + + if (func_table_.find("main") == func_table_.end()) { + throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); } return {}; } std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { - if (!ctx || !ctx->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); + if (!ctx || !ctx->block()) { + throw std::runtime_error(FormatError("sema", "函数体为空")); } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持 int main")); + if (!ctx->ID()) { + throw std::runtime_error(FormatError("sema", "缺少函数名")); + } + FuncTypeDesc fty; + fty.ret.base = BaseTypeFromFuncType(ctx->funcType()); + if (ctx->funcFParams()) { + for (auto* param : ctx->funcFParams()->funcFParam()) { + fty.params.push_back(BuildParamType(param)); + } + } + sema_.RegisterFunc(ctx, fty); + + current_ret_ = fty.ret.base; + seen_return_ = false; + + table_.EnterScope(); + if (ctx->funcFParams()) { + for (auto* param : ctx->funcFParams()->funcFParam()) { + RegisterParam(param); + } } - const auto& items = ctx->blockStmt()->blockItem(); - if (items.empty()) { - throw std::runtime_error( - FormatError("sema", "main 函数不能为空,且必须以 return 结束")); + ctx->block()->accept(this); + table_.ExitScope(); + + if (current_ret_ != BaseTypeKind::Void && !seen_return_) { + throw std::runtime_error(FormatError("sema", "非 void 函数缺少 return")); } - ctx->blockStmt()->accept(this); return {}; } - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "缺少语句块")); - } - const auto& items = ctx->blockItem(); - for (size_t i = 0; i < items.size(); ++i) { - auto* item = items[i]; - if (!item) { - continue; - } - if (seen_return_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); - } - current_item_index_ = i; - total_items_ = items.size(); - item->accept(this); + std::any visitBlock(SysYParser::BlockContext* ctx) override { + if (!ctx) return {}; + table_.EnterScope(); + for (auto* item : ctx->blockItem()) { + if (item) item->accept(this); } + table_.ExitScope(); return {}; } std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - if (ctx->decl()) { - ctx->decl()->accept(this); - return {}; + if (!ctx) return {}; + if (ctx->decl()) return ctx->decl()->accept(this); + if (ctx->stmt()) return ctx->stmt()->accept(this); + return {}; + } + + std::any visitDecl(SysYParser::DeclContext* ctx) override { + if (!ctx) return {}; + if (auto* c = ctx->constDecl()) return c->accept(this); + if (auto* v = ctx->varDecl()) return v->accept(this); + return {}; + } + + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override { + if (!ctx || !ctx->bType()) return {}; + BaseTypeKind base = BaseTypeFromBType(ctx->bType()); + for (auto* def : ctx->constDef()) { + RegisterConst(def, base); } - if (ctx->stmt()) { - ctx->stmt()->accept(this); - return {}; + return {}; + } + + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override { + if (!ctx || !ctx->bType()) return {}; + BaseTypeKind base = BaseTypeFromBType(ctx->bType()); + for (auto* def : ctx->varDef()) { + RegisterVar(def, base); } - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + return {}; } - std::any visitDecl(SysYParser::DeclContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); + std::any visitStmt(SysYParser::StmtContext* ctx) override { + if (!ctx) return {}; + if (ctx->lVal() && ctx->ASSIGN()) { + ctx->lVal()->accept(this); + if (ctx->exp()) ctx->exp()->accept(this); + return {}; + } + if (ctx->block()) return ctx->block()->accept(this); + if (ctx->IF()) { + if (ctx->cond()) ctx->cond()->accept(this); + if (ctx->stmt(0)) ctx->stmt(0)->accept(this); + if (ctx->stmt(1)) ctx->stmt(1)->accept(this); + return {}; } - if (!ctx->btype() || !ctx->btype()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明")); + if (ctx->WHILE()) { + loop_depth_++; + if (ctx->cond()) ctx->cond()->accept(this); + if (ctx->stmt(0)) ctx->stmt(0)->accept(this); + loop_depth_--; + return {}; } - auto* var_def = ctx->varDef(); - if (!var_def || !var_def->lValue()) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); + if (ctx->BREAK()) { + if (loop_depth_ == 0) { + throw std::runtime_error(FormatError("sema", "break 不在循环内")); + } + return {}; } - const std::string name = GetLValueName(*var_def->lValue()); - if (table_.Contains(name)) { - throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); + if (ctx->CONTINUE()) { + if (loop_depth_ == 0) { + throw std::runtime_error(FormatError("sema", "continue 不在循环内")); + } + return {}; } - if (auto* init = var_def->initValue()) { - if (!init->exp()) { - throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化")); + if (ctx->RETURN()) { + if (ctx->exp()) ctx->exp()->accept(this); + if (current_ret_ == BaseTypeKind::Void && ctx->exp()) { + throw std::runtime_error(FormatError("sema", "void 函数不能返回值")); } - init->exp()->accept(this); + if (current_ret_ != BaseTypeKind::Void && !ctx->exp()) { + throw std::runtime_error(FormatError("sema", "非 void 函数必须返回值")); + } + seen_return_ = true; + return {}; } - table_.Add(name, var_def); + if (ctx->exp()) ctx->exp()->accept(this); return {}; } - std::any visitStmt(SysYParser::StmtContext* ctx) override { - if (!ctx || !ctx->returnStmt()) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - ctx->returnStmt()->accept(this); + std::any visitExp(SysYParser::ExpContext* ctx) override { + if (ctx->addExp()) return ctx->addExp()->accept(this); return {}; } - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "return 缺少表达式")); - } - ctx->exp()->accept(this); - seen_return_ = true; - if (current_item_index_ + 1 != total_items_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); - } + std::any visitCond(SysYParser::CondContext* ctx) override { + if (ctx->lOrExp()) return ctx->lOrExp()->accept(this); return {}; } - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "非法括号表达式")); - } - ctx->exp()->accept(this); + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override { + for (auto* e : ctx->lAndExp()) e->accept(this); return {}; } - std::any visitVarExp(SysYParser::VarExpContext* ctx) override { - if (!ctx || !ctx->var()) { - throw std::runtime_error(FormatError("sema", "非法变量表达式")); - } - ctx->var()->accept(this); + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override { + for (auto* e : ctx->eqExp()) e->accept(this); + return {}; + } + + std::any visitEqExp(SysYParser::EqExpContext* ctx) override { + for (auto* e : ctx->relExp()) e->accept(this); return {}; } - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { - throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量")); + std::any visitRelExp(SysYParser::RelExpContext* ctx) override { + for (auto* e : ctx->addExp()) e->accept(this); + return {}; + } + + std::any visitAddExp(SysYParser::AddExpContext* ctx) override { + for (auto* mul : ctx->mulExp()) mul->accept(this); + return {}; + } + + std::any visitMulExp(SysYParser::MulExpContext* ctx) override { + for (auto* unary : ctx->unaryExp()) unary->accept(this); + return {}; + } + + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override { + if (ctx->primaryExp()) return ctx->primaryExp()->accept(this); + if (ctx->ID() && ctx->LPAREN()) { + std::string name = ctx->ID()->getText(); + auto it = func_table_.find(name); + if (it == func_table_.end()) { + if (builtin_funcs_.find(name) == builtin_funcs_.end()) { + throw std::runtime_error(FormatError("sema", "未定义的函数: " + name)); + } + } else { + sema_.BindFuncCall(ctx, it->second); + } + if (ctx->funcRParams()) ctx->funcRParams()->accept(this); + return {}; } + if (ctx->unaryExp()) return ctx->unaryExp()->accept(this); + return {}; + } + + std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override { + for (auto* e : ctx->exp()) e->accept(this); return {}; } - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { - throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式")); + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override { + if (ctx->exp()) return ctx->exp()->accept(this); + if (ctx->lVal()) return ctx->lVal()->accept(this); + if (ctx->number()) return ctx->number()->accept(this); + return {}; + } + + std::any visitNumber(SysYParser::NumberContext* ctx) override { + if (!ctx->INT_CONST() && !ctx->FLOAT_CONST()) { + throw std::runtime_error(FormatError("sema", "非法常量")); } - ctx->exp(0)->accept(this); - ctx->exp(1)->accept(this); return {}; } - std::any visitVar(SysYParser::VarContext* ctx) override { + std::any visitLVal(SysYParser::LValContext* ctx) override { if (!ctx || !ctx->ID()) { throw std::runtime_error(FormatError("sema", "非法变量引用")); } - const std::string name = ctx->ID()->getText(); - auto* decl = table_.Lookup(name); - if (!decl) { + std::string name = ctx->ID()->getText(); + const SymbolEntry* entry = table_.Lookup(name); + if (!entry) { throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); } - sema_.BindVarUse(ctx, decl); + BoundDecl bound; + if (entry->kind == SymbolKind::Var) { + bound.kind = BoundDecl::Kind::Var; + bound.var_decl = entry->var_decl; + } else if (entry->kind == SymbolKind::Const) { + bound.kind = BoundDecl::Kind::Const; + bound.const_decl = entry->const_decl; + } else { + bound.kind = BoundDecl::Kind::Param; + bound.param_decl = entry->param_decl; + } + sema_.BindVarUse(ctx, bound); return {}; } SemanticContext TakeSemanticContext() { return std::move(sema_); } + private: + TypeDesc BuildParamType(SysYParser::FuncFParamContext* ctx) { + if (!ctx || !ctx->bType()) { + throw std::runtime_error(FormatError("sema", "非法参数")); + } + TypeDesc ty; + ty.base = BaseTypeFromBType(ctx->bType()); + if (ctx->LBRACK().size() > 0) { + ty.dims.push_back(-1); + for (auto* exp : ctx->exp()) { + ty.dims.push_back(EvalConstExp(exp)); + } + } + return ty; + } + + void RegisterParam(SysYParser::FuncFParamContext* ctx) { + if (!ctx || !ctx->ID()) { + throw std::runtime_error(FormatError("sema", "参数缺少名称")); + } + std::string name = ctx->ID()->getText(); + if (table_.ContainsInCurrentScope(name)) { + throw std::runtime_error(FormatError("sema", "重复定义参数: " + name)); + } + TypeDesc ty = BuildParamType(ctx); + SymbolEntry entry; + entry.kind = SymbolKind::Param; + entry.param_decl = ctx; + entry.is_const = false; + entry.type = ty; + table_.Add(name, entry); + sema_.RegisterParam(ctx, ty); + } + + void RegisterVar(SysYParser::VarDefContext* ctx, BaseTypeKind base) { + if (!ctx || !ctx->ID()) { + throw std::runtime_error(FormatError("sema", "变量声明缺少名称")); + } + std::string name = ctx->ID()->getText(); + if (table_.ContainsInCurrentScope(name)) { + throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); + } + TypeDesc ty; + ty.base = base; + for (auto* dim : ctx->constExp()) { + ty.dims.push_back(EvalConstExp(dim)); + } + SymbolEntry entry; + entry.kind = SymbolKind::Var; + entry.var_decl = ctx; + entry.is_const = false; + entry.type = ty; + table_.Add(name, entry); + sema_.RegisterVarDecl(ctx, ty); + + if (auto* init = ctx->initVal()) { + init->accept(this); + } + } + + void RegisterConst(SysYParser::ConstDefContext* ctx, BaseTypeKind base) { + if (!ctx || !ctx->ID()) { + throw std::runtime_error(FormatError("sema", "常量声明缺少名称")); + } + std::string name = ctx->ID()->getText(); + if (table_.ContainsInCurrentScope(name)) { + throw std::runtime_error(FormatError("sema", "重复定义常量: " + name)); + } + TypeDesc ty; + ty.base = base; + ty.is_const = true; + for (auto* dim : ctx->constExp()) { + ty.dims.push_back(EvalConstExp(dim)); + } + SymbolEntry entry; + entry.kind = SymbolKind::Const; + entry.const_decl = ctx; + entry.is_const = true; + entry.type = ty; + if (ctx->constInitVal() && ty.dims.empty() && ty.base == BaseTypeKind::Int) { + if (auto* exp = ctx->constInitVal()->constExp()) { + entry.const_value = EvalConstExp(exp); + } + } + table_.Add(name, entry); + sema_.RegisterConstDecl(ctx, ty); + + if (auto* init = ctx->constInitVal()) { + init->accept(this); + } + } + + int EvalConstExp(SysYParser::ConstExpContext* ctx) { + ConstEvalVisitor visitor(table_); + return std::any_cast(ctx->accept(&visitor)); + } + + int EvalConstExp(SysYParser::ExpContext* ctx) { + if (!ctx || !ctx->addExp()) { + throw std::runtime_error(FormatError("sema", "非法常量表达式")); + } + ConstEvalVisitor visitor(table_); + return std::any_cast(ctx->addExp()->accept(&visitor)); + } + private: SymbolTable table_; SemanticContext sema_; + std::unordered_map func_table_; + const std::unordered_set builtin_funcs_ = { + "getint", "getch", "getarray", "putint", "putch", "putarray", + "getfloat", "getfarray", "putfloat", "putfarray"}; + BaseTypeKind current_ret_ = BaseTypeKind::Void; bool seen_return_ = false; - size_t current_item_index_ = 0; - size_t total_items_ = 0; + int loop_depth_ = 0; }; } // namespace @@ -197,4 +507,4 @@ SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { SemaVisitor visitor; comp_unit.accept(&visitor); return visitor.TakeSemanticContext(); -} +} \ No newline at end of file diff --git a/src/sem/SymbolTable.cpp b/src/sem/SymbolTable.cpp index ffeea89..3b37cd6 100644 --- a/src/sem/SymbolTable.cpp +++ b/src/sem/SymbolTable.cpp @@ -2,16 +2,34 @@ #include "sem/SymbolTable.h" -void SymbolTable::Add(const std::string& name, - SysYParser::VarDefContext* decl) { - table_[name] = decl; +void SymbolTable::EnterScope() { scopes_.emplace_back(); } + +void SymbolTable::ExitScope() { + if (!scopes_.empty()) { + scopes_.pop_back(); + } +} + +bool SymbolTable::ContainsInCurrentScope(const std::string& name) const { + if (scopes_.empty()) { + return false; + } + return scopes_.back().find(name) != scopes_.back().end(); } -bool SymbolTable::Contains(const std::string& name) const { - return table_.find(name) != table_.end(); +void SymbolTable::Add(const std::string& name, const SymbolEntry& entry) { + if (scopes_.empty()) { + EnterScope(); + } + scopes_.back()[name] = entry; } -SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const { - auto it = table_.find(name); - return it == table_.end() ? nullptr : it->second; +const SymbolEntry* SymbolTable::Lookup(const std::string& name) const { + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + auto found = it->find(name); + if (found != it->end()) { + return &found->second; + } + } + return nullptr; } diff --git a/sylib/sylib.c b/sylib/sylib.c index 7f26d0b..d968ce0 100644 --- a/sylib/sylib.c +++ b/sylib/sylib.c @@ -1,4 +1,54 @@ -// SysY 运行库实现: -// - 按实验/评测规范提供 I/O 等函数实现 -// - 与编译器生成的目标代码链接,支撑运行时行为 +#include + +int getint() { + int v = 0; + if (scanf("%d", &v) != 1) return 0; + return v; +} + +int getch() { return getchar(); } + +int getarray(int a[]) { + int n = 0; + if (scanf("%d", &n) != 1) return 0; + for (int i = 0; i < n; ++i) { + scanf("%d", &a[i]); + } + return n; +} + +void putint(int x) { printf("%d", x); } + +void putch(int x) { putchar(x); } + +void putarray(int n, int a[]) { + printf("%d", n); + for (int i = 0; i < n; ++i) { + printf(" %d", a[i]); + } +} + +float getfloat() { + float v = 0.0f; + if (scanf("%f", &v) != 1) return 0.0f; + return v; +} + +int getfarray(float a[]) { + int n = 0; + if (scanf("%d", &n) != 1) return 0; + for (int i = 0; i < n; ++i) { + scanf("%f", &a[i]); + } + return n; +} + +void putfloat(float x) { printf("%a", x); } + +void putfarray(int n, float a[]) { + printf("%d:", n); + for (int i = 0; i < n; ++i) { + printf(" %a", a[i]); + } +}