diff --git a/.gitignore b/.gitignore index 1ee33a1..53075fa 100644 --- a/.gitignore +++ b/.gitignore @@ -68,3 +68,9 @@ Thumbs.db # Project outputs # ========================= test/test_result/ + +# ========================= +# mxr +# ========================= +result.txt +build.sh \ No newline at end of file diff --git a/include/ir/IR.h b/include/ir/IR.h index b961192..ea62b05 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -41,10 +41,16 @@ namespace ir { class Type; +class ArrayType; +class FunctionType; class Value; class User; class ConstantValue; class ConstantInt; +class ConstantFloat; +class ConstantArray; +class ConstantZero; +class ConstantAggregateZero; class GlobalValue; class Instruction; class BasicBlock; @@ -77,39 +83,141 @@ class Use { }; // IR 上下文:集中管理类型、常量等共享资源,便于复用与扩展。 +// ir/IR.h - 修改 Context 类定义 + class Context { public: Context() = default; ~Context(); - // 去重创建 i32 常量。 + + // 去重创建 i32 常量 ConstantInt* GetConstInt(int v); + + // 去重创建浮点常量 + ConstantFloat* GetConstFloat(float v); + + // 创建数组常量(不去重,因为数组常量通常比较复杂且组合多样) + ConstantArray* GetConstArray(std::shared_ptr ty, + std::vector elements); + + // 获取零常量(按类型缓存) + ConstantZero* GetZeroConstant(std::shared_ptr ty); + + // 获取聚合类型的零常量 + ConstantAggregateZero* GetAggregateZero(std::shared_ptr ty); std::string NextTemp(); private: + // 数组常量缓存需要添加到类中 + struct ArrayKey { + std::shared_ptr type; + std::vector elements; + + bool operator==(const ArrayKey& other) const; + }; + + struct ArrayKeyHash { + size_t operator()(const ArrayKey& key) const; + }; + + std::unordered_map, ArrayKeyHash> array_cache_; + + // 其他现有成员... std::unordered_map> const_ints_; + std::unordered_map> const_floats_; + std::unordered_map> zero_constants_; + std::unordered_map> aggregate_zeros_; int temp_index_ = -1; }; class Type { public: - enum class Kind { Void, Int32, PtrInt32 }; - explicit Type(Kind k); + enum class Kind { Void, Int32, Float, PtrInt32, PtrFloat, Label, Array, Function, + Int1, PtrInt1}; + + virtual ~Type() = default; + // 使用静态共享对象获取类型。 // 同一类型可直接比较返回值是否相等,例如: // Type::GetInt32Type() == Type::GetInt32Type() static const std::shared_ptr& GetVoidType(); static const std::shared_ptr& GetInt32Type(); + static const std::shared_ptr& GetFloatType(); static const std::shared_ptr& GetPtrInt32Type(); - Kind GetKind() const; - bool IsVoid() const; - bool IsInt32() const; - bool IsPtrInt32() const; + static const std::shared_ptr& GetPtrFloatType(); + static const std::shared_ptr& GetLabelType(); + static std::shared_ptr GetArrayType(std::shared_ptr elem, std::vector dims); + static std::shared_ptr GetFunctionType(std::shared_ptr ret, std::vector> params); + static const std::shared_ptr& GetInt1Type(); + static const std::shared_ptr& GetPtrInt1Type(); + + + // 类型判断 + Kind GetKind() const { return kind_; } + bool IsVoid() const { return kind_ == Kind::Void; } + bool IsInt32() const { return kind_ == Kind::Int32; } + bool IsFloat() const { return kind_ == Kind::Float; } + bool IsPtrInt32() const { return kind_ == Kind::PtrInt32; } + bool IsPtrFloat() const { return kind_ == Kind::PtrFloat; } + bool IsLabel() const { return kind_ == Kind::Label; } + bool IsArray() const { return kind_ == Kind::Array; } + bool IsFunction() const { return kind_ == Kind::Function; } + bool IsInt1() const { return kind_ == Kind::Int1; } + bool IsPtrInt1() const { return kind_ == Kind::PtrInt1; } + + // 类型属性 + virtual size_t Size() const; // 字节大小 + virtual size_t Alignment() const; // 对齐要求 + virtual bool IsComplete() const; // 是否为完整类型(非 void,数组维度已知等) + +protected: + explicit Type(Kind k); // 构造函数 protected,只能由工厂和派生类调用 private: Kind kind_; }; +// 数组类型 +class ArrayType : public Type { +public: + // 获取元素类型和维度 + const std::shared_ptr& GetElementType() const { return elem_type_; } + const std::vector& GetDimensions() const { return dims_; } + size_t GetElementCount() const; // 总元素个数 + + size_t Size() const override; + size_t Alignment() const override; + bool IsComplete() const override; + +protected: + ArrayType(std::shared_ptr elem, std::vector dims); + friend class Type; // 允许 Type::GetArrayType 构造 + +private: + std::shared_ptr elem_type_; + std::vector dims_; +}; + +// 函数类型 +class FunctionType : public Type { +public: + const std::shared_ptr& GetReturnType() const { return return_type_; } + const std::vector>& GetParamTypes() const { return param_types_; } + + size_t Size() const override; // 函数类型没有大小,通常返回 0 + size_t Alignment() const override; // 无对齐要求 + bool IsComplete() const override; // 函数类型视为完整 + +protected: + FunctionType(std::shared_ptr ret, std::vector> params); + friend class Type; + +private: + std::shared_ptr return_type_; + std::vector> param_types_; +}; + class Value { public: Value(std::shared_ptr ty, std::string name); @@ -151,8 +259,82 @@ class ConstantInt : public ConstantValue { int value_{}; }; -// 后续还需要扩展更多指令类型。 -enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret }; +// 在 ConstantInt 类之后添加以下类 + +// ConstantFloat - 浮点常量 +class ConstantFloat : public ConstantValue { + public: + ConstantFloat(std::shared_ptr ty, float v); + float GetValue() const { return value_; } + + private: + float value_{}; +}; + +// ConstantArray - 数组常量 +class ConstantArray : public ConstantValue { + public: + // 构造函数:接收数组类型和常量元素列表 + ConstantArray(std::shared_ptr ty, std::vector elements); + + // 获取元素数量 + size_t GetNumElements() const { return elements_.size(); } + + // 获取指定索引的元素 + ConstantValue* GetElement(size_t index) const { return elements_[index]; } + + // 获取所有元素 + const std::vector& GetElements() const { return elements_; } + + // 验证常量数组的类型是否正确 + bool IsValid() const; + + private: + std::vector elements_; +}; + +// ConstantZero - 零常量(用于零初始化) +class ConstantZero : public ConstantValue { + public: + explicit ConstantZero(std::shared_ptr ty); + + // 工厂方法:创建特定类型的零常量 + static std::unique_ptr GetZero(std::shared_ptr ty); +}; + +// ConstantAggregateZero - 聚合类型的零常量(数组、结构体等) +class ConstantAggregateZero : public ConstantValue { + public: + explicit ConstantAggregateZero(std::shared_ptr ty); + + // 获取聚合类型 + std::shared_ptr GetAggregateType() const { return GetType(); } + + // 工厂方法:创建聚合类型的零常量 + static std::unique_ptr GetZero(std::shared_ptr ty); +}; +//function 参数占位类,目前仅保存类型和名字,后续可扩展更多属性(例如是否为数组参数、数组维度等)。 +class Argument : public Value { + public: + Argument(std::shared_ptr ty, std::string name); +}; + +// 后续还需要扩展更多指令类型。add call instruction 只是最小占位,后续可以继续补 sub/mul/div/rem、br/condbr、phi、gep 等指令。 +enum class Opcode { + Add, Sub, Mul, + Alloca, Load, Store, Ret, Call, + Br, CondBr, Icmp, ZExt, Trunc, + Div, Mod, + And, Or, Not, + GEP, + FAdd, FSub, FMul, FDiv, + FCmp, + SIToFP, // 整数转浮点 + FPToSI, // 浮点转整数 + FPExt, // 浮点扩展 + FPTrunc, // 浮点截断 + }; +// ZExt 和 Trunc 是零扩展和截断指令,SysY 的 int (i32) vs LLVM IR 的比较结果 (i1)。 // User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。 // 当前实现中只有 Instruction 继承自 User。 @@ -162,7 +344,16 @@ class User : public Value { size_t GetNumOperands() const; Value* GetOperand(size_t index) const; void SetOperand(size_t index, Value* value); - + // 添加模板方法,支持派生类自动转换 + template + void SetOperand(size_t index, T* value) { + SetOperand(index, static_cast(value)); + } + + template + void AddOperand(T* value) { + AddOperand(static_cast(value)); + } protected: // 统一的 operand 入口。 void AddOperand(Value* value); @@ -173,9 +364,53 @@ class User : public Value { // GlobalValue 是全局值/全局变量体系的空壳占位类。 // 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。 +// ir/IR.h - GlobalValue 类定义需要添加这些方法 + class GlobalValue : public User { - public: +private: + std::vector initializer_; + bool is_constant_ = false; + bool is_extern_ = false; + +public: GlobalValue(std::shared_ptr ty, std::string name); + + // 初始化器相关 + void SetInitializer(ConstantValue* init); + void SetInitializer(const std::vector& init); + const std::vector& GetInitializer() const { return initializer_; } + bool HasInitializer() const { return !initializer_.empty(); } + + // 常量属性 + void SetConstant(bool is_const) { is_constant_ = is_const; } + bool IsConstant() const { return is_constant_; } + + // 外部声明 + void SetExtern(bool is_extern) { is_extern_ = is_extern; } + bool IsExtern() const { return is_extern_; } + + // 类型判断 + bool IsArray() const { return GetType()->IsArray(); } + bool IsScalar() const { return GetType()->IsInt32() || GetType()->IsFloat(); } + + // 数组常量相关方法 + bool IsArrayConstant() const; + ConstantValue* GetArrayElement(size_t index) const; + size_t GetArraySize() const; + + // 获取数组大小(如果是数组类型) + int GetArraySizeInElements() const { + if (auto* array_ty = dynamic_cast(GetType().get())) { + return array_ty->GetElementCount(); + } + return 0; + } + +private: + // 辅助方法 + std::shared_ptr GetValueType() const; + bool CheckTypeCompatibility(std::shared_ptr value_type, + ConstantValue* init) const; }; class Instruction : public User { @@ -223,6 +458,284 @@ class StoreInst : public Instruction { Value* GetPtr() const; }; +// 在 IR.h 中修改 BranchInst 类定义 + +class BranchInst : public Instruction { + public: + // 无条件跳转构造函数 + BranchInst(std::shared_ptr void_ty, BasicBlock* target) + : Instruction(Opcode::Br, void_ty, ""), + is_conditional_(false), + cond_(nullptr), + target_(target), + true_target_(nullptr), + false_target_(nullptr) {} + + // 条件跳转构造函数 + BranchInst(std::shared_ptr void_ty, Value* cond, + BasicBlock* true_target, BasicBlock* false_target) + : Instruction(Opcode::CondBr, void_ty, ""), + is_conditional_(true), + cond_(cond), + target_(nullptr), + true_target_(true_target), + false_target_(false_target) { + // 添加操作数以便维护 def-use 关系 + AddOperand(cond); + // 注意:BasicBlock 也是 Value,也需要添加到操作数中 + // 但 BasicBlock 继承自 Value,所以可以添加 + AddOperand(true_target); + AddOperand(false_target); + } + + // 判断是否为条件跳转 + bool IsConditional() const { return is_conditional_; } + + // 获取无条件跳转的目标(仅适用于无条件跳转) + BasicBlock* GetTarget() const { + if (is_conditional_) { + throw std::runtime_error("GetTarget called on conditional branch"); + } + return target_; + } + + // 获取条件值(仅适用于条件跳转) + Value* GetCondition() const { + if (!is_conditional_) { + throw std::runtime_error("GetCondition called on unconditional branch"); + } + return cond_; + } + + // 获取真分支目标(仅适用于条件跳转) + BasicBlock* GetTrueTarget() const { + if (!is_conditional_) { + throw std::runtime_error("GetTrueTarget called on unconditional branch"); + } + return true_target_; + } + + // 获取假分支目标(仅适用于条件跳转) + BasicBlock* GetFalseTarget() const { + if (!is_conditional_) { + throw std::runtime_error("GetFalseTarget called on unconditional branch"); + } + return false_target_; + } + + // 设置无条件跳转目标 + void SetTarget(BasicBlock* target) { + if (is_conditional_) { + throw std::runtime_error("SetTarget called on conditional branch"); + } + target_ = target; + } + + // 设置条件跳转的分支目标 + void SetTrueTarget(BasicBlock* target) { + if (!is_conditional_) { + throw std::runtime_error("SetTrueTarget called on unconditional branch"); + } + true_target_ = target; + // 更新操作数 + SetOperand(1, target); + } + + void SetFalseTarget(BasicBlock* target) { + if (!is_conditional_) { + throw std::runtime_error("SetFalseTarget called on unconditional branch"); + } + false_target_ = target; + // 更新操作数 + SetOperand(2, target); + } + + void SetCondition(Value* cond) { + if (!is_conditional_) { + throw std::runtime_error("SetCondition called on unconditional branch"); + } + cond_ = cond; + // 更新操作数 + SetOperand(0, cond); + } + + private: + bool is_conditional_; + Value* cond_; // 条件值(条件跳转使用) + BasicBlock* target_; // 无条件跳转目标 + BasicBlock* true_target_; // 真分支目标(条件跳转使用) + BasicBlock* false_target_; // 假分支目标(条件跳转使用) +}; + +// 创建整数比较指令 + class IcmpInst : public Instruction { + public: + enum class Predicate { + EQ, // equal + NE, // not equal + LT, // less than + LE, // less than or equal + GT, // greater than + GE // greater than or equal + }; + + IcmpInst(Predicate pred, Value* lhs, Value* rhs, std::shared_ptr i1_ty, std::string name) + : Instruction(Opcode::Icmp, i1_ty, name), pred_(pred) { + AddOperand(lhs); + AddOperand(rhs); + } + + Predicate GetPredicate() const { return pred_; } + Value* GetLhs() const { return GetOperand(0); } + Value* GetRhs() const { return GetOperand(1); } + + private: + Predicate pred_; + }; + +class FcmpInst : public Instruction { + public: + enum class Predicate { + FALSE, // Always false + OEQ, // Ordered and equal + OGT, // Ordered and greater than + OGE, // Ordered and greater than or equal + OLT, // Ordered and less than + OLE, // Ordered and less than or equal + ONE, // Ordered and not equal + ORD, // Ordered (no nans) + UNO, // Unordered (isnan(x) || isnan(y)) + UEQ, // Unordered or equal + UGT, // Unordered or greater than + UGE, // Unordered or greater than or equal + ULT, // Unordered or less than + ULE, // Unordered or less than or equal + UNE, // Unordered or not equal + TRUE // Always true + }; + + FcmpInst(Predicate pred, Value* lhs, Value* rhs, + std::shared_ptr i1_ty, std::string name) + : Instruction(Opcode::FCmp, i1_ty, name), pred_(pred) { + AddOperand(lhs); + AddOperand(rhs); + } + + Predicate GetPredicate() const { return pred_; } + Value* GetLhs() const { return GetOperand(0); } + Value* GetRhs() const { return GetOperand(1); } + + private: + Predicate pred_; +}; + + // ZExtInst - 零扩展指令 +class ZExtInst : public Instruction { + public: + ZExtInst(Value* value, std::shared_ptr target_ty, std::string name = "") + : Instruction(Opcode::ZExt, target_ty, name) { + AddOperand(value); + } + + // 获取被扩展的值 + Value* GetValue() const { + return GetOperand(0); + } + + // 获取源类型 + std::shared_ptr GetSourceType() const { + return GetValue()->GetType(); + } + + // 获取目标类型 + std::shared_ptr GetTargetType() const { + return GetType(); + } + + // 设置被扩展的值 + void SetValue(Value* value) { + SetOperand(0, value); + } +}; + +// TruncInst - 截断指令 +class TruncInst : public Instruction { + public: + TruncInst(Value* value, std::shared_ptr target_ty, std::string name = "") + : Instruction(Opcode::Trunc, target_ty, name) { + AddOperand(value); + } + + // 获取被截断的值 + Value* GetValue() const { + return GetOperand(0); + } + + // 获取源类型 + std::shared_ptr GetSourceType() const { + return GetValue()->GetType(); + } + + // 获取目标类型 + std::shared_ptr GetTargetType() const { + return GetType(); + } + + // 设置被截断的值 + void SetValue(Value* value) { + SetOperand(0, value); + } +}; + +class GEPInst : public Instruction { + public: + GEPInst(std::shared_ptr ptr_ty, + Value* base, + const std::vector& indices, + const std::string& name); + Value* GetBase() const; + const std::vector& GetIndices() const; +}; + +// Function 当前也采用了最小实现。 +// 需要特别注意:由于项目里还没有单独的 FunctionType, +// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”, +// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。 +// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、 +// 形参和调用,通常需要引入专门的函数类型表示。 +class Function : public Value { + public: + // 当前构造函数接收完整的 FunctionType。 + Function(std::string name, std::shared_ptr func_type); + BasicBlock* CreateBlock(const std::string& name); + BasicBlock* GetEntry(); + const BasicBlock* GetEntry() const; + const std::vector>& GetBlocks() const; + // 函数增加参数的接口,目前仅保存参数类型和名字,后续可扩展更多属性(例如是否为数组参数、数组维度等)。注意这里是直接在 Function 上管理参数列表,而不是通过一个单独的 FunctionType 来表达完整函数签名,这也是当前最小实现的一个简化点。 + Argument* AddArgument(std::unique_ptr arg); + const std::vector>& GetArguments() const; + + private: + BasicBlock* entry_ = nullptr; + std::vector> blocks_; + std::vector> arguments_; +}; + + +class CallInst : public Instruction { + public: + CallInst(std::shared_ptr ret_ty, + Function* callee, + const std::vector& args, + const std::string& name); + Function* GetCallee() const; + const std::vector& GetArgs() const; + + private: + Function* callee_; + std::vector args_; +}; + // BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。 // 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。 class BasicBlock : public Value { @@ -247,6 +760,20 @@ class BasicBlock : public Value { return ptr; } + template + T* InsertBeforeTerminator(Args&&... args) { + auto inst = std::make_unique(std::forward(args)...); + auto* ptr = inst.get(); + ptr->SetParent(this); + + auto pos = instructions_.end(); + if (HasTerminator()) { + pos = instructions_.end() - 1; + } + instructions_.insert(pos, std::move(inst)); + return ptr; + } + private: Function* parent_ = nullptr; std::vector> instructions_; @@ -254,39 +781,51 @@ class BasicBlock : public Value { std::vector successors_; }; -// Function 当前也采用了最小实现。 -// 需要特别注意:由于项目里还没有单独的 FunctionType, -// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”, -// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。 -// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、 -// 形参和调用,通常需要引入专门的函数类型表示。 -class Function : public Value { - public: - // 当前构造函数接收的也是返回类型,而不是完整函数类型。 - Function(std::string name, std::shared_ptr ret_type); - BasicBlock* CreateBlock(const std::string& name); - BasicBlock* GetEntry(); - const BasicBlock* GetEntry() const; - const std::vector>& GetBlocks() const; - - private: - BasicBlock* entry_ = nullptr; - std::vector> blocks_; -}; class Module { public: Module() = default; Context& GetContext(); const Context& GetContext() const; - // 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。 + // 创建函数时传入完整的 FunctionType。 Function* CreateFunction(const std::string& name, - std::shared_ptr ret_type); + std::shared_ptr func_type); + GlobalValue* CreateGlobal(const std::string& name, + std::shared_ptr ty); + Function* FindFunction(const std::string& name) const; const std::vector>& GetFunctions() const; + const std::vector>& GetGlobals() const; private: Context context_; std::vector> functions_; + std::vector> globals_; +}; + +// SIToFP - 整数转浮点 +class SIToFPInst : public Instruction { + public: + SIToFPInst(Value* value, std::shared_ptr target_ty, std::string name = "") + : Instruction(Opcode::SIToFP, target_ty, name) { + AddOperand(value); + } + + Value* GetValue() const { + return GetOperand(0); + } +}; + +// FPToSI - 浮点转整数 +class FPToSIInst : public Instruction { + public: + FPToSIInst(Value* value, std::shared_ptr target_ty, std::string name = "") + : Instruction(Opcode::FPToSI, target_ty, name) { + AddOperand(value); + } + + Value* GetValue() const { + return GetOperand(0); + } }; class IRBuilder { @@ -297,14 +836,81 @@ class IRBuilder { // 构造常量、二元运算、返回指令的最小集合。 ConstantInt* CreateConstInt(int v); + ConstantFloat* CreateConstFloat(float v); // 新增 + ConstantArray* CreateConstArray(std::shared_ptr ty, + std::vector elements); // 新增 + ConstantZero* CreateZeroConstant(std::shared_ptr ty); // 新增 + BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name); BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name); + AllocaInst* CreateAlloca(std::shared_ptr ty, const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name); + AllocaInst* CreateAllocaFloat(const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name); StoreInst* CreateStore(Value* val, Value* ptr); ReturnInst* CreateRet(Value* v); - + CallInst* CreateCall(Function* callee, const std::vector& args, + const std::string& name); + + // 创建无条件跳转 + BranchInst* CreateBr(BasicBlock* target); + + // 创建条件跳转 + BranchInst* CreateCondBr(Value* cond, BasicBlock* true_target, + BasicBlock* false_target); + + // 创建整数比较指令 + IcmpInst* CreateICmpEQ(Value* lhs, Value* rhs, const std::string& name); + IcmpInst* CreateICmpNE(Value* lhs, Value* rhs, const std::string& name); + IcmpInst* CreateICmpLT(Value* lhs, Value* rhs, const std::string& name); + IcmpInst* CreateICmpLE(Value* lhs, Value* rhs, const std::string& name); + IcmpInst* CreateICmpGT(Value* lhs, Value* rhs, const std::string& name); + IcmpInst* CreateICmpGE(Value* lhs, Value* rhs, const std::string& name); + + // 创建类型转换指令 + ZExtInst* CreateZExt(Value* value, std::shared_ptr target_ty, + const std::string& name = ""); + TruncInst* CreateTrunc(Value* value, std::shared_ptr target_ty, + const std::string& name = ""); + + // 便捷方法 + ZExtInst* CreateZExtI1ToI32(Value* value, const std::string& name = "zext"); + TruncInst* CreateTruncI32ToI1(Value* value, const std::string& name = "trunc"); + + + + BinaryInst* CreateDiv(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateMod(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name); + + // 比较运算接口 + BinaryInst* CreateAnd(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateOr(Value* lhs, Value* rhs, const std::string& name); + IcmpInst* CreateNot(Value* val, const std::string& name); + + GEPInst* CreateGEP(Value* base, const std::vector& indices, const std::string& name); + + // 浮点运算 + BinaryInst* CreateFAdd(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateFSub(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateFMul(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateFDiv(Value* lhs, Value* rhs, const std::string& name); + + // 浮点比较 + FcmpInst* CreateFCmpOEQ(Value* lhs, Value* rhs, const std::string& name); + FcmpInst* CreateFCmpONE(Value* lhs, Value* rhs, const std::string& name); + FcmpInst* CreateFCmpOLT(Value* lhs, Value* rhs, const std::string& name); + FcmpInst* CreateFCmpOLE(Value* lhs, Value* rhs, const std::string& name); + FcmpInst* CreateFCmpOGT(Value* lhs, Value* rhs, const std::string& name); + FcmpInst* CreateFCmpOGE(Value* lhs, Value* rhs, const std::string& name); + + // 类型转换 + SIToFPInst* CreateSIToFP(Value* value, std::shared_ptr target_ty, + const std::string& name = ""); + FPToSIInst* CreateFPToSI(Value* value, std::shared_ptr target_ty, + const std::string& name = ""); private: Context& ctx_; BasicBlock* insert_block_; @@ -313,6 +919,9 @@ class IRBuilder { class IRPrinter { public: void Print(const Module& module, std::ostream& os); + + private: + void PrintConstant(const ConstantValue* constant, std::ostream& os); }; } // namespace ir diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index 231ba90..89b4e56 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -7,12 +7,23 @@ #include #include #include +#include +#include #include "SysYBaseVisitor.h" #include "SysYParser.h" #include "ir/IR.h" #include "sem/Sema.h" +//#define DEBUG_IRGen + +#ifdef DEBUG_IRGen +#include +#define DEBUG_MSG(msg) std::cerr << "[IRGen Debug] " << msg << std::endl +#else +#define DEBUG_MSG(msg) +#endif + namespace ir { class Module; class Function; @@ -21,38 +32,151 @@ class Value; } class IRGenImpl final : public SysYBaseVisitor { - public: - IRGenImpl(ir::Module& module, const SemanticContext& sema); +public: + // 修改构造函数,添加 SymbolTable 参数 + IRGenImpl(ir::Module& module, + const SemanticContext& sema, + const SymbolTable& sym_table); // 新增 + // 顶层 std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; + + // 块 + std::any visitBlock(SysYParser::BlockContext* ctx) override; // 注意:规则名为 Block std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; + + // 声明 std::any visitDecl(SysYParser::DeclContext* ctx) override; - std::any visitStmt(SysYParser::StmtContext* ctx) override; std::any visitVarDef(SysYParser::VarDefContext* ctx) override; - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override; - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override; - std::any visitVarExp(SysYParser::VarExpContext* ctx) override; - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override; - - private: - enum class BlockFlow { + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override; + std::any visitConstDef(SysYParser::ConstDefContext* ctx) override; + std::any visitInitVal(SysYParser::InitValContext* ctx) override; + std::any visitConstInitVal(SysYParser::ConstInitValContext* ctx) override; + + // 语句 + std::any visitStmt(SysYParser::StmtContext* ctx) override; + + // 表达式 + // 基本表达式(变量、常量、括号表达式)直接翻译为 IR 中的值;函数调用和一元运算需要特殊处理。 + std::any visitExp(SysYParser::ExpContext* ctx) override; + std::any visitCond(SysYParser::CondContext* ctx) override; + std::any visitLVal(SysYParser::LValContext* ctx) override; + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override; + + + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; + std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override; + std::any visitMulExp(SysYParser::MulExpContext* ctx) override; + // 加法表达式、乘法表达式、关系表达式、相等表达式、条件表达式分别对应不同的访问函数,按照优先级分层访问,最终调用 visitLVal 来处理变量访问 + std::any visitAddExp(SysYParser::AddExpContext* ctx) override; + std::any visitRelExp(SysYParser::RelExpContext* ctx) override; + std::any visitEqExp(SysYParser::EqExpContext* ctx) override; + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override; + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override; + std::any visitConstExp(SysYParser::ConstExpContext* ctx) override; + + // 辅助函数 + ir::Value* EvalExpr(SysYParser::ExpContext& expr); // 只保留一处 + ir::Value* EvalCond(SysYParser::CondContext& cond); + std::any visitCallExp(SysYParser::UnaryExpContext* ctx); + std::vector ProcessNestedInitVals(SysYParser::InitValContext* ctx); + // 带维度感知的展平:按 C 语言花括号对齐规则填充 total_size 个槽位 + // dims[0] 是最外层维度,dims.back() 是最内层维度(元素层) + // 返回已展平并补零的 total_size 大小的向量 + std::vector FlattenInitVal(SysYParser::InitValContext* ctx, + const std::vector& dims, + bool is_float); + int TryEvaluateConstInt(SysYParser::ConstExpContext* ctx); + void AddRuntimeFunctions(); + ir::Function* CreateRuntimeFunctionDecl(const std::string& funcName); + ir::BasicBlock* EnsureCleanupBlock(); + void RegisterCleanup(ir::Function* free_func, ir::Value* ptr); + ir::AllocaInst* CreateEntryAlloca(std::shared_ptr ty, + const std::string& name); + ir::AllocaInst* CreateEntryAllocaI32(const std::string& name); + ir::AllocaInst* CreateEntryAllocaFloat(const std::string& name); +private: + // 辅助函数声明 + enum class BlockFlow{ Continue, Terminated, }; BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item); - ir::Value* EvalExpr(SysYParser::ExpContext& expr); + + BlockFlow HandleReturnStmt(SysYParser::StmtContext* ctx); + BlockFlow HandleIfStmt(SysYParser::StmtContext* ctx); + BlockFlow HandleWhileStmt(SysYParser::StmtContext* ctx); + BlockFlow HandleBreakStmt(SysYParser::StmtContext* ctx); + BlockFlow HandleContinueStmt(SysYParser::StmtContext* ctx); + BlockFlow HandleAssignStmt(SysYParser::StmtContext* ctx); + + // 完全正确的左值判断函数 + bool IsLValueExpression(SysYParser::ExpContext* ctx); + bool IsLValueInAddExp(SysYParser::AddExpContext* ctx); + bool IsLValueInMulExp(SysYParser::MulExpContext* ctx); + bool IsLValueInUnaryExp(SysYParser::UnaryExpContext* ctx); + bool IsLValueInPrimaryExp(SysYParser::PrimaryExpContext* ctx); + // 循环上下文结构 + struct LoopContext { + ir::BasicBlock* condBlock; + ir::BasicBlock* bodyBlock; + ir::BasicBlock* exitBlock; + }; + + struct ArrayInfo { + std::vector elements; + std::vector dimensions; + }; + + std::vector loopStack_; ir::Module& module_; const SemanticContext& sema_; + const SymbolTable& symbol_table_; // 新增成员 ir::Function* func_; ir::IRBuilder builder_; - // 名称绑定由 Sema 负责;IRGen 只维护“声明 -> 存储槽位”的代码生成状态。 + ir::Value* EvalAssign(SysYParser::StmtContext* ctx); + + // 按 VarDefContext 查找存储位置(用于数组访问等场景) std::unordered_map storage_map_; + + // 按变量名快速查找(用于 LVal 等场景) + std::unordered_map local_var_map_; // 局部变量 + std::unordered_map global_map_; // 全局变量 + std::unordered_map param_map_; // 函数参数 + std::unordered_set pointer_param_names_; // 指针/数组形参名 + std::unordered_set heap_local_array_names_; // 堆分配的局部数组名 + + // 常量映射:常量名 -> 常量值(标量常量) + std::unordered_map const_value_map_; + + // 全局常量映射:常量名 -> 全局变量(数组常量) + std::unordered_map const_global_map_; + + // 原有的常量存储映射(用于兼容) + std::unordered_map const_storage_map_; + + std::unordered_map array_info_map_; + + std::string current_function_name_; + bool current_function_is_recursive_ = false; + ir::AllocaInst* function_return_slot_ = nullptr; + ir::BasicBlock* function_cleanup_block_ = nullptr; + std::vector> function_cleanup_actions_; + + // 新增:处理全局和局部变量的辅助函数 + // 修改处理函数的签名,使用 Symbol* 参数 + std::any HandleGlobalVariable(SysYParser::VarDefContext* ctx, + const std::string& varName, + const Symbol* sym); + + std::any HandleLocalVariable(SysYParser::VarDefContext* ctx, + const std::string& varName, + const Symbol* sym); }; +// 修改 GenerateIR 函数签名 std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, - const SemanticContext& sema); + const SemaResult& sema_result); \ No newline at end of file diff --git a/include/sem/Sema.h b/include/sem/Sema.h index 9ac057b..c053428 100644 --- a/include/sem/Sema.h +++ b/include/sem/Sema.h @@ -2,29 +2,101 @@ #pragma once #include +#include +#include #include "SysYParser.h" +#include "ir/IR.h" +#include "sem/SymbolTable.h" +// 表达式信息结构 +struct ExprInfo { + std::shared_ptr type = nullptr; + bool is_lvalue = false; + bool is_const = false; + bool is_const_int = false; // 是否是整型常量 + int const_int_value = 0; + float const_float_value = 0.0f; + antlr4::ParserRuleContext* node = nullptr; // 对应的语法树节点 +}; +// 语义分析上下文:存储分析过程中产生的信息 class SemanticContext { - public: - void BindVarUse(SysYParser::VarContext* use, - SysYParser::VarDefContext* decl) { - var_uses_[use] = decl; - } - - SysYParser::VarDefContext* ResolveVarUse( - const SysYParser::VarContext* use) const { - auto it = var_uses_.find(use); - return it == var_uses_.end() ? nullptr : it->second; - } - - private: - std::unordered_map - var_uses_; +public: + // ----- 变量使用绑定(使用 LValContext 而不是 VarContext)----- + void BindVarUse(SysYParser::LValContext* use, + SysYParser::VarDefContext* decl) { + var_uses_[use] = decl; + } + + SysYParser::VarDefContext* ResolveVarUse( + const SysYParser::LValContext* use) const { + auto it = var_uses_.find(use); + return it == var_uses_.end() ? nullptr : it->second; + } + + void BindConstUse(SysYParser::LValContext* use, SysYParser::ConstDefContext* decl) { + const_uses_[use] = decl; + } + SysYParser::ConstDefContext* ResolveConstUse(const SysYParser::LValContext* use) const { + auto it = const_uses_.find(use); + return it == const_uses_.end() ? nullptr : it->second; + } + + // ----- 表达式类型信息存储 ----- + void SetExprType(antlr4::ParserRuleContext* node, const ExprInfo& info) { + ExprInfo copy = info; + copy.node = node; + expr_types_[node] = copy; + } + + ExprInfo* GetExprType(antlr4::ParserRuleContext* node) { + auto it = expr_types_.find(node); + return it == expr_types_.end() ? nullptr : &it->second; + } + + const ExprInfo* GetExprType(antlr4::ParserRuleContext* node) const { + auto it = expr_types_.find(node); + return it == expr_types_.end() ? nullptr : &it->second; + } + + // ----- 隐式转换标记(供 IR 生成使用)----- + struct ConversionInfo { + antlr4::ParserRuleContext* node; + std::shared_ptr from_type; + std::shared_ptr to_type; + }; + + void AddConversion(antlr4::ParserRuleContext* node, + std::shared_ptr from, + std::shared_ptr to) { + conversions_.push_back({node, from, to}); + } + + const std::vector& GetConversions() const { return conversions_; } + +private: + // 变量使用映射 - 使用 LValContext 作为键 + std::unordered_map var_uses_; + + // 表达式类型映射 + std::unordered_map expr_types_; + + // 隐式转换列表 + std::vector conversions_; + + std::unordered_map const_uses_; }; // 目前仅检查: // - 变量先声明后使用 // - 局部变量不允许重复定义 -SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); +// SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); +// 新增:语义分析结果结构体 +struct SemaResult { + SemanticContext context; + SymbolTable symbol_table; +}; + +// 修改 RunSema 的返回类型 +SemaResult RunSema(SysYParser::CompUnitContext& comp_unit); \ No newline at end of file diff --git a/include/sem/SymbolTable.h b/include/sem/SymbolTable.h index c9396dd..ed986f1 100644 --- a/include/sem/SymbolTable.h +++ b/include/sem/SymbolTable.h @@ -1,17 +1,187 @@ // 极简符号表:记录局部变量定义点。 #pragma once +#include #include #include +#include +#include #include "SysYParser.h" +#include "ir/IR.h" +// 符号种类 +enum class SymbolKind { + Variable, + Function, + Parameter, + Constant +}; + +// 符号条目 +// 符号条目 +struct Symbol { + // 基本信息 + std::string name; + SymbolKind kind; + std::shared_ptr type; + int scope_level = 0; + int stack_offset = -1; + bool is_initialized = false; + bool is_builtin = false; + + // 数组参数相关 + std::vector array_dims; + bool is_array_param = false; + + // 函数相关 + std::vector> param_types; + + // 常量值存储 + union ConstantValue { + int i32; + float f32; + }; + + // 标量常量 + bool is_int_const = true; + ConstantValue const_value; + + // 数组常量(扁平化存储) + bool is_array_const = false; + std::vector array_const_values; + + // 语法树节点 + SysYParser::VarDefContext* var_def_ctx = nullptr; + SysYParser::ConstDefContext* const_def_ctx = nullptr; + SysYParser::FuncFParamContext* param_def_ctx = nullptr; + SysYParser::FuncDefContext* func_def_ctx = nullptr; + + // 辅助方法 + bool IsScalarConstant() const { + return kind == SymbolKind::Constant && !type->IsArray(); + } + + bool IsArrayConstant() const { + return kind == SymbolKind::Constant && type->IsArray(); + } + + int GetIntConstant() const { + if (!IsScalarConstant()) { + throw std::runtime_error("不是标量常量"); + } + if (!is_int_const) { + throw std::runtime_error("不是整型常量"); + } + return const_value.i32; + } + + float GetFloatConstant() const { + if (!IsScalarConstant()) { + throw std::runtime_error("不是标量常量"); + } + if (is_int_const) { + return static_cast(const_value.i32); + } + return const_value.f32; + } + + ConstantValue GetArrayElement(size_t index) const { + if (!IsArrayConstant()) { + throw std::runtime_error("不是数组常量"); + } + if (index >= array_const_values.size()) { + throw std::runtime_error("数组下标越界"); + } + return array_const_values[index]; + } + + size_t GetArraySize() const { + if (!IsArrayConstant()) return 0; + return array_const_values.size(); + } +}; class SymbolTable { - public: - void Add(const std::string& name, SysYParser::VarDefContext* decl); - bool Contains(const std::string& name) const; - SysYParser::VarDefContext* Lookup(const std::string& name) const; + public: + SymbolTable(); + ~SymbolTable() = default; + // 添加调试方法 + size_t getScopeCount() const { return active_scope_stack_.size(); } + + void dump() const { + std::cerr << "=== SymbolTable Dump ===" << std::endl; + for (size_t i = 0; i < scopes_.size(); ++i) { + std::cerr << "Scope " << i << " (depth=" << i << ")"; + bool active = std::find(active_scope_stack_.begin(), active_scope_stack_.end(), i) != active_scope_stack_.end(); + std::cerr << (active ? " [active]" : " [inactive]") << std::endl; + for (const auto& [name, sym] : scopes_[i]) { + std::cerr << " " << name + << " (kind=" << (int)sym.kind + << ", level=" << sym.scope_level << ")" << std::endl; + } + } + } + // ----- 作用域管理 ----- + void enterScope(); // 进入新作用域 + void exitScope(); // 退出当前作用域 + int currentScopeLevel() const { return static_cast(active_scope_stack_.size()) - 1; } + + // ----- 符号操作(推荐使用)----- + bool addSymbol(const Symbol& sym); // 添加符号到当前作用域 + Symbol* lookup(const std::string& name); // 从当前作用域向外查找 + Symbol* lookupCurrent(const std::string& name); // 仅在当前作用域查找 + const Symbol* lookup(const std::string& name) const; + const Symbol* lookupCurrent(const std::string& name) const; + const Symbol* lookupAll(const std::string& name) const; // 所有作用域查找,包括已结束的作用域 + const Symbol* lookupByVarDef(const SysYParser::VarDefContext* decl) const; // 通过定义节点查找符号 + const Symbol* lookupByConstDef(const SysYParser::ConstDefContext* decl) const; // 通过常量定义节点查找符号 + + // ----- 与原接口兼容(保留原有功能)----- + void Add(const std::string& name, SysYParser::VarDefContext* decl); + bool Contains(const std::string& name) const; + SysYParser::VarDefContext* Lookup(const std::string& name) const; + + // ----- 辅助函数:从语法树节点构造 Type ----- + static std::shared_ptr getTypeFromFuncDef(SysYParser::FuncDefContext* ctx); + + void registerBuiltinFunctions(); + + // 对常量表达式求值(返回整数值,用于数组维度等) + int EvaluateConstExp(SysYParser::ConstExpContext* ctx) const; + + // 对常量表达式求值(返回浮点值,用于全局初始化) + float EvaluateConstExpFloat(SysYParser::ConstExpContext* ctx) const; + + // 对常量初始化列表求值,返回一系列常量值(扁平化) + struct ConstValue { + enum Kind { INT, FLOAT }; + Kind kind; + union { + int int_val; + float float_val; + }; + }; + void flattenInit(SysYParser::ConstInitValContext* ctx, + std::vector& out, + std::shared_ptr base_type) const; + std::vector EvaluateConstInitVal( + SysYParser::ConstInitValContext* ctx, + const std::vector& dims, + std::shared_ptr base_type) const; + + int EvaluateConstExpression(SysYParser::ExpContext* ctx) const; + + private: + // 作用域栈:每个元素是一个从名字到符号的映射 + std::vector> scopes_; + std::vector active_scope_stack_; + + static constexpr int GLOBAL_SCOPE = 0; // 全局作用域索引 + + ConstValue EvaluateAddExp(SysYParser::AddExpContext* ctx) const; + ConstValue EvaluateMulExp(SysYParser::MulExpContext* ctx) const; + ConstValue EvaluateUnaryExp(SysYParser::UnaryExpContext* ctx) const; + ConstValue EvaluatePrimaryExp(SysYParser::PrimaryExpContext* ctx) const; - private: - std::unordered_map table_; + std::shared_ptr getTypeFromVarDef(SysYParser::VarDefContext* ctx) const; }; diff --git a/optimized.bc b/optimized.bc new file mode 100644 index 0000000..479f919 Binary files /dev/null and b/optimized.bc differ diff --git a/optimized.ll b/optimized.ll new file mode 100644 index 0000000..7f41f77 --- /dev/null +++ b/optimized.ll @@ -0,0 +1,492 @@ +; ModuleID = 'optimized.bc' +source_filename = "./build/test_compiler/performance/03_sort1.ll" + +@a = global [30000010 x i32] zeroinitializer +@ans = local_unnamed_addr global i32 0 + +declare i32 @getarray(ptr) local_unnamed_addr + +declare void @putint(i32) local_unnamed_addr + +declare void @putch(i32) local_unnamed_addr + +declare void @starttime() local_unnamed_addr + +declare void @stoptime() local_unnamed_addr + +declare ptr @sysy_alloc_i32(i32) local_unnamed_addr + +declare void @sysy_free_i32(ptr) local_unnamed_addr + +declare void @sysy_zero_i32(ptr, i32) local_unnamed_addr + +; Function Attrs: nofree norecurse nosync nounwind memory(argmem: read) +define i32 @getMaxNum(i32 %n, ptr nocapture readonly %arr) local_unnamed_addr #0 { +entry: + %t95 = icmp sgt i32 %n, 0 + br i1 %t95, label %while.body.t5, label %while.exit.t6 + +while.body.t5: ; preds = %entry, %while.body.t5 + %t3_i.07 = phi i32 [ %t21, %while.body.t5 ], [ 0, %entry ] + %t2_ret.06 = phi i32 [ %spec.select, %while.body.t5 ], [ 0, %entry ] + %0 = zext nneg i32 %t3_i.07 to i64 + %t13 = getelementptr i32, ptr %arr, i64 %0 + %t14 = load i32, ptr %t13, align 4 + %spec.select = tail call i32 @llvm.smax.i32(i32 %t14, i32 %t2_ret.06) + %t21 = add nuw nsw i32 %t3_i.07, 1 + %t9 = icmp slt i32 %t21, %n + br i1 %t9, label %while.body.t5, label %while.exit.t6 + +while.exit.t6: ; preds = %while.body.t5, %entry + %t2_ret.0.lcssa = phi i32 [ 0, %entry ], [ %spec.select, %while.body.t5 ] + ret i32 %t2_ret.0.lcssa +} + +; Function Attrs: nofree norecurse nosync nounwind memory(none) +define i32 @getNumPos(i32 %num, i32 %pos) local_unnamed_addr #1 { +entry: + %t333 = icmp sgt i32 %pos, 0 + br i1 %t333, label %while.body.t29, label %while.exit.t30 + +while.body.t29: ; preds = %entry, %while.body.t29 + %t27_i.05 = phi i32 [ %t37, %while.body.t29 ], [ 0, %entry ] + %t24.04 = phi i32 [ %t35, %while.body.t29 ], [ %num, %entry ] + %t35 = sdiv i32 %t24.04, 16 + %t37 = add nuw nsw i32 %t27_i.05, 1 + %t33 = icmp slt i32 %t37, %pos + br i1 %t33, label %while.body.t29, label %while.exit.t30 + +while.exit.t30: ; preds = %while.body.t29, %entry + %t24.0.lcssa = phi i32 [ %num, %entry ], [ %t35, %while.body.t29 ] + %t39 = srem i32 %t24.0.lcssa, 16 + ret i32 %t39 +} + +define void @radixSort(i32 %bitround, ptr nocapture %a, i32 %l, i32 %r) local_unnamed_addr { +entry: + %t43 = tail call ptr @sysy_alloc_i32(i32 16) + tail call void @sysy_zero_i32(ptr %t43, i32 16) + %t46 = tail call ptr @sysy_alloc_i32(i32 16) + tail call void @sysy_zero_i32(ptr %t46, i32 16) + %t48 = tail call ptr @sysy_alloc_i32(i32 16) + tail call void @sysy_zero_i32(ptr %t48, i32 16) + %t54 = icmp eq i32 %bitround, -1 + %t56 = add i32 %l, 1 + %t58 = icmp sge i32 %t56, %r + %t59 = or i1 %t54, %t58 + br i1 %t59, label %cleanup.t44, label %while.cond.t62.preheader + +while.cond.t62.preheader: ; preds = %entry + %t6796 = icmp slt i32 %l, %r + br i1 %t6796, label %while.body.t63.lr.ph, label %while.exit.t64 + +while.body.t63.lr.ph: ; preds = %while.cond.t62.preheader + %t333.i = icmp sgt i32 %bitround, 0 + br label %while.body.t63 + +cleanup.t44: ; preds = %merge.t196, %entry + tail call void @sysy_free_i32(ptr %t48) + tail call void @sysy_free_i32(ptr %t46) + tail call void @sysy_free_i32(ptr %t43) + ret void + +while.body.t63: ; preds = %while.body.t63.lr.ph, %getNumPos.exit28 + %storemerge97 = phi i32 [ %l, %while.body.t63.lr.ph ], [ %t83, %getNumPos.exit28 ] + %0 = sext i32 %storemerge97 to i64 + %t69 = getelementptr i32, ptr %a, i64 %0 + %t70 = load i32, ptr %t69, align 4 + br i1 %t333.i, label %while.body.t29.i, label %getNumPos.exit.thread + +getNumPos.exit.thread: ; preds = %while.body.t63 + %t39.i80 = srem i32 %t70, 16 + %1 = sext i32 %t39.i80 to i64 + %t7381 = getelementptr i32, ptr %t48, i64 %1 + %t7482 = load i32, ptr %t7381, align 4 + br label %getNumPos.exit28 + +while.body.t29.i: ; preds = %while.body.t63, %while.body.t29.i + %t27_i.05.i = phi i32 [ %t37.i, %while.body.t29.i ], [ 0, %while.body.t63 ] + %t24.04.i = phi i32 [ %t35.i, %while.body.t29.i ], [ %t70, %while.body.t63 ] + %t35.i = sdiv i32 %t24.04.i, 16 + %t37.i = add nuw nsw i32 %t27_i.05.i, 1 + %t33.i = icmp slt i32 %t37.i, %bitround + br i1 %t33.i, label %while.body.t29.i, label %getNumPos.exit + +getNumPos.exit: ; preds = %while.body.t29.i + %t39.i = srem i32 %t35.i, 16 + %2 = sext i32 %t39.i to i64 + %t73 = getelementptr i32, ptr %t48, i64 %2 + %t74 = load i32, ptr %t73, align 4 + br label %while.body.t29.i22 + +while.body.t29.i22: ; preds = %getNumPos.exit, %while.body.t29.i22 + %t27_i.05.i23 = phi i32 [ %t37.i26, %while.body.t29.i22 ], [ 0, %getNumPos.exit ] + %t24.04.i24 = phi i32 [ %t35.i25, %while.body.t29.i22 ], [ %t70, %getNumPos.exit ] + %t35.i25 = sdiv i32 %t24.04.i24, 16 + %t37.i26 = add nuw nsw i32 %t27_i.05.i23, 1 + %t33.i27 = icmp slt i32 %t37.i26, %bitround + br i1 %t33.i27, label %while.body.t29.i22, label %getNumPos.exit28.loopexit + +getNumPos.exit28.loopexit: ; preds = %while.body.t29.i22 + %.pre114 = srem i32 %t35.i25, 16 + %.pre115 = sext i32 %.pre114 to i64 + br label %getNumPos.exit28 + +getNumPos.exit28: ; preds = %getNumPos.exit28.loopexit, %getNumPos.exit.thread + %.pre-phi116 = phi i64 [ %.pre115, %getNumPos.exit28.loopexit ], [ %1, %getNumPos.exit.thread ] + %t7584.in = phi i32 [ %t74, %getNumPos.exit28.loopexit ], [ %t7482, %getNumPos.exit.thread ] + %t7584 = add i32 %t7584.in, 1 + %t81 = getelementptr i32, ptr %t48, i64 %.pre-phi116 + store i32 %t7584, ptr %t81, align 4 + %t83 = add nsw i32 %storemerge97, 1 + %t67 = icmp slt i32 %t83, %r + br i1 %t67, label %while.body.t63, label %while.exit.t64 + +while.exit.t64: ; preds = %getNumPos.exit28, %while.cond.t62.preheader + store i32 %l, ptr %t43, align 4 + %t88 = load i32, ptr %t48, align 4 + %t89 = add i32 %t88, %l + store i32 %t89, ptr %t46, align 4 + %invariant.gep = getelementptr i32, ptr %t46, i64 -1 + %t101 = getelementptr i32, ptr %t43, i64 1 + store i32 %t89, ptr %t101, align 4 + %t106 = getelementptr i32, ptr %t48, i64 1 + %t107 = load i32, ptr %t106, align 4 + %t108 = add i32 %t107, %t89 + %t110 = getelementptr i32, ptr %t46, i64 1 + store i32 %t108, ptr %t110, align 4 + %t101.1 = getelementptr i32, ptr %t43, i64 2 + store i32 %t108, ptr %t101.1, align 4 + %t106.1 = getelementptr i32, ptr %t48, i64 2 + %t107.1 = load i32, ptr %t106.1, align 4 + %t108.1 = add i32 %t107.1, %t108 + %t110.1 = getelementptr i32, ptr %t46, i64 2 + store i32 %t108.1, ptr %t110.1, align 4 + %t101.2 = getelementptr i32, ptr %t43, i64 3 + store i32 %t108.1, ptr %t101.2, align 4 + %t106.2 = getelementptr i32, ptr %t48, i64 3 + %t107.2 = load i32, ptr %t106.2, align 4 + %t108.2 = add i32 %t107.2, %t108.1 + %t110.2 = getelementptr i32, ptr %t46, i64 3 + store i32 %t108.2, ptr %t110.2, align 4 + %t101.3 = getelementptr i32, ptr %t43, i64 4 + store i32 %t108.2, ptr %t101.3, align 4 + %t106.3 = getelementptr i32, ptr %t48, i64 4 + %t107.3 = load i32, ptr %t106.3, align 4 + %t108.3 = add i32 %t107.3, %t108.2 + %t110.3 = getelementptr i32, ptr %t46, i64 4 + store i32 %t108.3, ptr %t110.3, align 4 + %t101.4 = getelementptr i32, ptr %t43, i64 5 + store i32 %t108.3, ptr %t101.4, align 4 + %t106.4 = getelementptr i32, ptr %t48, i64 5 + %t107.4 = load i32, ptr %t106.4, align 4 + %t108.4 = add i32 %t107.4, %t108.3 + %t110.4 = getelementptr i32, ptr %t46, i64 5 + store i32 %t108.4, ptr %t110.4, align 4 + %t101.5 = getelementptr i32, ptr %t43, i64 6 + store i32 %t108.4, ptr %t101.5, align 4 + %t106.5 = getelementptr i32, ptr %t48, i64 6 + %t107.5 = load i32, ptr %t106.5, align 4 + %t108.5 = add i32 %t107.5, %t108.4 + %t110.5 = getelementptr i32, ptr %t46, i64 6 + store i32 %t108.5, ptr %t110.5, align 4 + %t101.6 = getelementptr i32, ptr %t43, i64 7 + store i32 %t108.5, ptr %t101.6, align 4 + %t106.6 = getelementptr i32, ptr %t48, i64 7 + %t107.6 = load i32, ptr %t106.6, align 4 + %t108.6 = add i32 %t107.6, %t108.5 + %t110.6 = getelementptr i32, ptr %t46, i64 7 + store i32 %t108.6, ptr %t110.6, align 4 + %t101.7 = getelementptr i32, ptr %t43, i64 8 + store i32 %t108.6, ptr %t101.7, align 4 + %t106.7 = getelementptr i32, ptr %t48, i64 8 + %t107.7 = load i32, ptr %t106.7, align 4 + %t108.7 = add i32 %t107.7, %t108.6 + %t110.7 = getelementptr i32, ptr %t46, i64 8 + store i32 %t108.7, ptr %t110.7, align 4 + %t101.8 = getelementptr i32, ptr %t43, i64 9 + store i32 %t108.7, ptr %t101.8, align 4 + %t106.8 = getelementptr i32, ptr %t48, i64 9 + %t107.8 = load i32, ptr %t106.8, align 4 + %t108.8 = add i32 %t107.8, %t108.7 + %t110.8 = getelementptr i32, ptr %t46, i64 9 + store i32 %t108.8, ptr %t110.8, align 4 + %t101.9 = getelementptr i32, ptr %t43, i64 10 + store i32 %t108.8, ptr %t101.9, align 4 + %t106.9 = getelementptr i32, ptr %t48, i64 10 + %t107.9 = load i32, ptr %t106.9, align 4 + %t108.9 = add i32 %t107.9, %t108.8 + %t110.9 = getelementptr i32, ptr %t46, i64 10 + store i32 %t108.9, ptr %t110.9, align 4 + %t101.10 = getelementptr i32, ptr %t43, i64 11 + store i32 %t108.9, ptr %t101.10, align 4 + %t106.10 = getelementptr i32, ptr %t48, i64 11 + %t107.10 = load i32, ptr %t106.10, align 4 + %t108.10 = add i32 %t107.10, %t108.9 + %t110.10 = getelementptr i32, ptr %t46, i64 11 + store i32 %t108.10, ptr %t110.10, align 4 + %t101.11 = getelementptr i32, ptr %t43, i64 12 + store i32 %t108.10, ptr %t101.11, align 4 + %t106.11 = getelementptr i32, ptr %t48, i64 12 + %t107.11 = load i32, ptr %t106.11, align 4 + %t108.11 = add i32 %t107.11, %t108.10 + %t110.11 = getelementptr i32, ptr %t46, i64 12 + store i32 %t108.11, ptr %t110.11, align 4 + %t101.12 = getelementptr i32, ptr %t43, i64 13 + store i32 %t108.11, ptr %t101.12, align 4 + %t106.12 = getelementptr i32, ptr %t48, i64 13 + %t107.12 = load i32, ptr %t106.12, align 4 + %t108.12 = add i32 %t107.12, %t108.11 + %t110.12 = getelementptr i32, ptr %t46, i64 13 + store i32 %t108.12, ptr %t110.12, align 4 + %t101.13 = getelementptr i32, ptr %t43, i64 14 + store i32 %t108.12, ptr %t101.13, align 4 + %t106.13 = getelementptr i32, ptr %t48, i64 14 + %t107.13 = load i32, ptr %t106.13, align 4 + %t108.13 = add i32 %t107.13, %t108.12 + %t110.13 = getelementptr i32, ptr %t46, i64 14 + store i32 %t108.13, ptr %t110.13, align 4 + %t101.14 = getelementptr i32, ptr %t43, i64 15 + store i32 %t108.13, ptr %t101.14, align 4 + %t106.14 = getelementptr i32, ptr %t48, i64 15 + %t107.14 = load i32, ptr %t106.14, align 4 + %t108.14 = add i32 %t107.14, %t108.13 + %t110.14 = getelementptr i32, ptr %t46, i64 15 + store i32 %t108.14, ptr %t110.14, align 4 + %t333.i29 = icmp sgt i32 %bitround, 0 + br label %while.cond.t118.preheader + +while.cond.t118.preheader: ; preds = %while.exit.t64, %while.exit.t120 + %storemerge17104 = phi i32 [ 0, %while.exit.t64 ], [ %t180, %while.exit.t120 ] + %3 = zext nneg i32 %storemerge17104 to i64 + %t122 = getelementptr i32, ptr %t43, i64 %3 + %t125 = getelementptr i32, ptr %t46, i64 %3 + %t123100 = load i32, ptr %t122, align 4 + %t126101 = load i32, ptr %t125, align 4 + %t127102 = icmp slt i32 %t123100, %t126101 + br i1 %t127102, label %while.body.t119, label %while.exit.t120 + +while.body.t191.peel.next: ; preds = %while.exit.t120 + store i32 %l, ptr %t43, align 4 + %t187 = load i32, ptr %t48, align 4 + %t188 = add i32 %t187, %l + store i32 %t188, ptr %t46, align 4 + %t215 = add i32 %bitround, -1 + %t218.peel.pre = load i32, ptr %t43, align 4 + tail call void @radixSort(i32 %t215, ptr %a, i32 %t218.peel.pre, i32 %t188) + br label %merge.t196 + +while.body.t119: ; preds = %while.cond.t118.preheader, %while.exit.t136 + %t123103 = phi i32 [ %t176, %while.exit.t136 ], [ %t123100, %while.cond.t118.preheader ] + %4 = sext i32 %t123103 to i64 + %t132 = getelementptr i32, ptr %a, i64 %4 + %t133 = load i32, ptr %t132, align 4 + br label %while.cond.t134 + +while.exit.t120: ; preds = %while.exit.t136, %while.cond.t118.preheader + %t180 = add nuw nsw i32 %storemerge17104, 1 + %t117 = icmp ult i32 %storemerge17104, 15 + br i1 %t117, label %while.cond.t118.preheader, label %while.body.t191.peel.next + +while.cond.t134: ; preds = %getNumPos.exit78, %while.body.t119 + %t15099 = phi i32 [ %t150129, %getNumPos.exit78 ], [ %t133, %while.body.t119 ] + br i1 %t333.i29, label %while.body.t29.i32, label %getNumPos.exit38.thread + +while.body.t29.i32: ; preds = %while.cond.t134, %while.body.t29.i32 + %t27_i.05.i33 = phi i32 [ %t37.i36, %while.body.t29.i32 ], [ 0, %while.cond.t134 ] + %t24.04.i34 = phi i32 [ %t35.i35, %while.body.t29.i32 ], [ %t15099, %while.cond.t134 ] + %t35.i35 = sdiv i32 %t24.04.i34, 16 + %t37.i36 = add nuw nsw i32 %t27_i.05.i33, 1 + %t33.i37 = icmp slt i32 %t37.i36, %bitround + br i1 %t33.i37, label %while.body.t29.i32, label %getNumPos.exit38 + +getNumPos.exit38: ; preds = %while.body.t29.i32 + %t39.i31 = srem i32 %t35.i35, 16 + %t141.not = icmp eq i32 %t39.i31, %storemerge17104 + br i1 %t141.not, label %while.exit.t136, label %while.body.t29.i42 + +getNumPos.exit38.thread: ; preds = %while.cond.t134 + %t39.i3186 = srem i32 %t15099, 16 + %t141.not87 = icmp eq i32 %t39.i3186, %storemerge17104 + br i1 %t141.not87, label %while.exit.t136, label %getNumPos.exit48.thread + +getNumPos.exit48.thread: ; preds = %getNumPos.exit38.thread + %5 = sext i32 %t39.i3186 to i64 + %t147125 = getelementptr i32, ptr %t43, i64 %5 + %t148126 = load i32, ptr %t147125, align 4 + %6 = sext i32 %t148126 to i64 + %t149127 = getelementptr i32, ptr %a, i64 %6 + %t150128 = load i32, ptr %t149127, align 4 + br label %getNumPos.exit68.thread + +while.body.t29.i42: ; preds = %getNumPos.exit38, %while.body.t29.i42 + %t27_i.05.i43 = phi i32 [ %t37.i46, %while.body.t29.i42 ], [ 0, %getNumPos.exit38 ] + %t24.04.i44 = phi i32 [ %t35.i45, %while.body.t29.i42 ], [ %t15099, %getNumPos.exit38 ] + %t35.i45 = sdiv i32 %t24.04.i44, 16 + %t37.i46 = add nuw nsw i32 %t27_i.05.i43, 1 + %t33.i47 = icmp slt i32 %t37.i46, %bitround + br i1 %t33.i47, label %while.body.t29.i42, label %getNumPos.exit48 + +getNumPos.exit48: ; preds = %while.body.t29.i42 + %.pre117 = srem i32 %t35.i45, 16 + %7 = sext i32 %.pre117 to i64 + %t147 = getelementptr i32, ptr %t43, i64 %7 + %t148 = load i32, ptr %t147, align 4 + %8 = sext i32 %t148 to i64 + %t149 = getelementptr i32, ptr %a, i64 %8 + %t150 = load i32, ptr %t149, align 4 + br i1 %t333.i29, label %while.body.t29.i52, label %getNumPos.exit68.thread + +while.body.t29.i52: ; preds = %getNumPos.exit48, %while.body.t29.i52 + %t27_i.05.i53 = phi i32 [ %t37.i56, %while.body.t29.i52 ], [ 0, %getNumPos.exit48 ] + %t24.04.i54 = phi i32 [ %t35.i55, %while.body.t29.i52 ], [ %t15099, %getNumPos.exit48 ] + %t35.i55 = sdiv i32 %t24.04.i54, 16 + %t37.i56 = add nuw nsw i32 %t27_i.05.i53, 1 + %t33.i57 = icmp slt i32 %t37.i56, %bitround + br i1 %t33.i57, label %while.body.t29.i52, label %while.body.t29.i62.preheader + +while.body.t29.i62.preheader: ; preds = %while.body.t29.i52 + %t39.i51 = srem i32 %t35.i55, 16 + %9 = sext i32 %t39.i51 to i64 + %t155 = getelementptr i32, ptr %t43, i64 %9 + %t156 = load i32, ptr %t155, align 4 + %10 = sext i32 %t156 to i64 + %t157 = getelementptr i32, ptr %a, i64 %10 + store i32 %t15099, ptr %t157, align 4 + br label %while.body.t29.i62 + +getNumPos.exit68.thread: ; preds = %getNumPos.exit48, %getNumPos.exit48.thread + %t150130 = phi i32 [ %t150128, %getNumPos.exit48.thread ], [ %t150, %getNumPos.exit48 ] + %t39.i51.c = srem i32 %t15099, 16 + %11 = sext i32 %t39.i51.c to i64 + %t155.c = getelementptr i32, ptr %t43, i64 %11 + %t156.c = load i32, ptr %t155.c, align 4 + %12 = sext i32 %t156.c to i64 + %t157.c = getelementptr i32, ptr %a, i64 %12 + store i32 %t15099, ptr %t157.c, align 4 + %t16191 = getelementptr i32, ptr %t43, i64 %11 + %t16292 = load i32, ptr %t16191, align 4 + br label %getNumPos.exit78 + +while.body.t29.i62: ; preds = %while.body.t29.i62.preheader, %while.body.t29.i62 + %t27_i.05.i63 = phi i32 [ %t37.i66, %while.body.t29.i62 ], [ 0, %while.body.t29.i62.preheader ] + %t24.04.i64 = phi i32 [ %t35.i65, %while.body.t29.i62 ], [ %t15099, %while.body.t29.i62.preheader ] + %t35.i65 = sdiv i32 %t24.04.i64, 16 + %t37.i66 = add nuw nsw i32 %t27_i.05.i63, 1 + %t33.i67 = icmp slt i32 %t37.i66, %bitround + br i1 %t33.i67, label %while.body.t29.i62, label %getNumPos.exit68 + +getNumPos.exit68: ; preds = %while.body.t29.i62 + %t39.i61 = srem i32 %t35.i65, 16 + %13 = sext i32 %t39.i61 to i64 + %t161 = getelementptr i32, ptr %t43, i64 %13 + %t162 = load i32, ptr %t161, align 4 + br label %while.body.t29.i72 + +while.body.t29.i72: ; preds = %getNumPos.exit68, %while.body.t29.i72 + %t27_i.05.i73 = phi i32 [ %t37.i76, %while.body.t29.i72 ], [ 0, %getNumPos.exit68 ] + %t24.04.i74 = phi i32 [ %t35.i75, %while.body.t29.i72 ], [ %t15099, %getNumPos.exit68 ] + %t35.i75 = sdiv i32 %t24.04.i74, 16 + %t37.i76 = add nuw nsw i32 %t27_i.05.i73, 1 + %t33.i77 = icmp slt i32 %t37.i76, %bitround + br i1 %t33.i77, label %while.body.t29.i72, label %getNumPos.exit78.loopexit + +getNumPos.exit78.loopexit: ; preds = %while.body.t29.i72 + %.pre118 = srem i32 %t35.i75, 16 + %.pre119 = sext i32 %.pre118 to i64 + br label %getNumPos.exit78 + +getNumPos.exit78: ; preds = %getNumPos.exit78.loopexit, %getNumPos.exit68.thread + %t150129 = phi i32 [ %t150, %getNumPos.exit78.loopexit ], [ %t150130, %getNumPos.exit68.thread ] + %.pre-phi120 = phi i64 [ %.pre119, %getNumPos.exit78.loopexit ], [ %11, %getNumPos.exit68.thread ] + %t16394.in = phi i32 [ %t162, %getNumPos.exit78.loopexit ], [ %t16292, %getNumPos.exit68.thread ] + %t16394 = add i32 %t16394.in, 1 + %t167 = getelementptr i32, ptr %t43, i64 %.pre-phi120 + store i32 %t16394, ptr %t167, align 4 + br label %while.cond.t134 + +while.exit.t136: ; preds = %getNumPos.exit38.thread, %getNumPos.exit38 + %t171 = load i32, ptr %t122, align 4 + %14 = sext i32 %t171 to i64 + %t172 = getelementptr i32, ptr %a, i64 %14 + store i32 %t15099, ptr %t172, align 4 + %t175 = load i32, ptr %t122, align 4 + %t176 = add i32 %t175, 1 + store i32 %t176, ptr %t122, align 4 + %t126 = load i32, ptr %t125, align 4 + %t127 = icmp slt i32 %t176, %t126 + br i1 %t127, label %while.body.t119, label %while.exit.t120 + +merge.t196: ; preds = %while.body.t191.peel.next, %merge.t196 + %storemerge18107 = phi i32 [ 1, %while.body.t191.peel.next ], [ %t224, %merge.t196 ] + %15 = zext nneg i32 %storemerge18107 to i64 + %gep106 = getelementptr i32, ptr %invariant.gep, i64 %15 + %t202 = load i32, ptr %gep106, align 4 + %t204 = getelementptr i32, ptr %t43, i64 %15 + store i32 %t202, ptr %t204, align 4 + %t209 = getelementptr i32, ptr %t48, i64 %15 + %t210 = load i32, ptr %t209, align 4 + %t211 = add i32 %t210, %t202 + %t213 = getelementptr i32, ptr %t46, i64 %15 + store i32 %t211, ptr %t213, align 4 + %t218.pre = load i32, ptr %t204, align 4 + tail call void @radixSort(i32 %t215, ptr %a, i32 %t218.pre, i32 %t211) + %t224 = add nuw nsw i32 %storemerge18107, 1 + %t194 = icmp ult i32 %storemerge18107, 15 + br i1 %t194, label %merge.t196, label %cleanup.t44, !llvm.loop !0 +} + +define noundef i32 @main() local_unnamed_addr { +entry: + %t231 = tail call i32 @getarray(ptr nonnull @a) + tail call void @starttime() + tail call void @radixSort(i32 8, ptr nonnull @a, i32 0, i32 %t231) + %ans.promoted = load i32, ptr @ans, align 4 + %t2427 = icmp sgt i32 %t231, 0 + br i1 %t2427, label %while.body.t238, label %while.exit.t239 + +while.body.t238: ; preds = %entry, %while.body.t238 + %t236_i.09 = phi i32 [ %t254, %while.body.t238 ], [ 0, %entry ] + %t25268 = phi i32 [ %t252, %while.body.t238 ], [ %ans.promoted, %entry ] + %0 = zext nneg i32 %t236_i.09 to i64 + %t246 = getelementptr [30000010 x i32], ptr @a, i64 0, i64 %0 + %t247 = load i32, ptr %t246, align 4 + %t249 = add nuw i32 %t236_i.09, 2 + %t250 = srem i32 %t247, %t249 + %t251 = mul i32 %t250, %t236_i.09 + %t252 = add i32 %t251, %t25268 + %t254 = add nuw nsw i32 %t236_i.09, 1 + %t242 = icmp slt i32 %t254, %t231 + br i1 %t242, label %while.body.t238, label %while.cond.t237.while.exit.t239_crit_edge + +while.cond.t237.while.exit.t239_crit_edge: ; preds = %while.body.t238 + store i32 %t252, ptr @ans, align 4 + br label %while.exit.t239 + +while.exit.t239: ; preds = %while.cond.t237.while.exit.t239_crit_edge, %entry + %t257 = phi i32 [ %t252, %while.cond.t237.while.exit.t239_crit_edge ], [ %ans.promoted, %entry ] + %t258 = icmp slt i32 %t257, 0 + br i1 %t258, label %then.t255, label %merge.t256 + +then.t255: ; preds = %while.exit.t239 + %t260 = sub i32 0, %t257 + store i32 %t260, ptr @ans, align 4 + br label %merge.t256 + +merge.t256: ; preds = %then.t255, %while.exit.t239 + tail call void @stoptime() + %t262 = load i32, ptr @ans, align 4 + tail call void @putint(i32 %t262) + tail call void @putch(i32 10) + ret i32 0 +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.smax.i32(i32, i32) #2 + +attributes #0 = { nofree norecurse nosync nounwind memory(argmem: read) } +attributes #1 = { nofree norecurse nosync nounwind memory(none) } +attributes #2 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } + +!0 = distinct !{!0, !1} +!1 = !{!"llvm.loop.peeled.count", i32 1} diff --git a/run.sh b/scripts/run.sh similarity index 85% rename from run.sh rename to scripts/run.sh index cbd493f..e016839 100755 --- a/run.sh +++ b/scripts/run.sh @@ -35,7 +35,7 @@ total=0 passed=0 failed=0 -echo "开始测试 SysY 解析..." +echo "开始测试 ir out 解析..." echo "输出将保存到 $RESULT_FILE" echo "------------------------" @@ -59,10 +59,12 @@ for file in "${TEST_FILES[@]}"; do echo "========== $file ==========" >> "$RESULT_FILE" if [ $VERBOSE -eq 1 ]; then - "$COMPILER" --emit-parse-tree "$file" 2>&1 | tee -a "$RESULT_FILE" + # "$COMPILER" --emit-parse-tree "$file" 2>&1 | tee -a "$RESULT_FILE" + "$COMPILER" --emit-ir "$file" 2>&1 | tee -a "$RESULT_FILE" result=${PIPESTATUS[0]} else - "$COMPILER" --emit-parse-tree "$file" >> "$RESULT_FILE" 2>&1 + # "$COMPILER" --emit-parse-tree "$file" >> "$RESULT_FILE" 2>&1 + "$COMPILER" --emit-ir "$file" >> "$RESULT_FILE" 2>&1 result=$? fi diff --git a/scripts/test_compiler.sh b/scripts/test_compiler.sh new file mode 100755 index 0000000..ac98f68 --- /dev/null +++ b/scripts/test_compiler.sh @@ -0,0 +1,233 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +COMPILER="$ROOT_DIR/build/bin/compiler" +TMP_DIR="$ROOT_DIR/build/test_compiler" +RESULT_BASE_DIR="$ROOT_DIR/test/test_result" +TEST_DIRS=("$ROOT_DIR/test/test_case/functional" "$ROOT_DIR/test/test_case/performance") +CC_BIN="${CC:-cc}" +RUNTIME_SRC="$ROOT_DIR/sylib/sylib.c" +RUNTIME_OBJ="$TMP_DIR/sylib.o" +LLC_BIN="${LLC:-llc}" +CLANG_BIN="${CLANG:-clang}" + +if [[ ! -x "$COMPILER" ]]; then + echo "未找到编译器: $COMPILER" + echo "请先构建编译器,例如: mkdir -p build && cd build && cmake .. && make -j" + exit 1 +fi + +mkdir -p "$TMP_DIR" + +if ! command -v "$LLC_BIN" >/dev/null 2>&1; then + echo "未找到 llc: $LLC_BIN" + echo "请安装 LLVM,或通过 LLC 环境变量指定 llc 路径" + exit 1 +fi + +if ! command -v "$CLANG_BIN" >/dev/null 2>&1; then + echo "未找到 clang: $CLANG_BIN" + echo "请安装 Clang,或通过 CLANG 环境变量指定 clang 路径" + exit 1 +fi + +# 编译运行库(供链接生成的可执行文件) +runtime_ready=0 +if [[ -f "$RUNTIME_SRC" ]]; then + if "$CC_BIN" -c "$RUNTIME_SRC" -o "$RUNTIME_OBJ" >/dev/null 2>&1; then + runtime_ready=1 + else + echo "[WARN] 运行库编译失败,生成的可执行文件将不链接 sylib: $RUNTIME_SRC" + fi +else + echo "[WARN] 未找到运行库源码: $RUNTIME_SRC" +fi + +ir_total=0 +ir_pass=0 +result_total=0 +result_pass=0 + +ir_failures=() +result_failures=() + +function normalize_file() { + sed 's/\r$//' "$1" +} + +for test_dir in "${TEST_DIRS[@]}"; do + if [[ ! -d "$test_dir" ]]; then + echo "跳过不存在的测试目录: $test_dir" + continue + fi + + shopt -s nullglob + for input in "$test_dir"/*.sy; do + ir_total=$((ir_total+1)) + base=$(basename "$input") + stem=${base%.sy} + case "$(basename "$test_dir")" in + functional) + out_dir="$RESULT_BASE_DIR/function/ir" + ;; + performance) + out_dir="$RESULT_BASE_DIR/performance/ir" + ;; + *) + out_dir="$RESULT_BASE_DIR/$(basename "$test_dir")" + ;; + esac + mkdir -p "$out_dir" + ll_file="$out_dir/$stem.ll" + stdout_file="$out_dir/$stem.stdout" + expected_file="$test_dir/$stem.out" + stdin_file="$test_dir/$stem.in" + + echo "[TEST] $input" + + # 编译并捕获所有输出 + compiler_status=0 + compiler_output="" + compiler_output=$("$COMPILER" --emit-ir "$input" 2>&1) || compiler_status=$? + + # 临时文件存储原始输出 + raw_ll="$out_dir/$stem.raw.ll" + printf '%s\n' "$compiler_output" > "$raw_ll" + + # 检查编译是否成功 + if [[ $compiler_status -ne 0 ]]; then + echo " [IR] 编译失败: 返回码 $compiler_status" + ir_failures+=("$input: compiler failed ($compiler_status)") + # 失败:保留原始输出(包含所有调试信息) + cp "$raw_ll" "$ll_file" + rm -f "$raw_ll" + continue + fi + + # 从混杂输出中提取 IR: + # - 顶层实体:define/declare/@global + # - 基本块标签 + # - 缩进的指令行 + # - 函数结束花括号 + grep -E '^(define |declare |@|[[:space:]]|})|^[A-Za-z_.$%][A-Za-z0-9_.$%]*:$' "$raw_ll" > "$ll_file" + + # 检查是否生成了有效函数定义 + if ! grep -qE '^define ' "$ll_file"; then + echo " [IR] 失败: 未生成有效函数定义" + ir_failures+=("$input: invalid IR output") + # 失败:保留原始输出 + cp "$raw_ll" "$ll_file" + rm -f "$raw_ll" + continue + fi + + # 可选:删除多余的空行 + sed -i '/^$/N;/\n$/D' "$ll_file" + + rm -f "$raw_ll" + + ir_pass=$((ir_pass+1)) + echo " [IR] 生成成功 (IR已保存到: $ll_file)" + + # 运行测试 + # 运行测试部分 + if [[ -f "$expected_file" ]]; then + result_total=$((result_total+1)) + + # 运行生成的可执行文件(优先链接运行库) + run_status=0 + obj_file="$out_dir/$stem.o" + exe_file="$out_dir/$stem" + if ! "$LLC_BIN" -filetype=obj "$ll_file" -o "$obj_file" > "$stdout_file" 2>&1; then + echo " [RUN] llc 失败" + result_failures+=("$input: llc failed") + continue + fi + + if [[ $runtime_ready -eq 1 ]]; then + if ! "$CLANG_BIN" "$obj_file" "$RUNTIME_OBJ" -o "$exe_file" >> "$stdout_file" 2>&1; then + echo " [RUN] clang 链接失败" + result_failures+=("$input: clang link failed") + continue + fi + else + if ! "$CLANG_BIN" "$obj_file" -o "$exe_file" >> "$stdout_file" 2>&1; then + echo " [RUN] clang 链接失败" + result_failures+=("$input: clang link failed") + continue + fi + fi + + if [[ -f "$stdin_file" ]]; then + "$exe_file" < "$stdin_file" > "$stdout_file" 2>&1 || run_status=$? + else + "$exe_file" > "$stdout_file" 2>&1 || run_status=$? + fi + + # 读取预期文件内容 + expected_content=$(normalize_file "$expected_file") + + # 判断预期文件是只包含退出码,还是包含输出+退出码 + if [[ "$expected_content" =~ ^[0-9]+$ ]]; then + # 只包含退出码 + expected=$expected_content + if [[ "$run_status" -eq "$expected" ]]; then + result_pass=$((result_pass+1)) + echo " [RUN] 返回值匹配: $run_status" + rm -f "$stdout_file" + else + echo " [RUN] 返回值不匹配: got $run_status, expected $expected" + result_failures+=("$input: exit code mismatch (got $run_status, expected $expected)") + fi + else + # 包含输出和退出码(最后一行是退出码) + expected_output=$(head -n -1 <<< "$expected_content") + expected_exit=$(tail -n 1 <<< "$expected_content") + actual_output=$(cat "$stdout_file") + + if [[ "$run_status" -eq "$expected_exit" ]] && [[ "$actual_output" == "$expected_output" ]]; then + result_pass=$((result_pass+1)) + echo " [RUN] 成功: 退出码和输出都匹配" + rm -f "$stdout_file" + else + echo " [RUN] 不匹配: 退出码 got $run_status, expected $expected_exit" + if [[ "$actual_output" != "$expected_output" ]]; then + echo " 输出不匹配" + fi + result_failures+=("$input: mismatch") + fi + fi + else + echo " [RUN] 未找到预期返回值文件 $expected_file,跳过结果验证" + fi + done + shopt -u nullglob +done + +# 输出统计 +cat < #include +#include namespace ir { @@ -15,10 +16,112 @@ ConstantInt* Context::GetConstInt(int v) { return inserted->second.get(); } +ConstantFloat* Context::GetConstFloat(float v) { + uint32_t key; + std::memcpy(&key, &v, sizeof(float)); + + auto it = const_floats_.find(key); + if (it != const_floats_.end()) { + return it->second.get(); + } + + auto float_ty = Type::GetFloatType(); + auto constant = std::make_unique(float_ty, v); + auto* ptr = constant.get(); + const_floats_[key] = std::move(constant); + return ptr; +} + +ConstantArray* Context::GetConstArray(std::shared_ptr ty, + std::vector elements) { + // 验证数组常量 + size_t expected_size = ty->GetElementCount(); + if (elements.size() != expected_size) { + // 如果元素数量不匹配,可能需要补零或报错 + // 这里根据需求处理 + if (elements.size() < expected_size) { + // 补零 + auto elem_type = ty->GetElementType(); + while (elements.size() < expected_size) { + if (elem_type->IsInt32()) { + elements.push_back(GetConstInt(0)); + } else if (elem_type->IsFloat()) { + elements.push_back(GetConstFloat(0.0f)); + } + } + } else { + throw std::runtime_error("Array constant size mismatch"); + } + } + + // 构建缓存键 + struct ArrayKey { + std::shared_ptr type; + std::vector elements; + + bool operator==(const ArrayKey& other) const { + if (type != other.type) return false; + if (elements.size() != other.elements.size()) return false; + for (size_t i = 0; i < elements.size(); ++i) { + if (elements[i] != other.elements[i]) return false; + } + return true; + } + }; + + struct ArrayKeyHash { + size_t operator()(const ArrayKey& key) const { + size_t hash = std::hash{}(key.type.get()); + for (auto* elem : key.elements) { + hash ^= std::hash{}(elem) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + } + return hash; + } + }; + + // 使用静态缓存(需要作为成员变量) + static std::unordered_map, ArrayKeyHash> cache; + + ArrayKey key{ty, elements}; + auto it = cache.find(key); + if (it != cache.end()) { + return it->second.get(); + } + + auto constant = std::make_unique(ty, std::move(elements)); + auto* ptr = constant.get(); + cache[std::move(key)] = std::move(constant); + return ptr; +} + +ConstantZero* Context::GetZeroConstant(std::shared_ptr ty) { + auto it = zero_constants_.find(ty.get()); + if (it != zero_constants_.end()) { + return it->second.get(); + } + + auto constant = std::make_unique(ty); + auto* ptr = constant.get(); + zero_constants_[ty.get()] = std::move(constant); + return ptr; +} + +ConstantAggregateZero* Context::GetAggregateZero(std::shared_ptr ty) { + auto it = aggregate_zeros_.find(ty.get()); + if (it != aggregate_zeros_.end()) { + return it->second.get(); + } + + auto constant = std::make_unique(ty); + auto* ptr = constant.get(); + aggregate_zeros_[ty.get()] = std::move(constant); + return ptr; +} + std::string Context::NextTemp() { std::ostringstream oss; - oss << "%" << ++temp_index_; + oss << "%t" << ++temp_index_; return oss.str(); } -} // namespace ir +} // namespace ir \ No newline at end of file diff --git a/src/ir/Function.cpp b/src/ir/Function.cpp index cf14d48..3652f77 100644 --- a/src/ir/Function.cpp +++ b/src/ir/Function.cpp @@ -5,10 +5,21 @@ namespace ir { -Function::Function(std::string name, std::shared_ptr ret_type) - : Value(std::move(ret_type), std::move(name)) { +Function::Function(std::string name, std::shared_ptr func_type) + : Value(std::move(func_type), std::move(name)) { entry_ = CreateBlock("entry"); } +// 向函数添加参数的实现。 +Argument* Function::AddArgument(std::unique_ptr arg) {// 独占所有权,自动释放内存 + if (!arg) return nullptr; // 1. 检查参数是否为空 + auto* ptr = arg.get(); // 2. 获取原始指针(用于返回) + arguments_.push_back(std::move(arg)); // 3. 将参数所有权转移到函数的参数列表中,arg已经是空指针了,不能再使用arg了 + return ptr; // 4. 返回参数指针,方便调用者使用 +} +// 获取函数参数列表的实现。 +const std::vector>& Function::GetArguments() const { + return arguments_; +} BasicBlock* Function::CreateBlock(const std::string& name) { auto block = std::make_unique(name); diff --git a/src/ir/GlobalValue.cpp b/src/ir/GlobalValue.cpp index 7c2abe1..24686b1 100644 --- a/src/ir/GlobalValue.cpp +++ b/src/ir/GlobalValue.cpp @@ -1,11 +1,195 @@ -// GlobalValue 占位实现: -// - 具体的全局初始化器、打印和链接语义需要自行补全 - +// ir/GlobalValue.cpp #include "ir/IR.h" +#include + namespace ir { +namespace { + +ConstantValue* GetScalarZeroConstant(const Type& type) { + if (type.IsInt32()) { + static ConstantInt* zero_i32 = new ConstantInt(Type::GetInt32Type(), 0); + return zero_i32; + } + if (type.IsFloat()) { + static ConstantFloat* zero_f32 = new ConstantFloat(Type::GetFloatType(), 0.0f); + return zero_f32; + } + if (type.IsInt1()) { + static ConstantInt* zero_i1 = new ConstantInt(Type::GetInt1Type(), 0); + return zero_i1; + } + return nullptr; +} + +} // namespace + GlobalValue::GlobalValue(std::shared_ptr ty, std::string name) : User(std::move(ty), std::move(name)) {} -} // namespace ir +void GlobalValue::SetInitializer(ConstantValue* init) { + if (!init) { + throw std::runtime_error("GlobalValue::SetInitializer: init is null"); + } + + // 获取实际的值类型(用于类型检查) + std::shared_ptr value_type = GetValueType(); + + // 类型检查 + bool type_match = CheckTypeCompatibility(value_type, init); + + if (!type_match) { + throw std::runtime_error("GlobalValue::SetInitializer: type mismatch"); + } + + initializer_.clear(); + initializer_.push_back(init); +} + +void GlobalValue::SetInitializer(const std::vector& init) { + if (init.empty()) { + initializer_.clear(); + return; + } + + // 获取实际的值类型 + std::shared_ptr value_type = GetValueType(); + + // 类型检查 + if (value_type->IsArray()) { + auto* array_ty = static_cast(value_type.get()); + size_t array_size = array_ty->GetElementCount(); + + if (init.size() > array_size) { + throw std::runtime_error("GlobalValue::SetInitializer: too many initializers"); + } + + // 检查每个初始化值的类型 + auto* elem_type = array_ty->GetElementType().get(); + for (size_t i = 0; i < init.size(); ++i) { + auto* elem = init[i]; + if (!elem) { + throw std::runtime_error("GlobalValue::SetInitializer: null initializer at index " + std::to_string(i)); + } + + bool elem_match = false; + if (elem_type->IsInt32() && elem->GetType()->IsInt32()) { + elem_match = true; + } else if (elem_type->IsFloat() && elem->GetType()->IsFloat()) { + elem_match = true; + } else if (elem_type->IsInt1() && elem->GetType()->IsInt1()) { + elem_match = true; + } + + if (!elem_match) { + throw std::runtime_error("GlobalValue::SetInitializer: element type mismatch at index " + std::to_string(i)); + } + } + } + else if (value_type->IsInt32() || value_type->IsFloat() || value_type->IsInt1()) { + if (init.size() != 1) { + throw std::runtime_error("GlobalValue::SetInitializer: scalar requires exactly one initializer"); + } + + if (!init[0]) { + throw std::runtime_error("GlobalValue::SetInitializer: null initializer"); + } + + if ((value_type->IsInt32() && !init[0]->GetType()->IsInt32()) || + (value_type->IsFloat() && !init[0]->GetType()->IsFloat()) || + (value_type->IsInt1() && !init[0]->GetType()->IsInt1())) { + throw std::runtime_error("GlobalValue::SetInitializer: type mismatch"); + } + } + else { + throw std::runtime_error("GlobalValue::SetInitializer: unsupported type"); + } + + initializer_ = init; +} + +// 辅助方法:获取实际的值类型(处理指针包装) +std::shared_ptr GlobalValue::GetValueType() const { + if (GetType()->IsPtrInt32()) { + return Type::GetInt32Type(); + } else if (GetType()->IsPtrFloat()) { + return Type::GetFloatType(); + } else if (GetType()->IsPtrInt1()) { + return Type::GetInt1Type(); + } + return GetType(); +} + +// 辅助方法:检查类型兼容性 +bool GlobalValue::CheckTypeCompatibility(std::shared_ptr value_type, + ConstantValue* init) const { + // 检查标量类型 + if (value_type->IsInt32() && init->GetType()->IsInt32()) { + return true; + } else if (value_type->IsFloat() && init->GetType()->IsFloat()) { + return true; + } else if (value_type->IsInt1() && init->GetType()->IsInt1()) { + return true; + } + // 检查数组类型:允许用单个标量初始化整个数组 + else if (value_type->IsArray()) { + auto* array_ty = static_cast(value_type.get()); + auto* elem_type = array_ty->GetElementType().get(); + + if (elem_type->IsInt32() && init->GetType()->IsInt32()) { + return true; + } else if (elem_type->IsFloat() && init->GetType()->IsFloat()) { + return true; + } else if (elem_type->IsInt1() && init->GetType()->IsInt1()) { + return true; + } + // 也可以允许 ConstantArray 作为初始化器 + else if (init->GetType()->IsArray()) { + auto* init_array = static_cast(init); + return init_array->IsValid(); + } + } + // 检查指针类型(用于数组参数) + else if (value_type->IsPtrInt32() && init->GetType()->IsInt32()) { + return true; + } else if (value_type->IsPtrFloat() && init->GetType()->IsFloat()) { + return true; + } + + return false; +} + +// 添加获取数组元素的便捷方法 +ConstantValue* GlobalValue::GetArrayElement(size_t index) const { + if (!GetType()->IsArray()) { + return nullptr; + } + + auto* array_ty = dynamic_cast(GetType().get()); + if (!array_ty) { + return nullptr; + } + if (index >= static_cast(array_ty->GetElementCount())) { + return nullptr; + } + if (index >= initializer_.size()) { + return GetScalarZeroConstant(*array_ty->GetElementType()); + } + return initializer_[index]; +} + +// 添加获取数组元素数量的方法 +size_t GlobalValue::GetArraySize() const { + if (!IsArrayConstant()) { + return 0; + } + return initializer_.size(); +} + +// 添加判断是否为数组常量的方法 +bool GlobalValue::IsArrayConstant() const { + return GetType()->IsArray() && !initializer_.empty(); +} + +} // namespace ir \ No newline at end of file diff --git a/src/ir/IRBuilder.cpp b/src/ir/IRBuilder.cpp index 90f03c4..9b7c545 100644 --- a/src/ir/IRBuilder.cpp +++ b/src/ir/IRBuilder.cpp @@ -5,7 +5,7 @@ #include "ir/IR.h" #include - +#include #include "utils/Log.h" namespace ir { @@ -21,6 +21,21 @@ ConstantInt* IRBuilder::CreateConstInt(int v) { return ctx_.GetConstInt(v); } +// IRBuilder 方法实现 +ConstantFloat* IRBuilder::CreateConstFloat(float v) { + return ctx_.GetConstFloat(v); +} + +ConstantArray* IRBuilder::CreateConstArray(std::shared_ptr ty, + std::vector elements) { + return ctx_.GetConstArray(ty, std::move(elements)); +} + +ConstantZero* IRBuilder::CreateZeroConstant(std::shared_ptr ty) { + return ctx_.GetZeroConstant(ty); +} + + BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name) { if (!insert_block_) { @@ -34,7 +49,69 @@ BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs, throw std::runtime_error( FormatError("ir", "IRBuilder::CreateBinary 缺少 rhs")); } - return insert_block_->Append(op, lhs->GetType(), lhs, rhs, name); + + // 检查操作码是否为有效的二元操作符 + switch (op) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Div: + case Opcode::Mod: + case Opcode::And: + case Opcode::Or: + // 添加浮点操作码 + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: + // 有效的二元操作符 + break; + case Opcode::Not: + // Not是一元操作符,不应该在BinaryInst中 + throw std::runtime_error(FormatError("ir", "Not是一元操作符,应使用其他指令")); + default: + throw std::runtime_error(FormatError("ir", "BinaryInst 不支持的操作码")); + } + + // 确定结果类型 + std::shared_ptr result_type; + + // 检查操作数类型是否相同 + if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind()) { + throw std::runtime_error( + FormatError("ir", "CreateBinary 操作数类型不匹配")); + } + + // 检查是否为浮点操作 + bool is_float_op = (op == Opcode::FAdd || op == Opcode::FSub || + op == Opcode::FMul || op == Opcode::FDiv); + + if (is_float_op) { + // 浮点操作要求操作数是浮点类型 + if (!lhs->GetType()->IsFloat()) { + throw std::runtime_error( + FormatError("ir", "浮点运算要求操作数为浮点类型")); + } + result_type = lhs->GetType(); + } else { + bool is_logical = (op == Opcode::And || op == Opcode::Or); + + if (is_logical) { + // 逻辑运算的结果是 int32(布尔值) + result_type = Type::GetInt32Type(); + } else { + // 算术运算的结果类型与操作数相同 + result_type = lhs->GetType(); + } + + // 检查操作数类型是否支持 + if (!lhs->GetType()->IsInt32() && !lhs->GetType()->IsFloat()) { + throw std::runtime_error( + FormatError("ir", "CreateBinary 只支持 int32 和 float 类型")); + } + } + + return insert_block_->Append(op, result_type, lhs, rhs, name); } BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs, @@ -42,6 +119,17 @@ BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs, return CreateBinary(Opcode::Add, lhs, rhs, name); } +AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr ty, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!ty) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateAlloca 缺少类型")); + } + return insert_block_->Append(ty, name); +} + AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); @@ -49,6 +137,13 @@ AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { return insert_block_->Append(Type::GetPtrInt32Type(), name); } +AllocaInst* IRBuilder::CreateAllocaFloat(const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Type::GetPtrFloatType(), name); +} + LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); @@ -57,9 +152,32 @@ LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { throw std::runtime_error( FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr")); } - return insert_block_->Append(Type::GetInt32Type(), ptr, name); + auto ptr_ty = ptr->GetType(); + std::shared_ptr elem_ty; + + if (ptr_ty->IsPtrInt32()) { + elem_ty = Type::GetInt32Type(); + } else if (ptr_ty->IsPtrFloat()) { + elem_ty = Type::GetFloatType(); + } else if (ptr_ty->IsPtrInt1()) { + elem_ty = Type::GetInt1Type(); + } else if (ptr_ty->IsArray()) { + // 数组类型的指针,元素类型是数组元素类型 + auto* array_ty = dynamic_cast(ptr_ty.get()); + if (array_ty) { + elem_ty = array_ty->GetElementType(); + } else { + throw std::runtime_error(FormatError("ir", "不支持的指针类型")); + } + } else { + // 尝试其他指针类型 + throw std::runtime_error(FormatError("ir", "不支持的指针类型")); + } + + return insert_block_->Append(elem_ty, ptr, name); } + StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); @@ -72,6 +190,35 @@ StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { throw std::runtime_error( FormatError("ir", "IRBuilder::CreateStore 缺少 ptr")); } + + // 检查类型兼容性 + auto ptr_ty = ptr->GetType(); + auto val_ty = val->GetType(); + + if (ptr_ty->IsPtrInt32()) { + if (!val_ty->IsInt32()) { + throw std::runtime_error(FormatError("ir", "存储类型不匹配:期望 int32")); + } + } else if (ptr_ty->IsPtrFloat()) { + if (!val_ty->IsFloat()) { + throw std::runtime_error( + FormatError("ir", "存储类型不匹配:期望 float, 实际 kind=" + + std::to_string(static_cast(val_ty->GetKind())))); + } + } else if (ptr_ty->IsArray()) { + // 数组存储支持两种形式: + // 1. 标量元素写入(通常配合 GEP 后落到元素指针,不会走到这里) + // 2. 聚合数组整体写入,例如 `store [16 x i32] zeroinitializer, [16 x i32]* %arr` + if (!val_ty->IsArray()) { + throw std::runtime_error( + FormatError("ir", "数组地址仅支持聚合数组整体存储")); + } + if (val_ty->GetKind() != ptr_ty->GetKind()) { + throw std::runtime_error( + FormatError("ir", "聚合数组存储类型不匹配")); + } + } + return insert_block_->Append(Type::GetVoidType(), val, ptr); } @@ -79,11 +226,445 @@ ReturnInst* IRBuilder::CreateRet(Value* v) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); } - if (!v) { + return insert_block_->Append(Type::GetVoidType(), v); +} + +BranchInst* IRBuilder::CreateBr(BasicBlock* target) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!target) { throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateRet 缺少返回值")); + FormatError("ir", "IRBuilder::CreateBr 缺少 target")); } - return insert_block_->Append(Type::GetVoidType(), v); + return insert_block_->Append(Type::GetVoidType(), target); +} + +BranchInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* true_target, + BasicBlock* false_target) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!cond) { + throw std::runtime_error( + FormatError("ir", "IRBuilder::CreateCondBr 缺少 cond")); + } + if (!true_target) { + throw std::runtime_error( + FormatError("ir", "IRBuilder::CreateCondBr 缺少 true_target")); + } + if (!false_target) { + throw std::runtime_error( + FormatError("ir", "IRBuilder::CreateCondBr 缺少 false_target")); + } + return insert_block_->Append(Type::GetVoidType(), cond, true_target, false_target); +} + +// 创建整数相等比较 +IcmpInst* IRBuilder::CreateICmpEQ(Value* lhs, Value* rhs, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!lhs || !rhs) { + throw std::runtime_error( + FormatError("ir", "IRBuilder::CreateICmpEQ 缺少操作数")); + } + // 检查类型必须一致 + if (lhs->GetType() != rhs->GetType()) { + throw std::runtime_error( + FormatError("ir", "比较操作数类型不匹配")); + } + return insert_block_->Append(IcmpInst::Predicate::EQ, lhs, rhs, + Type::GetInt1Type(), name); +} + +// 创建整数不等比较 +IcmpInst* IRBuilder::CreateICmpNE(Value* lhs, Value* rhs, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!lhs || !rhs) { + throw std::runtime_error( + FormatError("ir", "IRBuilder::CreateICmpNE 缺少操作数")); + } + if (lhs->GetType() != rhs->GetType()) { + throw std::runtime_error( + FormatError("ir", "比较操作数类型不匹配")); + } + return insert_block_->Append(IcmpInst::Predicate::NE, lhs, rhs, + Type::GetInt1Type(), name); +} + +// 创建整数小于比较 +IcmpInst* IRBuilder::CreateICmpLT(Value* lhs, Value* rhs, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!lhs || !rhs) { + throw std::runtime_error( + FormatError("ir", "IRBuilder::CreateICmpLT 缺少操作数")); + } + if (lhs->GetType() != rhs->GetType()) { + throw std::runtime_error( + FormatError("ir", "比较操作数类型不匹配")); + } + return insert_block_->Append(IcmpInst::Predicate::LT, lhs, rhs, + Type::GetInt1Type(), name); +} + +// 创建整数小于等于比较 +IcmpInst* IRBuilder::CreateICmpLE(Value* lhs, Value* rhs, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!lhs || !rhs) { + throw std::runtime_error( + FormatError("ir", "IRBuilder::CreateICmpLE 缺少操作数")); + } + if (lhs->GetType() != rhs->GetType()) { + throw std::runtime_error( + FormatError("ir", "比较操作数类型不匹配")); + } + return insert_block_->Append(IcmpInst::Predicate::LE, lhs, rhs, + Type::GetInt1Type(), name); +} + +// 创建整数大于比较 +IcmpInst* IRBuilder::CreateICmpGT(Value* lhs, Value* rhs, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!lhs || !rhs) { + throw std::runtime_error( + FormatError("ir", "IRBuilder::CreateICmpGT 缺少操作数")); + } + if (lhs->GetType() != rhs->GetType()) { + throw std::runtime_error( + FormatError("ir", "比较操作数类型不匹配")); + } + return insert_block_->Append(IcmpInst::Predicate::GT, lhs, rhs, + Type::GetInt1Type(), name); +} + +// 创建整数大于等于比较 +IcmpInst* IRBuilder::CreateICmpGE(Value* lhs, Value* rhs, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!lhs || !rhs) { + throw std::runtime_error( + FormatError("ir", "IRBuilder::CreateICmpGE 缺少操作数")); + } + if (lhs->GetType() != rhs->GetType()) { + throw std::runtime_error( + FormatError("ir", "比较操作数类型不匹配")); + } + return insert_block_->Append(IcmpInst::Predicate::GE, lhs, rhs, + Type::GetInt1Type(), name); +} + +// 创建零扩展指令 +ZExtInst* IRBuilder::CreateZExt(Value* value, std::shared_ptr target_ty, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!value) { + throw std::runtime_error( + FormatError("ir", "IRBuilder::CreateZExt 缺少 value")); + } + if (!target_ty) { + throw std::runtime_error( + FormatError("ir", "IRBuilder::CreateZExt 缺少 target_ty")); + } + + auto src_ty = value->GetType(); + // 类型检查:源类型应该是较小的整数类型 + if (!src_ty->IsInt1() && !src_ty->IsInt32()) { + throw std::runtime_error( + FormatError("ir", "ZExt 源类型必须是整数类型")); + } + // 目标类型应该是较大的整数类型 + if (!target_ty->IsInt32()) { + throw std::runtime_error( + FormatError("ir", "ZExt 目标类型必须是整数类型")); + } + + const std::string inst_name = name.empty() ? ctx_.NextTemp() : name; + return insert_block_->Append(value, target_ty, inst_name); +} + +// 创建截断指令 +TruncInst* IRBuilder::CreateTrunc(Value* value, std::shared_ptr target_ty, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!value) { + throw std::runtime_error( + FormatError("ir", "IRBuilder::CreateTrunc 缺少 value")); + } + if (!target_ty) { + throw std::runtime_error( + FormatError("ir", "IRBuilder::CreateTrunc 缺少 target_ty")); + } + + auto src_ty = value->GetType(); + // 类型检查:源类型应该是较大的整数类型 + if (!src_ty->IsInt32()) { + throw std::runtime_error( + FormatError("ir", "Trunc 源类型必须是整数类型")); + } + // 目标类型应该是较小的整数类型 + if (!target_ty->IsInt1() && !target_ty->IsInt32()) { + throw std::runtime_error( + FormatError("ir", "Trunc 目标类型必须是整数类型")); + } + + const std::string inst_name = name.empty() ? ctx_.NextTemp() : name; + return insert_block_->Append(value, target_ty, inst_name); +} + +// 便捷方法:i1 转 i32 +ZExtInst* IRBuilder::CreateZExtI1ToI32(Value* value, const std::string& name) { + return CreateZExt(value, Type::GetInt32Type(), name); +} + +// 便捷方法:i32 转 i1 +TruncInst* IRBuilder::CreateTruncI32ToI1(Value* value, const std::string& name) { + return CreateTrunc(value, Type::GetInt1Type(), name); +} + +BinaryInst* IRBuilder::CreateDiv(Value* lhs, Value* rhs, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!lhs) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateDiv 缺少 lhs")); + } + if (!rhs) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateDiv 缺少 rhs")); + } + return insert_block_->Append(Opcode::Div, lhs->GetType(), lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateMod(Value* lhs, Value* rhs, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!lhs) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateMod 缺少 lhs")); + } + if (!rhs) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateMod 缺少 rhs")); + } + return insert_block_->Append(Opcode::Mod, lhs->GetType(), lhs, rhs, name); +} + + +BinaryInst* IRBuilder::CreateAnd(Value* lhs, Value* rhs, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!lhs) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateAnd 缺少 lhs")); + } + if (!rhs) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateAnd 缺少 rhs")); + } + auto result_ty = lhs->GetType()->IsInt1() ? Type::GetInt1Type() + : Type::GetInt32Type(); + return insert_block_->Append(Opcode::And, result_ty, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateOr(Value* lhs, Value* rhs, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!lhs) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateOr 缺少 lhs")); + } + if (!rhs) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateOr 缺少 rhs")); + } + auto result_ty = lhs->GetType()->IsInt1() ? Type::GetInt1Type() + : Type::GetInt32Type(); + return insert_block_->Append(Opcode::Or, result_ty, lhs, rhs, name); +} + +IcmpInst* IRBuilder::CreateNot(Value* val, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!val) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateNot 缺少 operand")); + } + if (val->GetType()->IsInt1()) { + auto* ext = CreateZExtI1ToI32(val, ""); + auto* zero = CreateConstInt(0); + return CreateICmpEQ(ext, zero, name); + } + auto* zero = CreateConstInt(0); + return CreateICmpEQ(val, zero, name); +} + +GEPInst* IRBuilder::CreateGEP(Value* base, + const std::vector& indices, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!base) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateGEP 缺少 base")); + } + + // 检查所有索引 + for (size_t i = 0; i < indices.size(); ++i) { + if (!indices[i]) { + throw std::runtime_error( + FormatError("ir", "IRBuilder::CreateGEP 索引 " + std::to_string(i) + " 为空")); + } + } + + // 结果类型推断: + // - 对 i32*/float* 基址,结果仍分别为 i32*/float* + // - 对数组基址,按多索引向下剥离元素类型;若到达标量则返回对应标量指针 + // (本项目没有“指向数组的指针类型”,未完全剥离时退回数组类型) + std::shared_ptr result_ty = base->GetType(); + if (base->GetType()->IsPtrInt32()) { + result_ty = Type::GetPtrInt32Type(); + } else if (base->GetType()->IsPtrFloat()) { + result_ty = Type::GetPtrFloatType(); + } else if (base->GetType()->IsArray()) { + std::shared_ptr cur = base->GetType(); + for (size_t i = 1; i < indices.size(); ++i) { + auto* at = dynamic_cast(cur.get()); + if (!at) break; + cur = at->GetElementType(); + } + + if (cur->IsInt32()) result_ty = Type::GetPtrInt32Type(); + else if (cur->IsFloat()) result_ty = Type::GetPtrFloatType(); + else result_ty = cur; + } + + return insert_block_->Append(result_ty, base, indices, name); +} + + +BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!lhs) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateMul 缺少 lhs")); + } + if (!rhs) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateMul 缺少 rhs")); + } + return CreateBinary(Opcode::Mul, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!lhs) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateSub 缺少 lhs")); + } + if (!rhs) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateSub 缺少 rhs")); + } + return CreateBinary(Opcode::Sub, lhs, rhs, name); +} + +// 注意:当前 CreateCall 仅支持直接调用 Function,且不支持变长参数列表等复杂特性。 +// 创建函数调用指令的实现,被调用的函数,参数列表,返回值临时变量名 +CallInst* IRBuilder::CreateCall(Function* callee, + const std::vector& args, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!callee) { //被调用的函数不能为空 + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCall 缺少 callee")); + } + auto func_ty = std::static_pointer_cast(callee->GetType()); + auto ret_ty = func_ty->GetReturnType(); + return insert_block_->Append(ret_ty, callee, args, name); +} + + +BinaryInst* IRBuilder::CreateFAdd(Value* lhs, Value* rhs, const std::string& name) { + return insert_block_->Append(Opcode::FAdd, lhs->GetType(), lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateFSub(Value* lhs, Value* rhs, const std::string& name) { + return insert_block_->Append(Opcode::FSub, lhs->GetType(), lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateFMul(Value* lhs, Value* rhs, const std::string& name) { + return insert_block_->Append(Opcode::FMul, lhs->GetType(), lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateFDiv(Value* lhs, Value* rhs, const std::string& name) { + return insert_block_->Append(Opcode::FDiv, lhs->GetType(), lhs, rhs, name); +} + +// 浮点比较 +FcmpInst* IRBuilder::CreateFCmpOEQ(Value* lhs, Value* rhs, const std::string& name) { + return insert_block_->Append( + FcmpInst::Predicate::OEQ, lhs, rhs, Type::GetInt1Type(), name); +} + +FcmpInst* IRBuilder::CreateFCmpONE(Value* lhs, Value* rhs, const std::string& name) { + return insert_block_->Append( + FcmpInst::Predicate::ONE, lhs, rhs, Type::GetInt1Type(), name); +} + +FcmpInst* IRBuilder::CreateFCmpOLT(Value* lhs, Value* rhs, const std::string& name) { + return insert_block_->Append( + FcmpInst::Predicate::OLT, lhs, rhs, Type::GetInt1Type(), name); +} + +FcmpInst* IRBuilder::CreateFCmpOLE(Value* lhs, Value* rhs, const std::string& name) { + return insert_block_->Append( + FcmpInst::Predicate::OLE, lhs, rhs, Type::GetInt1Type(), name); +} + +FcmpInst* IRBuilder::CreateFCmpOGT(Value* lhs, Value* rhs, const std::string& name) { + return insert_block_->Append( + FcmpInst::Predicate::OGT, lhs, rhs, Type::GetInt1Type(), name); +} + +FcmpInst* IRBuilder::CreateFCmpOGE(Value* lhs, Value* rhs, const std::string& name) { + return insert_block_->Append( + FcmpInst::Predicate::OGE, lhs, rhs, Type::GetInt1Type(), name); +} + +// 类型转换 +SIToFPInst* IRBuilder::CreateSIToFP(Value* value, std::shared_ptr target_ty, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + const std::string inst_name = name.empty() ? ctx_.NextTemp() : name; + return insert_block_->Append(value, target_ty, inst_name); +} + +FPToSIInst* IRBuilder::CreateFPToSI(Value* value, std::shared_ptr target_ty, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + const std::string inst_name = name.empty() ? ctx_.NextTemp() : name; + return insert_block_->Append(value, target_ty, inst_name); } } // namespace ir diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 30efbb6..a3a6278 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -4,6 +4,9 @@ #include "ir/IR.h" +#include +#include +#include #include #include #include @@ -12,14 +15,134 @@ namespace ir { -static const char* TypeToString(const Type& ty) { +static std::string TypeToString(const Type& ty); + +static std::string ArrayTypeToStringFrom(const Type& base_ty, + const std::vector& dims, + size_t begin) { + std::string s = TypeToString(base_ty); + for (size_t i = dims.size(); i-- > begin;) { + s = "[" + std::to_string(dims[i]) + " x " + s + "]"; + } + return s; +} + +static bool IsZeroConstant(const ConstantValue* value) { + if (!value) { + return true; + } + if (auto* ci = dynamic_cast(value)) { + return ci->GetValue() == 0; + } + if (auto* cf = dynamic_cast(value)) { + return cf->GetValue() == 0.0f; + } + if (dynamic_cast(value) || + dynamic_cast(value)) { + return true; + } + if (auto* arr = dynamic_cast(value)) { + for (auto* elem : arr->GetElements()) { + if (!IsZeroConstant(elem)) { + return false; + } + } + return true; + } + return false; +} + +static size_t AggregateSpan(const std::vector& dims, size_t level) { + size_t span = 1; + for (size_t i = level; i < dims.size(); ++i) { + span *= static_cast(dims[i]); + } + return span; +} + +static bool IsZeroRange(const std::vector& init, + size_t begin, + size_t count) { + for (size_t i = 0; i < count; ++i) { + const size_t index = begin + i; + if (index >= init.size()) { + continue; + } + if (!IsZeroConstant(init[index])) { + return false; + } + } + return true; +} + +static void PrintFlatArrayBody(std::ostream& os, + const Type& base_ty, + const std::vector& dims, + size_t level, + const std::vector& init, + size_t& flat_index) { + const size_t span = AggregateSpan(dims, level); + if (IsZeroRange(init, flat_index, span)) { + os << "zeroinitializer"; + flat_index += span; + return; + } + + os << "["; + for (int i = 0; i < dims[level]; ++i) { + if (i > 0) os << ", "; + + if (level + 1 < dims.size()) { + os << ArrayTypeToStringFrom(base_ty, dims, level + 1) << " "; + PrintFlatArrayBody(os, base_ty, dims, level + 1, init, flat_index); + continue; + } + + os << TypeToString(base_ty) << " "; + if (flat_index < init.size() && init[flat_index]) { + if (auto* ci = dynamic_cast(init[flat_index])) { + os << ci->GetValue(); + } else if (auto* cf = dynamic_cast(init[flat_index])) { + os << cf->GetValue(); + } else if (IsZeroConstant(init[flat_index])) { + os << "0"; + } else { + os << "0"; + } + } else { + os << "0"; + } + ++flat_index; + } + os << "]"; +} + +static std::string TypeToString(const Type& ty) { switch (ty.GetKind()) { - case Type::Kind::Void: - return "void"; - case Type::Kind::Int32: - return "i32"; - case Type::Kind::PtrInt32: - return "i32*"; + case Type::Kind::Void: return "void"; + case Type::Kind::Int32: return "i32"; + case Type::Kind::Float: return "float"; + case Type::Kind::PtrInt32: return "i32*"; + case Type::Kind::PtrFloat: return "float*"; + case Type::Kind::Label: return "label"; + case Type::Kind::Function: return "function"; + case Type::Kind::Int1: return "i1"; + case Type::Kind::PtrInt1: return "i1*"; + case Type::Kind::Array: { + // 打印数组类型为 LLVM 风格,如 [4 x [2 x i32]] + auto* at = dynamic_cast(&ty); + if (!at) return "array"; + // 递归构建类型字符串 + std::string elem = TypeToString(*at->GetElementType()); + const auto& dims = at->GetDimensions(); + // 从外到内构建 + std::string s = elem; + for (auto it = dims.rbegin(); it != dims.rend(); ++it) { + s = "[" + std::to_string(*it) + " x " + s + "]"; + } + return s; + } + default: return "unknown"; } throw std::runtime_error(FormatError("ir", "未知类型")); } @@ -40,21 +163,182 @@ static const char* OpcodeToString(Opcode op) { return "store"; case Opcode::Ret: return "ret"; + case Opcode::Call: + return "call"; + case Opcode::Br: + return "br"; + case Opcode::CondBr: + return "condbr"; + case Opcode::Icmp: + return "icmp"; + case Opcode::Div: + return "sdiv"; + case Opcode::Mod: + return "srem"; + case Opcode::ZExt: + return "zext"; + case Opcode::Trunc: + return "trunc"; + case Opcode::And: + return "and"; + case Opcode::Or: + return "or"; + case Opcode::Not: + return "not"; + case Opcode::GEP: + return "getelementptr"; + case Opcode::FAdd: return "fadd"; + case Opcode::FSub: return "fsub"; + case Opcode::FMul: return "fmul"; + case Opcode::FDiv: return "fdiv"; + case Opcode::FCmp: return "fcmp"; + case Opcode::SIToFP: return "sitofp"; + case Opcode::FPToSI: return "fptosi"; + case Opcode::FPExt: return "fpext"; + case Opcode::FPTrunc: return "fptrunc"; } return "?"; } +// 将 float 值转为 LLVM IR 接受的 64-bit 十六进制浮点格式 +static std::string FloatToLLVMHex(float f) { + double d = static_cast(f); + uint64_t bits; + memcpy(&bits, &d, sizeof(bits)); + char buf[20]; + snprintf(buf, sizeof(buf), "0x%016llX", (unsigned long long)bits); + return buf; +} + static std::string ValueToString(const Value* v) { + if (!v) { + return ""; + } + if (dynamic_cast(v) || + dynamic_cast(v)) { + return "zeroinitializer"; + } if (auto* ci = dynamic_cast(v)) { return std::to_string(ci->GetValue()); } - return v ? v->GetName() : ""; + if (auto* cf = dynamic_cast(v)) { + return FloatToLLVMHex(cf->GetValue()); + } + const auto& name = v->GetName(); + if (name.empty()) { + return ""; + } + if (name[0] == '%' || name[0] == '@') { + return name; + } + if (dynamic_cast(v)) { + return "@" + name; + } + return "%" + name; +} + +static std::string MemoryTypeToString(const Type& ty) { + std::string text = TypeToString(ty); + if (ty.IsArray()) { + text += "*"; + } + return text; } void IRPrinter::Print(const Module& module, std::ostream& os) { + for (const auto& global : module.GetGlobals()) { + if (!global) continue; + os << "@" << global->GetName() << " = " + << (global->IsConstant() ? "constant " : "global "); + + if (global->GetType()->IsPtrInt32()) { + os << "i32 "; + if (global->HasInitializer()) { + auto* ci = dynamic_cast(global->GetInitializer().front()); + os << (ci ? ci->GetValue() : 0); + } else { + os << "0"; + } + os << "\n"; + continue; + } + + if (global->GetType()->IsPtrFloat()) { + os << "float "; + if (global->HasInitializer()) { + auto* cf = dynamic_cast(global->GetInitializer().front()); + os << (cf ? ValueToString(cf) : FloatToLLVMHex(0.0f)); + } else { + os << FloatToLLVMHex(0.0f); + } + os << "\n"; + continue; + } + + if (global->GetType()->IsArray()) { + auto* at = dynamic_cast(global->GetType().get()); + os << TypeToString(*global->GetType()) << " "; + if (!at || !global->HasInitializer() || + IsZeroRange(global->GetInitializer(), 0, AggregateSpan(at->GetDimensions(), 0))) { + os << "zeroinitializer\n"; + continue; + } + + size_t flat_index = 0; + PrintFlatArrayBody(os, + *at->GetElementType(), + at->GetDimensions(), + 0, + global->GetInitializer(), + flat_index); + os << "\n"; + continue; + } + + os << TypeToString(*global->GetType()) << " zeroinitializer\n"; + } + + auto print_func_params = [&](const Function* func, + const FunctionType* func_ty) { + bool first = true; + if (!func->GetArguments().empty()) { + for (const auto& arg : func->GetArguments()) { + if (!first) os << ", "; + first = false; + os << TypeToString(*arg->GetType()) << " %" << arg->GetName(); + } + return; + } + + for (const auto& pty : func_ty->GetParamTypes()) { + if (!first) os << ", "; + first = false; + os << TypeToString(*pty); + } + }; + + auto is_declaration_only = [](const Function* func) { + const auto& blocks = func->GetBlocks(); + if (blocks.size() != 1) return false; + const auto& only = blocks.front(); + if (!only) return false; + return only->GetInstructions().empty(); + }; + for (const auto& func : module.GetFunctions()) { - os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() - << "() {\n"; + auto* func_ty = static_cast(func->GetType().get()); + if (is_declaration_only(func.get())) { + os << "declare " << TypeToString(*func_ty->GetReturnType()) << " @" + << func->GetName() << "("; + print_func_params(func.get(), func_ty); + os << ")\n"; + continue; + } + + os << "define " << TypeToString(*func_ty->GetReturnType()) << " @" + << func->GetName() << "("; + print_func_params(func.get(), func_ty); + os << ") {\n"; for (const auto& bb : func->GetBlocks()) { if (!bb) { continue; @@ -65,7 +349,17 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { switch (inst->GetOpcode()) { case Opcode::Add: case Opcode::Sub: - case Opcode::Mul: { + case Opcode::Mul: + case Opcode::Div: + case Opcode::Mod: + case Opcode::And: + case Opcode::Not: + case Opcode::Or: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: + { auto* bin = static_cast(inst); os << " " << bin->GetName() << " = " << OpcodeToString(bin->GetOpcode()) << " " @@ -76,25 +370,189 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { } case Opcode::Alloca: { auto* alloca = static_cast(inst); - os << " " << alloca->GetName() << " = alloca i32\n"; + std::string elem_ty_str; + if (alloca->GetType()->IsPtrInt32()) { + elem_ty_str = "i32"; + } else if (alloca->GetType()->IsPtrFloat()) { + elem_ty_str = "float"; + } else { + elem_ty_str = TypeToString(*alloca->GetType()); + } + os << " " << alloca->GetName() << " = alloca " << elem_ty_str << "\n"; break; } case Opcode::Load: { auto* load = static_cast(inst); - os << " " << load->GetName() << " = load i32, i32* " + os << " " << load->GetName() << " = load " + << TypeToString(*load->GetType()) << ", " + << MemoryTypeToString(*load->GetPtr()->GetType()) << " " << ValueToString(load->GetPtr()) << "\n"; break; } case Opcode::Store: { auto* store = static_cast(inst); - os << " store i32 " << ValueToString(store->GetValue()) - << ", i32* " << ValueToString(store->GetPtr()) << "\n"; + os << " store " << TypeToString(*store->GetValue()->GetType()) << " " + << ValueToString(store->GetValue()) + << ", " << MemoryTypeToString(*store->GetPtr()->GetType()) << " " + << ValueToString(store->GetPtr()) << "\n"; break; } case Opcode::Ret: { auto* ret = static_cast(inst); - os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " - << ValueToString(ret->GetValue()) << "\n"; + if (!ret->GetValue()) { + os << " ret void\n"; + } else { + os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " + << ValueToString(ret->GetValue()) << "\n"; + } + break; + } + // CallInst类在 include/ir/IR.h 中定义 + case Opcode::Call: { + auto* call = static_cast(inst); + os << " "; + if (!call->GetType()->IsVoid()) { + os << call->GetName() << " = "; + } + os << "call " << TypeToString(*call->GetType()) << " @" + << call->GetCallee()->GetName() << "("; + bool first = true; + for (auto* arg : call->GetArgs()) { + if (!first) os << ", "; + first = false; + os << TypeToString(*arg->GetType()) << " " << ValueToString(arg); + } + os << ")\n"; + break; + } + + // 在 IRPrinter.cpp 的 switch 语句中添加 + case Opcode::Br: + case Opcode::CondBr: { + auto* br = static_cast(inst); + if (!br->IsConditional()) { + os << " br label %" << br->GetTarget()->GetName() << "\n"; + } else { + os << " br i1 " << ValueToString(br->GetCondition()) + << ", label %" << br->GetTrueTarget()->GetName() + << ", label %" << br->GetFalseTarget()->GetName() << "\n"; + } + break; + } + + case Opcode::Icmp: { + auto* icmp = static_cast(inst); + os << " " << icmp->GetName() << " = icmp "; + switch (icmp->GetPredicate()) { + case IcmpInst::Predicate::EQ: os << "eq"; break; + case IcmpInst::Predicate::NE: os << "ne"; break; + case IcmpInst::Predicate::LT: os << "slt"; break; + case IcmpInst::Predicate::LE: os << "sle"; break; + case IcmpInst::Predicate::GT: os << "sgt"; break; + case IcmpInst::Predicate::GE: os << "sge"; break; + } + os << " " << TypeToString(*icmp->GetLhs()->GetType()) + << " " << ValueToString(icmp->GetLhs()) + << ", " << ValueToString(icmp->GetRhs()) << "\n"; + break; + } + + case Opcode::ZExt: { + auto* zext = static_cast(inst); + os << " " << zext->GetName() << " = zext " + << TypeToString(*zext->GetSourceType()) << " " + << ValueToString(zext->GetValue()) << " to " + << TypeToString(*zext->GetTargetType()) << "\n"; + break; + } + + case Opcode::Trunc: { + auto* trunc = static_cast(inst); + os << " " << trunc->GetName() << " = trunc " + << TypeToString(*trunc->GetSourceType()) << " " + << ValueToString(trunc->GetValue()) << " to " + << TypeToString(*trunc->GetTargetType()) << "\n"; + break; + } + case Opcode::GEP:{ + // 打印为类似 LLVM 的 getelementptr 形式: + // getelementptr , , i32 , i32 , ... + os << " " << inst->GetName() << " = getelementptr "; + // 基地址类型使用第一个操作数的类型 + Value* base = inst->GetOperand(0); + // GEP 的第一个类型参数应是基址指向的元素类型(pointee)。 + std::string elem_ty; + if (base->GetType()->IsPtrInt32()) elem_ty = "i32"; + else if (base->GetType()->IsPtrFloat()) elem_ty = "float"; + else if (base->GetType()->IsArray()) elem_ty = TypeToString(*base->GetType()); + else elem_ty = TypeToString(*inst->GetType()); + + std::string base_ty = TypeToString(*base->GetType()); + if (base->GetType()->IsArray()) { + base_ty += "*"; + } + + os << elem_ty << ", " << base_ty << " " << ValueToString(base); + + // 后续操作数为索引,按照 i32 打印 + // 特殊处理:如果 base 是标量指针(i32*/float*)且第一个索引是常量 0 + // 且后续还有索引,则丢弃第一个 0(对 T* 来说多余且会导致无效 IR)。 + size_t start_idx = 1; + if ((base->GetType()->IsPtrInt32() || base->GetType()->IsPtrFloat()) && + inst->GetNumOperands() >= 3) { + // 检查第一个索引是否为常量 0 + auto* first_idx = inst->GetOperand(1); + if (auto* ci = dynamic_cast(first_idx)) { + if (ci->GetValue() == 0) { + start_idx = 2; // 跳过第一个 0 + } + } + } + for (size_t i = start_idx; i < inst->GetNumOperands(); ++i) { + os << ", i32 " << ValueToString(inst->GetOperand(i)); + } + os << "\n"; + break; + } + case Opcode::FCmp: { + auto* fcmp = static_cast(inst); + os << " " << fcmp->GetName() << " = fcmp "; + switch (fcmp->GetPredicate()) { + case FcmpInst::Predicate::OEQ: os << "oeq"; break; + case FcmpInst::Predicate::ONE: os << "one"; break; + case FcmpInst::Predicate::OLT: os << "olt"; break; + case FcmpInst::Predicate::OLE: os << "ole"; break; + case FcmpInst::Predicate::OGT: os << "ogt"; break; + case FcmpInst::Predicate::OGE: os << "oge"; break; + default: os << "oeq"; break; + } + os << " " << TypeToString(*fcmp->GetLhs()->GetType()) + << " " << ValueToString(fcmp->GetLhs()) + << ", " << ValueToString(fcmp->GetRhs()) << "\n"; + break; + } + + case Opcode::SIToFP: { + auto* sitofp = static_cast(inst); + os << " " << sitofp->GetName() << " = sitofp " + << TypeToString(*sitofp->GetValue()->GetType()) << " " + << ValueToString(sitofp->GetValue()) << " to " + << TypeToString(*sitofp->GetType()) << "\n"; + break; + } + + case Opcode::FPToSI: { + auto* fptosi = static_cast(inst); + os << " " << fptosi->GetName() << " = fptosi " + << TypeToString(*fptosi->GetValue()->GetType()) << " " + << ValueToString(fptosi->GetValue()) << " to " + << TypeToString(*fptosi->GetType()) << "\n"; + break; + } + + default: { + // 处理未知操作码 + os << " ; 未知指令: " << OpcodeToString(inst->GetOpcode()) << "\n"; break; } } @@ -104,4 +562,28 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { } } +void IRPrinter::PrintConstant(const ConstantValue* constant, std::ostream& os) { + if (auto* const_int = dynamic_cast(constant)) { + os << const_int->GetValue(); + } + else if (auto* const_float = dynamic_cast(constant)) { + os << const_float->GetValue(); + } + else if (auto* const_array = dynamic_cast(constant)) { + os << "["; + auto& elements = const_array->GetElements(); + for (size_t i = 0; i < elements.size(); ++i) { + if (i > 0) os << ", "; + PrintConstant(elements[i], os); + } + os << "]"; + } + else if (dynamic_cast(constant)) { + os << "zero"; + } + else if (dynamic_cast(constant)) { + os << "zeroinitializer"; + } +} + } // namespace ir diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index 7928716..d0f280a 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -52,7 +52,10 @@ Instruction::Instruction(Opcode op, std::shared_ptr ty, std::string name) Opcode Instruction::GetOpcode() const { return opcode_; } -bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret; } +bool Instruction::IsTerminator() const { + return opcode_ == Opcode::Ret || opcode_ == Opcode::Br || + opcode_ == Opcode::CondBr; +} BasicBlock* Instruction::GetParent() const { return parent_; } @@ -61,22 +64,71 @@ void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; } BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, Value* rhs, std::string name) : Instruction(op, std::move(ty), std::move(name)) { - if (op != Opcode::Add) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add")); - } - if (!lhs || !rhs) { - throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数")); + // 检查操作码是否为有效的二元操作符 + switch (op) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Div: + case Opcode::Mod: + case Opcode::And: + case Opcode::Or: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: + // 有效的二元操作符 + break; + case Opcode::Not: + // Not是一元操作符,不应该在BinaryInst中 + throw std::runtime_error(FormatError("ir", "Not是一元操作符,应使用其他指令")); + default: + throw std::runtime_error(FormatError("ir", "BinaryInst 不支持的操作码")); + } + // 当前 BinaryInst 仅支持 Add/Sub/Mul,且操作数和结果必须都是 i32。 + if (op != Opcode::Add && op != Opcode::Sub && op != Opcode::Mul) { } + if (!type_ || !lhs->GetType() || !rhs->GetType()) { throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息")); } - if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind() || - type_->GetKind() != lhs->GetType()->GetKind()) { - throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配")); + + // 对于比较操作,结果类型是i1,但我们的类型系统可能还没有i1 + // 暂时简化:所有操作都返回i32,比较操作返回0或1 + // 检查操作数类型是否匹配 + if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind()) { + throw std::runtime_error(FormatError("ir", "BinaryInst 操作数类型不匹配")); } - if (!type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32")); + + bool is_logical = (op == Opcode::And || op == Opcode::Or); + + // 检查操作数类型是否支持 + if (is_logical) { + if (!lhs->GetType()->IsInt32() && !lhs->GetType()->IsInt1()) { + throw std::runtime_error( + FormatError("ir", "逻辑运算仅支持 i32/i1")); + } + } else { + if (!lhs->GetType()->IsInt32() && !lhs->GetType()->IsFloat()) { + throw std::runtime_error( + FormatError("ir", "BinaryInst 只支持 int32 和 float 类型")); + } } + + if (is_logical) { + // 逻辑运算结果类型应与操作数一致(i1 或 i32)。 + if (type_->GetKind() != lhs->GetType()->GetKind()) { + throw std::runtime_error( + FormatError("ir", "逻辑运算结果类型与操作数类型不匹配")); + } + } else { + // 算术运算的结果类型应与操作数类型相同 + if (type_->GetKind() != lhs->GetType()->GetKind()) { + throw std::runtime_error( + FormatError("ir", "BinaryInst 结果类型与操作数类型不匹配")); + } + } + AddOperand(lhs); AddOperand(rhs); } @@ -87,21 +139,27 @@ Value* BinaryInst::GetRhs() const { return GetOperand(1); } ReturnInst::ReturnInst(std::shared_ptr void_ty, Value* val) : Instruction(Opcode::Ret, std::move(void_ty), "") { - if (!val) { - throw std::runtime_error(FormatError("ir", "ReturnInst 缺少返回值")); - } if (!type_ || !type_->IsVoid()) { throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void")); } - AddOperand(val); + if (val) { + AddOperand(val); + } } -Value* ReturnInst::GetValue() const { return GetOperand(0); } +Value* ReturnInst::GetValue() const { + if (GetNumOperands() == 0) { + return nullptr; + } + return GetOperand(0); +} AllocaInst::AllocaInst(std::shared_ptr ptr_ty, std::string name) : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) { - if (!type_ || !type_->IsPtrInt32()) { - throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*")); + if (!type_ || + (!type_->IsPtrInt32() && !type_->IsPtrFloat() && !type_->IsArray())) { + throw std::runtime_error( + FormatError("ir", "AllocaInst 仅支持 i32* / float* / array")); } } @@ -110,12 +168,15 @@ LoadInst::LoadInst(std::shared_ptr val_ty, Value* ptr, std::string name) if (!ptr) { throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr")); } - if (!type_ || !type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32")); + if (!type_ || (!type_->IsInt32() && !type_->IsFloat() && !type_->IsInt1())) { + throw std::runtime_error( + FormatError("ir", "LoadInst 仅支持加载 i32/float/i1")); } - if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { + if (!ptr->GetType() || + (!ptr->GetType()->IsPtrInt32() && !ptr->GetType()->IsPtrFloat() && + !ptr->GetType()->IsArray() && !ptr->GetType()->IsPtrInt1())) { throw std::runtime_error( - FormatError("ir", "LoadInst 当前只支持从 i32* 加载")); + FormatError("ir", "LoadInst 仅支持从指针或数组地址加载")); } AddOperand(ptr); } @@ -133,13 +194,25 @@ StoreInst::StoreInst(std::shared_ptr void_ty, Value* val, Value* ptr) if (!type_ || !type_->IsVoid()) { throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void")); } - if (!val->GetType() || !val->GetType()->IsInt32()) { - throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32")); - } - if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { + if (!val->GetType() || + (!val->GetType()->IsInt32() && !val->GetType()->IsFloat() && + !val->GetType()->IsInt1() && !val->GetType()->IsArray())) { throw std::runtime_error( - FormatError("ir", "StoreInst 当前只支持写入 i32*")); + FormatError("ir", "StoreInst 仅支持存储 i32/float/i1/array")); } + if (!ptr->GetType() || + (!ptr->GetType()->IsPtrInt32() && !ptr->GetType()->IsPtrFloat() && + !ptr->GetType()->IsArray() && !ptr->GetType()->IsPtrInt1())) { + throw std::runtime_error(FormatError("ir", "StoreInst 仅支持写入指针或数组地址")); + } + if (ptr->GetType()->IsArray()) { + if (!val->GetType()->IsArray() || + val->GetType()->GetKind() != ptr->GetType()->GetKind()) { + throw std::runtime_error( + FormatError("ir", "StoreInst 聚合存储要求 value/ptr 具有相同数组类型")); + } + } + AddOperand(val); AddOperand(ptr); } @@ -148,4 +221,61 @@ Value* StoreInst::GetValue() const { return GetOperand(0); } Value* StoreInst::GetPtr() const { return GetOperand(1); } + +Function* CallInst::GetCallee() const { return callee_; } + +const std::vector& CallInst::GetArgs() const { return args_; } + +GEPInst::GEPInst(std::shared_ptr ptr_ty, + Value* base, + const std::vector& indices, + const std::string& name) + : Instruction(Opcode::GEP, ptr_ty, name) { + // 添加base作为第一个操作数 + AddOperand(base); + + // 添加所有索引作为后续操作数 + for (auto* index : indices) { + AddOperand(index); + } +} + +Value* GEPInst::GetBase() const { + // 第一个操作数是base + return GetOperand(0); +} + +const std::vector& GEPInst::GetIndices() const { + // 需要返回索引列表,但Instruction只存储操作数 + // 这是一个设计问题:要么修改架构,要么提供辅助方法 + + // 简化实现:返回空vector(或创建临时vector) + static std::vector indices; + indices.clear(); + + // 索引从操作数1开始 + for (size_t i = 1; i < GetNumOperands(); ++i) { + indices.push_back(GetOperand(i)); + } + + return indices; +} + +CallInst::CallInst(std::shared_ptr ret_ty, Function* callee, + const std::vector& args, const std::string& name) + : Instruction(Opcode::Call, std::move(ret_ty), name), // name 是 const&,这里会复制 + callee_(callee), args_(args) { + if (!callee_) { + throw std::runtime_error(FormatError("ir", "CallInst 缺少被调用函数")); + } + for (auto* arg : args_) { + if (!arg) { + throw std::runtime_error(FormatError("ir", "CallInst 参数不能为 null")); + } + AddOperand(arg); + } +} + + } // namespace ir + diff --git a/src/ir/Module.cpp b/src/ir/Module.cpp index 928efdc..79e41a5 100644 --- a/src/ir/Module.cpp +++ b/src/ir/Module.cpp @@ -9,13 +9,46 @@ Context& Module::GetContext() { return context_; } const Context& Module::GetContext() const { return context_; } Function* Module::CreateFunction(const std::string& name, - std::shared_ptr ret_type) { - functions_.push_back(std::make_unique(name, std::move(ret_type))); + std::shared_ptr func_type) { + functions_.push_back(std::make_unique(name, std::move(func_type))); return functions_.back().get(); } +Function* Module::FindFunction(const std::string& name) const { + for (const auto& func : functions_) { + if (func->GetName() == name) { + return func.get(); + } + } + return nullptr; +} + const std::vector>& Module::GetFunctions() const { return functions_; } +GlobalValue* Module::CreateGlobal(const std::string& name, + std::shared_ptr ty) { + // 对于标量类型,自动转换为指针类型 + std::shared_ptr global_ty; + + if (ty->IsInt32()) { + global_ty = Type::GetPtrInt32Type(); // i32 -> i32* + } else if (ty->IsFloat()) { + global_ty = Type::GetPtrFloatType(); // float -> float* + } else if (ty->IsInt1()) { + global_ty = Type::GetPtrInt1Type(); // i1 -> i1* + } else { + // 数组等类型保持不变 + global_ty = ty; + } + + globals_.push_back(std::make_unique(global_ty, name)); + return globals_.back().get(); +} + +const std::vector>& Module::GetGlobals() const { + return globals_; +} + } // namespace ir diff --git a/src/ir/Type.cpp b/src/ir/Type.cpp index 3e1684d..792ff54 100644 --- a/src/ir/Type.cpp +++ b/src/ir/Type.cpp @@ -1,31 +1,219 @@ // 当前仅支持 void、i32 和 i32*。 #include "ir/IR.h" +#include namespace ir { Type::Type(Kind k) : kind_(k) {} +size_t Type::Size() const { + switch (kind_) { + case Kind::Void: return 0; + case Kind::Int32: return 4; + case Kind::Float: return 4; // 单精度浮点 4 字节 + case Kind::PtrInt32: return 8; // 假设 64 位指针 + case Kind::PtrFloat: return 8; + case Kind::Label: return 8; // 标签地址大小(指针大小) + default: return 0; // 派生类应重写 + } +} + +size_t Type::Alignment() const { + switch (kind_) { + case Kind::Int32: return 4; + case Kind::Float: return 4; + case Kind::PtrInt32: return 8; + case Kind::PtrFloat: return 8; + case Kind::Label: return 8; // 与指针相同 + default: return 1; + } +} + +bool Type::IsComplete() const { + return kind_ != Kind::Void; +} const std::shared_ptr& Type::GetVoidType() { - static const std::shared_ptr type = std::make_shared(Kind::Void); + static const std::shared_ptr type = std::shared_ptr(new Type(Kind::Void)); return type; } const std::shared_ptr& Type::GetInt32Type() { - static const std::shared_ptr type = std::make_shared(Kind::Int32); + static const std::shared_ptr type = std::shared_ptr(new Type(Kind::Int32)); return type; } +const std::shared_ptr& Type::GetFloatType() { + static const std::shared_ptr type(new Type(Kind::Float)); + return type; +} + const std::shared_ptr& Type::GetPtrInt32Type() { - static const std::shared_ptr type = std::make_shared(Kind::PtrInt32); + static const std::shared_ptr type = std::shared_ptr(new Type(Kind::PtrInt32)); return type; } -Type::Kind Type::GetKind() const { return kind_; } +const std::shared_ptr& Type::GetPtrFloatType() { + static const std::shared_ptr type(new Type(Kind::PtrFloat)); + return type; +} + +const std::shared_ptr& Type::GetLabelType() { + static const std::shared_ptr type(new Type(Kind::Label)); + return type; +} + +// Int1 类型表示布尔值,通常用于比较指令的结果 +const std::shared_ptr& Type::GetInt1Type() { + static const std::shared_ptr type = std::shared_ptr(new Type(Kind::Int1)); + return type; + } +// PtrInt1 类型表示指向 Int1 的指针,主要用于条件跳转等场景 +const std::shared_ptr& Type::GetPtrInt1Type() { + static const std::shared_ptr type = std::shared_ptr(new Type(Kind::PtrInt1)); + return type; + } -bool Type::IsVoid() const { return kind_ == Kind::Void; } +// ---------- 数组类型缓存 ---------- +// 使用自定义键类型保证唯一性:元素类型指针 + 维度向量 +struct ArrayKey { + const Type* elem_type; + std::vector dims; -bool Type::IsInt32() const { return kind_ == Kind::Int32; } + bool operator==(const ArrayKey& other) const { + return elem_type == other.elem_type && dims == other.dims; + } +}; -bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; } +struct ArrayKeyHash { + std::size_t operator()(const ArrayKey& key) const { + std::size_t h = std::hash{}(key.elem_type); + for (int d : key.dims) { + h ^= std::hash{}(d) + 0x9e3779b9 + (h << 6) + (h >> 2); + } + return h; + } +}; + +static std::unordered_map, ArrayKeyHash>& GetArrayCache() { + static std::unordered_map, ArrayKeyHash> cache; + return cache; +} + +std::shared_ptr Type::GetArrayType(std::shared_ptr elem, + std::vector dims) { + // 检查维度合法性 + for (int d : dims) { + if (d <= 0) { + // SysY 数组维度必须为正整数常量表达式,这里假设已检查 + return nullptr; + } + } + + ArrayKey key{elem.get(), dims}; + auto& cache = GetArrayCache(); + auto it = cache.find(key); + if (it != cache.end()) { + auto ptr = it->second.lock(); + if (ptr) return ptr; + } + + auto arr = std::shared_ptr(new ArrayType(std::move(elem), std::move(dims))); + cache[key] = arr; + return arr; +} + +// ---------- 函数类型缓存 ---------- +struct FunctionKey { + const Type* return_type; + std::vector param_types; + + bool operator==(const FunctionKey& other) const { + return return_type == other.return_type && param_types == other.param_types; + } +}; + +struct FunctionKeyHash { + std::size_t operator()(const FunctionKey& key) const { + std::size_t h = std::hash{}(key.return_type); + for (const Type* t : key.param_types) { + h ^= std::hash{}(t) + 0x9e3779b9 + (h << 6) + (h >> 2); + } + return h; + } +}; + +static std::unordered_map, FunctionKeyHash>& GetFunctionCache() { + static std::unordered_map, FunctionKeyHash> cache; + return cache; +} + +std::shared_ptr Type::GetFunctionType(std::shared_ptr ret, + std::vector> params) { + // 提取裸指针用于键(保证唯一性,因为 shared_ptr 指向同一对象) + std::vector param_ptrs; + param_ptrs.reserve(params.size()); + for (const auto& p : params) { + param_ptrs.push_back(p.get()); + } + + FunctionKey key{ret.get(), std::move(param_ptrs)}; + auto& cache = GetFunctionCache(); + auto it = cache.find(key); + if (it != cache.end()) { + auto ptr = it->second.lock(); + if (ptr) return ptr; + } + + auto func = std::shared_ptr(new FunctionType(std::move(ret), std::move(params))); + cache[key] = func; + return func; +} + +// ---------- ArrayType 实现 ---------- +ArrayType::ArrayType(std::shared_ptr elem, std::vector dims) + : Type(Kind::Array), elem_type_(std::move(elem)), dims_(std::move(dims)) { + // 数组元素类型必须是完整类型 + assert(elem_type_ && elem_type_->IsComplete()); +} + +size_t ArrayType::GetElementCount() const { + size_t count = 1; + for (int d : dims_) count *= d; + return count; +} + +size_t ArrayType::Size() const { + return GetElementCount() * elem_type_->Size(); +} + +size_t ArrayType::Alignment() const { + // 数组对齐等于其元素对齐 + return elem_type_->Alignment(); +} + +bool ArrayType::IsComplete() const { + // 维度已确定且元素类型完整,则数组完整 + return !dims_.empty() && elem_type_->IsComplete(); +} + +// ---------- FunctionType 实现 ---------- +FunctionType::FunctionType(std::shared_ptr ret, + std::vector> params) + : Type(Kind::Function), return_type_(std::move(ret)), param_types_(std::move(params)) {} + +size_t FunctionType::Size() const { + // 函数类型没有运行时大小,通常用于类型检查,返回 0 + return 0; +} + +size_t FunctionType::Alignment() const { + // 不对齐 + return 1; +} + +bool FunctionType::IsComplete() const { + // 函数类型总是完整的(只要返回类型完整,但 SysY 中 void 也视为完整) + return true; +} } // namespace ir diff --git a/src/ir/Value.cpp b/src/ir/Value.cpp index 2e9f4c1..56dd2e6 100644 --- a/src/ir/Value.cpp +++ b/src/ir/Value.cpp @@ -76,8 +76,65 @@ void Value::ReplaceAllUsesWith(Value* new_value) { ConstantValue::ConstantValue(std::shared_ptr ty, std::string name) : Value(std::move(ty), std::move(name)) {} +// 建一个 Argument 对象,用给定的类型和名称初始化它,并继承 Value 的所有属性和方法。 +Argument::Argument(std::shared_ptr ty, std::string name) + : Value(std::move(ty), std::move(name)) {} ConstantInt::ConstantInt(std::shared_ptr ty, int v) : ConstantValue(std::move(ty), ""), value_(v) {} +// ConstantFloat 实现 +ConstantFloat::ConstantFloat(std::shared_ptr ty, float v) + : ConstantValue(ty, ""), value_(v) { + if (!ty->IsFloat()) { + throw std::runtime_error("ConstantFloat requires Float type"); + } +} + +// ConstantArray 实现 +ConstantArray::ConstantArray(std::shared_ptr ty, + std::vector elements) + : ConstantValue(ty, ""), elements_(std::move(elements)) { + if (!IsValid()) { + throw std::runtime_error("Invalid constant array initialization"); + } +} + +bool ConstantArray::IsValid() const { + auto* array_ty = dynamic_cast(GetType().get()); + if (!array_ty) return false; + + // 检查元素数量是否匹配 + if (elements_.size() != array_ty->GetElementCount()) return false; + + // 检查每个元素的类型是否匹配数组元素类型 + auto& elem_ty = array_ty->GetElementType(); + for (auto* elem : elements_) { + if (elem->GetType() != elem_ty) return false; + } + + return true; +} + +// ConstantZero 实现 +ConstantZero::ConstantZero(std::shared_ptr ty) + : ConstantValue(ty, "zero") { + // 零常量可以用于任何类型 +} + +std::unique_ptr ConstantZero::GetZero(std::shared_ptr ty) { + return std::make_unique(ty); +} + +// ConstantAggregateZero 实现 +ConstantAggregateZero::ConstantAggregateZero(std::shared_ptr ty) + : ConstantValue(ty, "zero") { + if (!ty->IsArray()) { + throw std::runtime_error("ConstantAggregateZero requires aggregate type"); + } +} + +std::unique_ptr ConstantAggregateZero::GetZero(std::shared_ptr ty) { + return std::make_unique(ty); +} } // namespace ir diff --git a/src/irgen/CMakeLists.txt b/src/irgen/CMakeLists.txt index d440bde..04a3195 100644 --- a/src/irgen/CMakeLists.txt +++ b/src/irgen/CMakeLists.txt @@ -10,4 +10,4 @@ target_link_libraries(irgen PUBLIC build_options ${ANTLR4_RUNTIME_TARGET} ir -) +) \ No newline at end of file diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index 0eb62ae..30234bc 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -1,5 +1,7 @@ +// IRGenDecl.cpp #include "irgen/IRGen.h" +#include #include #include "SysYParser.h" @@ -8,100 +10,758 @@ namespace { -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("irgen", "非法左值")); +constexpr size_t kLocalArrayHeapThresholdBytes = 1024 * 1024; + +size_t GetArrayElementByteWidth(const ir::ArrayType& array_ty) { + auto* elem_ty = array_ty.GetElementType().get(); + if (elem_ty->IsInt1()) { + return 1; } - return lvalue.ID()->getText(); + return 4; } -} // namespace +size_t GetArrayStorageBytes(const ir::ArrayType& array_ty) { + return static_cast(array_ty.GetElementCount()) * + GetArrayElementByteWidth(array_ty); +} -std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少语句块")); +bool IsZeroIRValue(const ir::Value* value) { + if (!value) { + return true; } - for (auto* item : ctx->blockItem()) { - if (item) { - if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { - // 当前语法要求 return 为块内最后一条语句;命中后可停止生成。 - break; - } - } + if (auto* ci = dynamic_cast(value)) { + return ci->GetValue() == 0; } - return {}; + if (auto* cf = dynamic_cast(value)) { + return cf->GetValue() == 0.0f; + } + if (dynamic_cast(value) || + dynamic_cast(value)) { + return true; + } + return false; } -IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult( - SysYParser::BlockItemContext& item) { - return std::any_cast(item.accept(this)); +bool IsZeroConstantValue(const ir::ConstantValue* value) { + return IsZeroIRValue(value); } -std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少块内项")); +std::vector TrimTrailingZeroConstants( + std::vector values) { + while (!values.empty() && IsZeroConstantValue(values.back())) { + values.pop_back(); } - if (ctx->decl()) { - ctx->decl()->accept(this); - return BlockFlow::Continue; + return values; +} + +std::vector BuildArrayIndices(ir::IRBuilder& builder, + const std::vector& dims, + size_t flat_idx) { + std::vector indices; + indices.reserve(dims.size() + 1); + indices.push_back(builder.CreateConstInt(0)); + + size_t rem = flat_idx; + for (size_t i = 0; i < dims.size(); ++i) { + size_t stride = 1; + for (size_t j = i + 1; j < dims.size(); ++j) { + stride *= static_cast(dims[j]); + } + const int idx = static_cast(rem / stride); + rem %= stride; + indices.push_back(builder.CreateConstInt(idx)); } - if (ctx->stmt()) { - return ctx->stmt()->accept(this); + return indices; +} + +std::vector BuildZeroIndices(ir::IRBuilder& builder, + size_t dims_count) { + std::vector indices; + indices.reserve(dims_count + 1); + for (size_t i = 0; i <= dims_count; ++i) { + indices.push_back(builder.CreateConstInt(0)); } - throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明")); + return indices; } -// 变量声明的 IR 生成目前也是最小实现: -// - 先检查声明的基础类型,当前仅支持局部 int; -// - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。 -// -// 和更完整的版本相比,这里还没有: -// - 一个 Decl 中多个变量定义的顺序处理; -// - const、数组、全局变量等不同声明形态; -// - 更丰富的类型系统。 +std::string MakeStaticArrayName(const ir::Function& func, + const std::string& var_name, + std::string suffix) { + for (char& ch : suffix) { + if (ch == '%') { + ch = '_'; + } + } + return "__static_array." + func.GetName() + "." + var_name + "." + suffix; +} + + +} // namespace + +// visitDecl: 处理声明 std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { + DEBUG_MSG("[DEBUG] visitDecl: 开始处理声明"); if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少变量声明")); } - if (!ctx->btype() || !ctx->btype()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明")); + + // 处理 varDecl + if (auto* varDecl = ctx->varDecl()) { + DEBUG_MSG("[DEBUG] visitDecl: 处理变量声明"); + for (auto* varDef : varDecl->varDef()) { + varDef->accept(this); + } } - auto* var_def = ctx->varDef(); - if (!var_def) { - throw std::runtime_error(FormatError("irgen", "非法变量声明")); + + // 处理 constDecl + if (ctx->constDecl()) { + DEBUG_MSG("[DEBUG] visitDecl: 处理常量声明"); + auto* constDecl = ctx->constDecl(); + for (auto* constDef : constDecl->constDef()) { + constDef->accept(this); + } } - var_def->accept(this); + + DEBUG_MSG("[DEBUG] visitDecl: 声明处理完成"); return {}; } - -// 当前仍是教学用的最小版本,因此这里只支持: -// - 局部 int 变量; -// - 标量初始化; -// - 一个 VarDef 对应一个槽位。 -std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { +// visitConstDecl: 处理常量声明 +std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) { + DEBUG_MSG("[DEBUG] visitConstDecl: 开始处理常量声明"); if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少变量定义")); + throw std::runtime_error(FormatError("irgen", "非法常量声明")); + } + + for (auto* constDef : ctx->constDef()) { + if (constDef) { + constDef->accept(this); + } + } + + DEBUG_MSG("[DEBUG] visitConstDecl: 常量声明处理完成"); + return {}; +} + +// visitConstDef: 处理常量定义 - 从符号表获取常量值 +std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) { + DEBUG_MSG("[DEBUG] visitConstDef: 开始处理常量定义"); + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("irgen", "非法常量定义")); } - if (!ctx->lValue()) { - throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); + + std::string const_name = ctx->Ident()->getText(); + + // 从符号表获取常量符号 + const Symbol* sym = symbol_table_.lookupByConstDef(ctx); + if (!sym || sym->kind != SymbolKind::Constant) { + throw std::runtime_error(FormatError("irgen", "常量符号未找到: " + const_name)); } - GetLValueName(*ctx->lValue()); + + DEBUG_MSG("[DEBUG] visitConstDef: 从符号表获取常量 " << const_name + << ", is_array_const: " << sym->IsArrayConstant()); + + // 根据符号表中的常量值创建 IR 常量 + if (sym->IsArrayConstant()) { + auto* array_ty = dynamic_cast(sym->type.get()); + if (!array_ty) { + throw std::runtime_error(FormatError("irgen", "数组类型转换失败")); + } + + const bool is_global_scope = (func_ == nullptr) || (sym->scope_level == 0); + const bool use_heap_storage = !is_global_scope && + (current_function_is_recursive_ || + GetArrayStorageBytes(*array_ty) > kLocalArrayHeapThresholdBytes); + + // 先把符号表中的扁平初始化值转换为 IR 常量值 + std::vector init_consts; + init_consts.reserve(sym->GetArraySize()); + for (size_t i = 0; i < sym->GetArraySize(); ++i) { + auto elem = sym->GetArrayElement(i); + if (array_ty->GetElementType()->IsInt32()) { + init_consts.push_back(builder_.CreateConstInt(elem.i32)); + } else if (array_ty->GetElementType()->IsFloat()) { + init_consts.push_back(builder_.CreateConstFloat(elem.f32)); + } + } + + init_consts = TrimTrailingZeroConstants(std::move(init_consts)); + + if (is_global_scope) { + // 全局 const 数组:全局存储,并标记为 constant + ir::GlobalValue* global_array = module_.CreateGlobal(const_name, sym->type); + if (!init_consts.empty()) { + global_array->SetInitializer(init_consts); + } + global_array->SetConstant(true); + const_global_map_[const_name] = global_array; + const_storage_map_[ctx] = global_array; + } else { + ir::Value* array_slot = nullptr; + if (use_heap_storage) { + const bool is_float_array = array_ty->GetElementType()->IsFloat(); + const std::string alloc_name = is_float_array ? "sysy_alloc_f32" : "sysy_alloc_i32"; + const std::string free_name = is_float_array ? "sysy_free_f32" : "sysy_free_i32"; + ir::Function* alloc_func = module_.FindFunction(alloc_name); + if (!alloc_func) alloc_func = CreateRuntimeFunctionDecl(alloc_name); + ir::Function* free_func = module_.FindFunction(free_name); + if (!free_func) free_func = CreateRuntimeFunctionDecl(free_name); + array_slot = builder_.CreateCall( + alloc_func, + {builder_.CreateConstInt(static_cast(array_ty->GetElementCount()))}, + module_.GetContext().NextTemp()); + heap_local_array_names_.insert(const_name); + RegisterCleanup(free_func, array_slot); + } else { + array_slot = CreateEntryAlloca(sym->type, + module_.GetContext().NextTemp() + "_" + const_name); + } + + const auto& dims = array_ty->GetDimensions(); + const size_t total_size = array_ty->GetElementCount(); + + if (init_consts.empty()) { + if (use_heap_storage) { + const std::string zero_name = array_ty->GetElementType()->IsFloat() + ? "sysy_zero_f32" + : "sysy_zero_i32"; + ir::Function* zero_func = module_.FindFunction(zero_name); + if (!zero_func) zero_func = CreateRuntimeFunctionDecl(zero_name); + builder_.CreateCall(zero_func, + {array_slot, builder_.CreateConstInt(static_cast(total_size))}, + module_.GetContext().NextTemp()); + } else { + builder_.CreateStore(module_.GetContext().GetAggregateZero(sym->type), array_slot); + } + } + + for (size_t i = 0; i < total_size; ++i) { + ir::Value* init = nullptr; + if (i < init_consts.size()) { + init = init_consts[i]; + } else if (array_ty->GetElementType()->IsFloat()) { + init = builder_.CreateConstFloat(0.0f); + } else { + init = builder_.CreateConstInt(0); + } + + ir::Value* elem_ptr = nullptr; + if (use_heap_storage) { + if (IsZeroIRValue(init)) { + continue; + } + elem_ptr = builder_.CreateGEP( + array_slot, {builder_.CreateConstInt(static_cast(i))}, + module_.GetContext().NextTemp()); + } else { + elem_ptr = builder_.CreateGEP( + array_slot, BuildArrayIndices(builder_, dims, i), + module_.GetContext().NextTemp()); + } + builder_.CreateStore(init, elem_ptr); + } + + local_var_map_[const_name] = array_slot; + const_storage_map_[ctx] = array_slot; + } + + } else if (sym->IsScalarConstant()) { + // 标量常量:存储常量值 + ir::ConstantValue* const_value = nullptr; + if (sym->type->IsInt32()) { + const_value = builder_.CreateConstInt(sym->GetIntConstant()); + DEBUG_MSG("[DEBUG] visitConstDef: 整型常量 " << const_name + << " = " << sym->GetIntConstant()); + } else if (sym->type->IsFloat()) { + const_value = builder_.CreateConstFloat(sym->GetFloatConstant()); + DEBUG_MSG("[DEBUG] visitConstDef: 浮点常量 " << const_name + << " = " << sym->GetFloatConstant()); + } + + const_value_map_[const_name] = const_value; + const_storage_map_[ctx] = const_value; + } + + return {}; +} + +// visitVarDef: 处理变量定义 - 从符号表获取类型信息 +std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { + DEBUG_MSG("[DEBUG] visitVarDef: 开始处理变量定义"); + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("irgen", "非法变量定义")); + } + + std::string varName = ctx->Ident()->getText(); + DEBUG_MSG("[DEBUG] visitVarDef: 变量名称: " << varName); + + // 防止重复分配 if (storage_map_.find(ctx) != storage_map_.end()) { - throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); + throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位: " + varName)); + } + + // 从符号表获取变量信息 + const Symbol* sym = symbol_table_.lookupByVarDef(ctx); + if (!sym) { + throw std::runtime_error(FormatError("irgen", "变量符号未找到: " + varName)); + } + + DEBUG_MSG("[DEBUG] visitVarDef: 变量类型: " + << (sym->type->IsInt32() ? "int" : + sym->type->IsFloat() ? "float" : + sym->type->IsArray() ? "array" : "unknown")); + + // 根据作用域处理 + if (func_ == nullptr) { + DEBUG_MSG("[DEBUG] visitVarDef: 处理全局变量"); + return HandleGlobalVariable(ctx, varName, sym); + } else { + DEBUG_MSG("[DEBUG] visitVarDef: 处理局部变量"); + return HandleLocalVariable(ctx, varName, sym); + } +} + +// HandleGlobalVariable: 处理全局变量 +std::any IRGenImpl::HandleGlobalVariable(SysYParser::VarDefContext* ctx, + const std::string& varName, + const Symbol* sym) { + DEBUG_MSG("[DEBUG] HandleGlobalVariable: 开始处理全局变量 " << varName); + + if (!sym) { + throw std::runtime_error(FormatError("irgen", "符号表信息缺失: " + varName)); + } + + bool is_array = sym->type->IsArray(); + bool is_float = sym->type->IsFloat(); + if (is_array) { + if (auto* array_ty = dynamic_cast(sym->type.get())) { + is_float = array_ty->GetElementType()->IsFloat(); + } } - auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); - storage_map_[ctx] = slot; + + if (is_array) { + // 从符号表获取数组类型和维度 + auto* array_ty = dynamic_cast(sym->type.get()); + if (!array_ty) { + throw std::runtime_error(FormatError("irgen", "数组类型转换失败")); + } + + const auto& dimensions = array_ty->GetDimensions(); + size_t total_size = array_ty->GetElementCount(); + + DEBUG_MSG("[DEBUG] HandleGlobalVariable: 全局数组 " << varName << " 维度: "); + for (int d : dimensions) DEBUG_MSG(d << " "); + DEBUG_MSG(", 总大小: " << total_size); + + // 创建全局数组 + ir::GlobalValue* global_array = module_.CreateGlobal(varName, sym->type); + + // 处理初始化值(使用带维度感知的展平) + std::vector init_consts; + if (auto* initVal = ctx->initVal()) { + DEBUG_MSG("[DEBUG] HandleGlobalVariable: 处理初始化值"); + // 全局变量的初始化必须是常量表达式(语义检查已保证) + std::vector flat_vals = FlattenInitVal( + initVal, dimensions, is_float); + for (auto* val : flat_vals) { + if (is_float) { + if (auto* cf = dynamic_cast(val)) { + init_consts.push_back(cf); + } else if (auto* ci = dynamic_cast(val)) { + init_consts.push_back(builder_.CreateConstFloat(static_cast(ci->GetValue()))); + } else { + init_consts.push_back(builder_.CreateConstFloat(0.0f)); + } + } else { + if (auto* ci = dynamic_cast(val)) { + init_consts.push_back(ci); + } else if (auto* cf = dynamic_cast(val)) { + init_consts.push_back(builder_.CreateConstInt(static_cast(cf->GetValue()))); + } else { + init_consts.push_back(builder_.CreateConstInt(0)); + } + } + } + } - ir::Value* init = nullptr; - if (auto* init_value = ctx->initValue()) { - if (!init_value->exp()) { - throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化")); + init_consts = TrimTrailingZeroConstants(std::move(init_consts)); + + // 设置初始化器 + if (!init_consts.empty()) { + global_array->SetInitializer(init_consts); } - init = EvalExpr(*init_value->exp()); + + storage_map_[ctx] = global_array; + global_map_[varName] = global_array; + } else { - init = builder_.CreateConstInt(0); + // 全局标量变量 + std::shared_ptr var_type = sym->type; + ir::GlobalValue* global_var = module_.CreateGlobal(varName, var_type); + + // 处理初始化值 + ir::ConstantValue* init_value = nullptr; + if (auto* initVal = ctx->initVal()) { + auto result = initVal->accept(this); + if (result.has_value()) { + try { + ir::Value* val = std::any_cast(result); + if (is_float) { + if (auto* const_float = dynamic_cast(val)) { + init_value = const_float; + } else if (auto* const_int = dynamic_cast(val)) { + init_value = builder_.CreateConstFloat(static_cast(const_int->GetValue())); + } + } else { + if (auto* const_int = dynamic_cast(val)) { + init_value = const_int; + } else if (auto* const_float = dynamic_cast(val)) { + init_value = builder_.CreateConstInt(static_cast(const_float->GetValue())); + } + } + } catch (const std::bad_any_cast&) { + // 使用默认值 + } + } + } + + //正确:只在没有初始化值时才设置默认值 + if (!init_value) { + if (is_float) { + init_value = builder_.CreateConstFloat(0.0f); + } else { + init_value = builder_.CreateConstInt(0); + } + } + + global_var->SetInitializer(init_value); + storage_map_[ctx] = global_var; + global_map_[varName] = global_var; + } + + DEBUG_MSG("[DEBUG] HandleGlobalVariable: 全局变量处理完成"); + return {}; +} + +// HandleLocalVariable: 处理局部变量 +std::any IRGenImpl::HandleLocalVariable(SysYParser::VarDefContext* ctx, + const std::string& varName, + const Symbol* sym) { + DEBUG_MSG("[DEBUG] HandleLocalVariable: 开始处理局部变量 " << varName); + + if (!sym) { + throw std::runtime_error(FormatError("irgen", "符号表信息缺失: " + varName)); + } + + bool is_array = sym->type->IsArray(); + bool is_float = sym->type->IsFloat(); + if (is_array) { + if (auto* array_ty = dynamic_cast(sym->type.get())) { + is_float = array_ty->GetElementType()->IsFloat(); + } + } + + if (is_array) { + // 从符号表获取数组信息 + auto* array_ty = dynamic_cast(sym->type.get()); + if (!array_ty) { + throw std::runtime_error(FormatError("irgen", "数组类型转换失败")); + } + + size_t total_size = array_ty->GetElementCount(); + const size_t total_bytes = GetArrayStorageBytes(*array_ty); + const bool use_heap_storage = + current_function_is_recursive_ || total_bytes > kLocalArrayHeapThresholdBytes; + + DEBUG_MSG("[DEBUG] HandleLocalVariable: 局部数组 " << varName + << " 总大小: " << total_size); + + ir::Value* array_slot = nullptr; + if (use_heap_storage) { + const std::string alloc_name = is_float ? "sysy_alloc_f32" : "sysy_alloc_i32"; + const std::string free_name = is_float ? "sysy_free_f32" : "sysy_free_i32"; + ir::Function* alloc_func = module_.FindFunction(alloc_name); + if (!alloc_func) { + alloc_func = CreateRuntimeFunctionDecl(alloc_name); + } + ir::Function* free_func = module_.FindFunction(free_name); + if (!free_func) { + free_func = CreateRuntimeFunctionDecl(free_name); + } + array_slot = builder_.CreateCall( + alloc_func, + {builder_.CreateConstInt(static_cast(total_size))}, + module_.GetContext().NextTemp()); + heap_local_array_names_.insert(varName); + RegisterCleanup(free_func, array_slot); + } else { + array_slot = CreateEntryAlloca( + sym->type, module_.GetContext().NextTemp() + "_" + varName); + } + + const auto& dims = array_ty->GetDimensions(); + + storage_map_[ctx] = array_slot; + local_var_map_[varName] = array_slot; + + // 处理初始化 + if (auto* initVal = ctx->initVal()) { + std::vector init_values = FlattenInitVal( + initVal, array_ty->GetDimensions(), is_float); + + bool is_all_zero_init = true; + for (auto* value : init_values) { + if (!IsZeroIRValue(value)) { + is_all_zero_init = false; + break; + } + } + + if (is_all_zero_init && !use_heap_storage) { + builder_.CreateStore(module_.GetContext().GetAggregateZero(sym->type), + array_slot); + DEBUG_MSG("[DEBUG] HandleLocalVariable: aggregate zeroinitializer store for " + << varName); + return {}; + } + + if (use_heap_storage) { + const std::string zero_name = is_float ? "sysy_zero_f32" : "sysy_zero_i32"; + ir::Function* zero_func = module_.FindFunction(zero_name); + if (!zero_func) { + zero_func = CreateRuntimeFunctionDecl(zero_name); + } + builder_.CreateCall(zero_func, + {array_slot, builder_.CreateConstInt(static_cast(total_size))}, + module_.GetContext().NextTemp()); + if (is_all_zero_init) { + return {}; + } + } + + for (size_t i = 0; i < total_size; i++) { + ir::Value* val = (i < init_values.size() && init_values[i]) + ? init_values[i] + : (is_float ? static_cast(builder_.CreateConstFloat(0.0f)) + : static_cast(builder_.CreateConstInt(0))); + if (use_heap_storage && IsZeroIRValue(val)) { + continue; + } + if (is_float && val->GetType()->IsInt32()) { + val = builder_.CreateSIToFP(val, ir::Type::GetFloatType(), + module_.GetContext().NextTemp()); + } else if (!is_float && val->GetType()->IsFloat()) { + val = builder_.CreateFPToSI(val, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + ir::Value* elem_ptr = nullptr; + if (use_heap_storage) { + elem_ptr = builder_.CreateGEP( + array_slot, {builder_.CreateConstInt(static_cast(i))}, + module_.GetContext().NextTemp()); + } else { + elem_ptr = builder_.CreateGEP( + array_slot, BuildArrayIndices(builder_, dims, i), + module_.GetContext().NextTemp()); + } + builder_.CreateStore(val, elem_ptr); + } + } + + } else { + // 局部标量变量 + ir::AllocaInst* slot; + if (is_float) { + slot = CreateEntryAllocaFloat(module_.GetContext().NextTemp() + "_" + varName); + } else { + slot = CreateEntryAllocaI32(module_.GetContext().NextTemp() + "_" + varName); + } + + storage_map_[ctx] = slot; + local_var_map_[varName] = slot; + + // 处理初始化 + ir::Value* init = nullptr; + if (auto* initVal = ctx->initVal()) { + auto result = initVal->accept(this); + if (result.has_value()) { + try { + init = std::any_cast(result); + } catch (const std::bad_any_cast&) { + try { + auto init_vec = std::any_cast>(result); + if (!init_vec.empty()) { + init = init_vec[0]; + } + } catch (const std::bad_any_cast&) { + // 使用默认值 + } + } + } + } + + if (!init) { + // SysY 局部变量未显式初始化时为未定义值:不生成默认 store。 + return {}; + } + + // 标量初始化支持 int/float 隐式转换。 + if (is_float && init->GetType()->IsInt32()) { + init = builder_.CreateSIToFP(init, ir::Type::GetFloatType(), + module_.GetContext().NextTemp()); + } else if (!is_float && init->GetType()->IsFloat()) { + init = builder_.CreateFPToSI(init, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + + builder_.CreateStore(init, slot); } - builder_.CreateStore(init, slot); + + DEBUG_MSG("[DEBUG] HandleLocalVariable: 局部变量处理完成"); return {}; } + +// visitInitVal: 处理初始化值 +std::any IRGenImpl::visitInitVal(SysYParser::InitValContext* ctx) { + DEBUG_MSG("[DEBUG] visitInitVal: 开始处理初始化值"); + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法初始化值")); + } + + // 如果是单个表达式 + if (ctx->exp()) { + DEBUG_MSG("[DEBUG] visitInitVal: 处理表达式初始化"); + return EvalExpr(*ctx->exp()); + } + // 如果是聚合初始化(花括号列表) + else if (!ctx->initVal().empty()) { + DEBUG_MSG("[DEBUG] visitInitVal: 处理聚合初始化"); + return ProcessNestedInitVals(ctx); + } + + DEBUG_MSG("[DEBUG] visitInitVal: 空初始化列表"); + return std::vector{}; +} + +// ProcessNestedInitVals: 处理嵌套聚合初始化 +std::vector IRGenImpl::ProcessNestedInitVals(SysYParser::InitValContext* ctx) { + DEBUG_MSG("[DEBUG] ProcessNestedInitVals: 开始处理嵌套初始化值"); + std::vector all_values; + + for (auto* init_val : ctx->initVal()) { + auto result = init_val->accept(this); + if (result.has_value()) { + try { + // 尝试获取单个值 + ir::Value* value = std::any_cast(result); + all_values.push_back(value); + DEBUG_MSG("[DEBUG] ProcessNestedInitVals: 获取到单个值"); + } catch (const std::bad_any_cast&) { + try { + // 尝试获取值列表(嵌套情况) + std::vector nested_values = + std::any_cast>(result); + DEBUG_MSG("[DEBUG] ProcessNestedInitVals: 获取到嵌套值列表, 大小: " + << nested_values.size()); + all_values.insert(all_values.end(), + nested_values.begin(), nested_values.end()); + } catch (const std::bad_any_cast&) { + DEBUG_MSG("[ERROR] ProcessNestedInitVals: 不支持的初始化值类型"); + throw std::runtime_error( + FormatError("irgen", "不支持的初始化值类型")); + } + } + } + } + + DEBUG_MSG("[DEBUG] ProcessNestedInitVals: 共获取 " << all_values.size() + << " 个初始化值"); + return all_values; +} + +// FlattenInitVal:按 C 语言花括号对齐规则展平初始化列表 +// dims[0] 是最外层维度,dims.back() 是最内层维度(元素层) +// 总元素数 = prod(dims),结果向量长度恰好为总元素数(不足处补零) +std::vector IRGenImpl::FlattenInitVal( + SysYParser::InitValContext* ctx, + const std::vector& dims, + bool is_float) { + + // 计算总元素数 + size_t total = 1; + for (int d : dims) total *= static_cast(d); + + // 零值工厂 + auto make_zero = [&]() -> ir::Value* { + if (is_float) return builder_.CreateConstFloat(0.0f); + return builder_.CreateConstInt(0); + }; + + // 先全部填零 + std::vector flat(total, nullptr); + + // 计算 depth 层从 begin 开始的子数组跨度 + // depth=0 → 每个子聚合占 dims[1]*dims[2]*... 个元素 + auto subspan = [&](size_t depth) -> size_t { + size_t span = 1; + for (size_t i = depth + 1; i < dims.size(); ++i) + span *= static_cast(dims[i]); + return span; + }; + + // 递归 fill_impl:将 node 的内容按 C 规则写入 flat[begin..end-1],返回填完后的光标 + std::function fill_impl; + fill_impl = [&](SysYParser::InitValContext* node, + size_t depth, + size_t begin, + size_t end) -> size_t { + if (!node || begin >= end) return begin; + + // 单标量初始化项(叶节点) + if (node->exp()) { + ir::Value* v = EvalExpr(*node->exp()); + if (begin < flat.size()) flat[begin] = v; + return std::min(begin + 1, end); + } + + // 聚合初始化(花括号列表) + size_t cursor = begin; + for (auto* child : node->initVal()) { + if (cursor >= end) break; + + if (child->exp()) { + // 标量子项 + ir::Value* v = EvalExpr(*child->exp()); + if (cursor < flat.size()) flat[cursor] = v; + ++cursor; + continue; + } + + // 花括号子列表 + if (depth + 1 < dims.size()) { + // 对齐到下一个子聚合边界 + const size_t span = subspan(depth); + const size_t rel = (cursor - begin) % span; + if (rel != 0) cursor += (span - rel); + if (cursor >= end) break; + + const size_t sub_end = std::min(cursor + span, end); + fill_impl(child, depth + 1, cursor, sub_end); + cursor = sub_end; // 消耗一个子聚合 + } else { + // 最内层遇到额外花括号,按顺序展开 + cursor = fill_impl(child, depth, cursor, end); + } + } + return cursor; + }; + + fill_impl(ctx, 0, 0, total); + + // 把 nullptr(未显式初始化)替换为零值 + for (auto*& v : flat) { + if (!v) v = make_zero(); + } + + return flat; +} \ No newline at end of file diff --git a/src/irgen/IRGenDriver.cpp b/src/irgen/IRGenDriver.cpp index ff94412..ac19a7b 100644 --- a/src/irgen/IRGenDriver.cpp +++ b/src/irgen/IRGenDriver.cpp @@ -6,10 +6,11 @@ #include "ir/IR.h" #include "utils/Log.h" +// 修改 GenerateIR 函数 std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, - const SemanticContext& sema) { + const SemaResult& sema_result) { auto module = std::make_unique(); - IRGenImpl gen(*module, sema); + IRGenImpl gen(*module, sema_result.context, sema_result.symbol_table); tree.accept(&gen); return module; } diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index cf4797c..7618f15 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -20,61 +20,1136 @@ // - 数组、指针、下标访问 // - 条件与比较表达式 // - ... + +// 表达式生成 ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { - return std::any_cast(expr.accept(this)); + DEBUG_MSG("[DEBUG IRGEN] EvalExpr: 开始处理表达式 " << expr.getText()); + try { + auto result_any = expr.accept(this); + + if (!result_any.has_value()) { + DEBUG_MSG("[ERROR] EvalExpr: result_any has no value"); + throw std::runtime_error("表达式求值结果为空"); + } + + try { + ir::Value* result = std::any_cast(result_any); + DEBUG_MSG("[DEBUG] EvalExpr: success, result = " << (void*)result); + return result; + } catch (const std::bad_any_cast& e) { + DEBUG_MSG("[ERROR] EvalExpr: bad any_cast - " << e.what()); + DEBUG_MSG(" Type info: " << result_any.type().name()); + throw std::runtime_error(FormatError("irgen", "表达式求值返回了错误的类型")); + } + } catch (const std::exception& e) { + DEBUG_MSG("[ERROR] Exception in EvalExpr: " << e.what()); + throw; + } } +ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) { + DEBUG_MSG("[DEBUG IRGEN] EvalCond: 开始处理条件表达式 " << cond.getText()); + return std::any_cast(cond.accept(this)); +} + +// 基本表达式:数字、变量、括号表达式 +std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitPrimaryExp: 开始处理基本表达式 " << (ctx ? ctx->getText() : "")); + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "缺少基本表达式")); + } + + DEBUG_MSG("[DEBUG] visitPrimaryExp"); + + // 处理数字字面量 + if (ctx->DECIMAL_INT()) { + int value = std::stoi(ctx->DECIMAL_INT()->getText()); + ir::Value* const_int = builder_.CreateConstInt(value); + DEBUG_MSG("[DEBUG] visitPrimaryExp: constant int " << value + << " created as " << (void*)const_int); + return static_cast(const_int); + } + + if (ctx->HEX_FLOAT()) { + std::string hex_float_str = ctx->HEX_FLOAT()->getText(); + float value = 0.0f; + try { + value = std::stof(hex_float_str); + } catch (const std::exception& e) { + DEBUG_MSG("[WARNING] 无法解析十六进制浮点数: " << hex_float_str + << ",使用0.0代替"); + value = 0.0f; + } + ir::Value* const_float = builder_.CreateConstFloat(value); + DEBUG_MSG("[DEBUG] visitPrimaryExp: constant hex float " << value + << " created as " << (void*)const_float); + return static_cast(const_float); + } + + if (ctx->DEC_FLOAT()) { + std::string dec_float_str = ctx->DEC_FLOAT()->getText(); + float value = 0.0f; + try { + value = std::stof(dec_float_str); + } catch (const std::exception& e) { + DEBUG_MSG("[WARNING] 无法解析十进制浮点数: " << dec_float_str + << ",使用0.0代替"); + value = 0.0f; + } + ir::Value* const_float = builder_.CreateConstFloat(value); + DEBUG_MSG("[DEBUG] visitPrimaryExp: constant dec float " << value + << " created as " << (void*)const_float); + return static_cast(const_float); + } -std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "非法括号表达式")); + if (ctx->HEX_INT()) { + std::string hex = ctx->HEX_INT()->getText(); + int value = std::stoi(hex, nullptr, 16); + ir::Value* const_int = builder_.CreateConstInt(value); + DEBUG_MSG("[DEBUG] visitPrimaryExp: constant hex int " << value + << " created as " << (void*)const_int); + return static_cast(const_int); + } + + if (ctx->OCTAL_INT()) { + std::string oct = ctx->OCTAL_INT()->getText(); + int value = std::stoi(oct, nullptr, 8); + ir::Value* const_int = builder_.CreateConstInt(value); + DEBUG_MSG("[DEBUG] visitPrimaryExp: constant octal int " << value + << " created as " << (void*)const_int); + return static_cast(const_int); + } + + if (ctx->ZERO()) { + ir::Value* const_int = builder_.CreateConstInt(0); + DEBUG_MSG("[DEBUG] visitPrimaryExp: constant zero int created"); + return static_cast(const_int); } - return EvalExpr(*ctx->exp()); + + // 处理变量 + if (ctx->lVal()) { + DEBUG_MSG("[DEBUG] visitPrimaryExp: visiting lVal"); + return ctx->lVal()->accept(this); + } + + // 处理括号表达式 + if (ctx->L_PAREN() && ctx->exp()) { + DEBUG_MSG("[DEBUG] visitPrimaryExp: visiting parenthesized expression"); + return EvalExpr(*ctx->exp()); + } + + DEBUG_MSG("[ERROR] visitPrimaryExp: unsupported primary expression type"); + throw std::runtime_error(FormatError("irgen", "不支持的基本表达式类型")); } +// 左值(变量)处理 +std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitLVal: 开始处理左值 " << (ctx ? ctx->getText() : "")); + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("irgen", "非法左值")); + } + + std::string varName = ctx->Ident()->getText(); + DEBUG_MSG("[DEBUG] visitLVal: " << varName); + + // 先检查语义分析中常量绑定 + const SysYParser::ConstDefContext* const_decl = sema_.ResolveConstUse(ctx); + const Symbol* sym = nullptr; + if (const_decl) { + sym = symbol_table_.lookupByConstDef(const_decl); + if (!sym) { + sym = symbol_table_.lookupAll(varName); + } + } else { + sym = symbol_table_.lookup(varName); + } + + // 如果是常量,直接返回常量值 + if (sym && sym->kind == SymbolKind::Constant) { + DEBUG_MSG("[DEBUG] visitLVal: 找到常量 " << varName); + + if (sym->IsScalarConstant()) { + if (sym->type->IsInt32()) { + ir::ConstantValue* const_val = builder_.CreateConstInt(sym->GetIntConstant()); + return static_cast(const_val); + } else if (sym->type->IsFloat()) { + ir::ConstantValue* const_val = builder_.CreateConstFloat(sym->GetFloatConstant()); + return static_cast(const_val); + } + } else if (sym->IsArrayConstant()) { + auto it = const_global_map_.find(varName); + if (it != const_global_map_.end()) { + ir::GlobalValue* global_array = it->second; + + // 尝试获取类型信息,用于维度判断与下标线性化 + auto* array_ty = dynamic_cast(sym->type.get()); + if (!array_ty) { + // 无法获取数组类型,退回返回全局对象 + return static_cast(global_array); + } + + size_t ndims = array_ty->GetDimensions().size(); + + // 有下标访问 + if (!ctx->exp().empty()) { + size_t provided = ctx->exp().size(); + + // 完全索引(所有维度都有下标)——直接返回常量元素,不生成 Load + if (provided == ndims) { + std::vector idxs; + idxs.reserve(provided); + for (auto* exp : ctx->exp()) { + ir::Value* v = EvalExpr(*exp); + if (!v || !v->IsConstant()) { + throw std::runtime_error(FormatError("irgen", "常量数组索引必须为常量整数: " + varName)); + } + auto* ci = dynamic_cast(v); + if (!ci) { + throw std::runtime_error(FormatError("irgen", "常量数组索引非整型常量: " + varName)); + } + idxs.push_back(ci->GetValue()); + } + + // 计算线性下标(行主序) + const auto& dims = array_ty->GetDimensions(); + int flat = idxs[0]; + for (size_t i = 1; i < ndims; ++i) { + flat = flat * dims[i] + idxs[i]; + } -std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量")); + ir::ConstantValue* elem = global_array->GetArrayElement(static_cast(flat)); + return static_cast(elem); + } + + // 部分索引:返回指针(不做 Load),由上层按需处理 + std::vector indices; + indices.push_back(builder_.CreateConstInt(0)); + for (auto* exp : ctx->exp()) { + indices.push_back(EvalExpr(*exp)); + } + return static_cast( + builder_.CreateGEP(global_array, indices, module_.GetContext().NextTemp())); + } else { + // 无下标,直接返回全局常量对象 + return static_cast(global_array); + } + } + } + } + + // 不是常量,按正常变量处理 + auto* decl = sema_.ResolveVarUse(ctx); + ir::Value* ptr = nullptr; + + if (decl) { + auto it = storage_map_.find(decl); + if (it != storage_map_.end()) { + ptr = it->second; + } } - return static_cast( - builder_.CreateConstInt(std::stoi(ctx->number()->getText()))); -} -// 变量使用的处理流程: -// 1. 先通过语义分析结果把变量使用绑定回声明; -// 2. 再通过 storage_map_ 找到该声明对应的栈槽位; -// 3. 最后生成 load,把内存中的值读出来。 -// -// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 -std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) { - if (!ctx || !ctx->var() || !ctx->var()->ID()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); + if (!ptr) { + auto it2 = param_map_.find(varName); + if (it2 != param_map_.end()) { + ptr = it2->second; + } } - auto* decl = sema_.ResolveVarUse(ctx->var()); - if (!decl) { - throw std::runtime_error( - FormatError("irgen", - "变量使用缺少语义绑定: " + ctx->var()->ID()->getText())); + + if (!ptr) { + auto it3 = global_map_.find(varName); + if (it3 != global_map_.end()) { + ptr = it3->second; + } } - auto it = storage_map_.find(decl); - if (it == storage_map_.end()) { + + if (!ptr) { + auto it4 = local_var_map_.find(varName); + if (it4 != local_var_map_.end()) { + ptr = it4->second; + } + } + + if (!ptr) { throw std::runtime_error( - FormatError("irgen", - "变量声明缺少存储槽位: " + ctx->var()->ID()->getText())); + FormatError("irgen", "变量声明缺少存储槽位: " + varName)); + } + + // 检查是否有数组下标 + bool is_array_access = !ctx->exp().empty(); + if (is_array_access) { + // 收集下标表达式(不含前导0) + std::vector idx_vals; + for (auto* exp : ctx->exp()) { + ir::Value* index = EvalExpr(*exp); + idx_vals.push_back(index); + } + + const Symbol* var_sym = sym; + if (!var_sym) { + var_sym = symbol_table_.lookup(varName); + } + if (!var_sym && decl) { + var_sym = symbol_table_.lookupByVarDef(decl); + } + if (!var_sym) { + var_sym = symbol_table_.lookupAll(varName); + } + + std::vector dims; + if (var_sym) { + if (var_sym->is_array_param && !var_sym->array_dims.empty()) { + dims = var_sym->array_dims; + } else if (var_sym->type && var_sym->type->IsArray()) { + auto* at = dynamic_cast(var_sym->type.get()); + if (at) dims = at->GetDimensions(); + } + } + + if (dims.empty() && ptr->GetType()->IsArray()) { + if (auto* at = dynamic_cast(ptr->GetType().get())) { + dims = at->GetDimensions(); + } + } + + // 兜底:从语法树声明提取维度,避免作用域关闭后符号查询不完整。 + if (dims.empty() && const_decl) { + auto* mutable_const_decl = const_cast(const_decl); + for (auto* cexp : mutable_const_decl->constExp()) { + dims.push_back(symbol_table_.EvaluateConstExp(cexp)); + } + } + if (dims.empty() && decl) { + for (auto* cexp : decl->constExp()) { + dims.push_back(symbol_table_.EvaluateConstExp(cexp)); + } + } + + const bool is_partial_array_access = + !dims.empty() && idx_vals.size() < dims.size(); + + // 如果 base 是标量指针(例如局部扁平数组或数组参数), + // 需要把多维下标折合为单一线性下标,然后用一个索引进行 GEP。 + if (ptr->GetType()->IsPtrInt32() || ptr->GetType()->IsPtrFloat()) { + // 如果没有维度信息,仍尝试用运行时算术合并下标(按后维乘积) + // flat = idx0 * (prod dims[1..]) + idx1 * (prod dims[2..]) + ... + ir::Value* flat = nullptr; + for (size_t i = 0; i < idx_vals.size(); ++i) { + ir::Value* term = idx_vals[i]; + if (!term) continue; + + // 计算乘数(后续维度乘积) + int mult = 1; + if (!dims.empty() && i + 1 < dims.size()) { + for (size_t j = i + 1; j < dims.size(); ++j) { + // 数组参数首维可能是 0(表示省略),不参与乘数。 + if (dims[j] > 0) mult *= dims[j]; + } + } + + if (mult != 1) { + auto* mval = builder_.CreateConstInt(mult); + term = builder_.CreateMul(term, mval, module_.GetContext().NextTemp()); + } + + if (!flat) flat = term; + else flat = builder_.CreateAdd(flat, term, module_.GetContext().NextTemp()); + } + + if (!flat) flat = builder_.CreateConstInt(0); + + // 使用单一索引创建 GEP + std::vector gep_indices = { flat }; + ir::Value* elem_ptr = builder_.CreateGEP(ptr, gep_indices, module_.GetContext().NextTemp()); + if (is_partial_array_access) { + return elem_ptr; + } + return static_cast(builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp())); + } + + std::vector indices; + // 标量指针(T*)使用单索引;数组对象使用前导0进入首层。 + if (ptr->GetType()->IsPtrInt32() || ptr->GetType()->IsPtrFloat()) { + for (auto* v : idx_vals) indices.push_back(v); + } else { + indices.push_back(builder_.CreateConstInt(0)); + for (auto* v : idx_vals) indices.push_back(v); + } + + ir::Value* elem_ptr = builder_.CreateGEP(ptr, indices, module_.GetContext().NextTemp()); + if (is_partial_array_access) { + return elem_ptr; + } + return static_cast(builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp())); + } else { + if ((sym && sym->is_array_param) || + pointer_param_names_.find(varName) != pointer_param_names_.end() || + heap_local_array_names_.find(varName) != heap_local_array_names_.end()) { + return ptr; + } + if (ptr->GetType()->IsArray()) { + return ptr; + } + return static_cast(builder_.CreateLoad(ptr, module_.GetContext().NextTemp())); + } +} + +std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitAddExp: 开始处理加法表达式 " << (ctx ? ctx->getText() : "")); + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法加法表达式")); + } + + // 如果没有 addExp(),说明是单个 mulExp() + if (!ctx->addExp()) { + return ctx->mulExp()->accept(this); + } + + // 正确提取左操作数 + auto left_any = ctx->addExp()->accept(this); + if (!left_any.has_value()) { + throw std::runtime_error(FormatError("irgen", "左操作数求值失败")); + } + ir::Value* left = std::any_cast(left_any); + + // 正确提取右操作数 + auto right_any = ctx->mulExp()->accept(this); + if (!right_any.has_value()) { + throw std::runtime_error(FormatError("irgen", "右操作数求值失败")); + } + ir::Value* right = std::any_cast(right_any); + + DEBUG_MSG("[DEBUG] visitAddExp: left=" << (void*)left + << ", type=" << (left->GetType()->IsFloat() ? "float" : "int") + << ", right=" << (void*)right + << ", type=" << (right->GetType()->IsFloat() ? "float" : "int")); + + // 处理类型转换:如果操作数类型不同,需要进行类型转换 + if (left->GetType()->IsFloat() != right->GetType()->IsFloat()) { + if (left->GetType()->IsFloat()) { + // left是float,right是int,需要将right转换为float + right = builder_.CreateSIToFP(right, ir::Type::GetFloatType()); + } else { + // right是float,left是int,需要将left转换为float + left = builder_.CreateSIToFP(left, ir::Type::GetFloatType()); + } + } + + // 根据操作符生成相应的指令 + if (ctx->AddOp()) { + if (left->GetType()->IsFloat()) { + return static_cast( + builder_.CreateFAdd(left, right, module_.GetContext().NextTemp())); + } else { + return static_cast( + builder_.CreateAdd(left, right, module_.GetContext().NextTemp())); + } + } else if (ctx->SubOp()) { + if (left->GetType()->IsFloat()) { + return static_cast( + builder_.CreateFSub(left, right, module_.GetContext().NextTemp())); + } else { + return static_cast( + builder_.CreateSub(left, right, module_.GetContext().NextTemp())); + } + } + + throw std::runtime_error(FormatError("irgen", "未知的加法操作符")); +} + + +std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitMulExp: 开始处理乘法表达式 " << (ctx ? ctx->getText() : "")); + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); + } + + // 如果是基本形式 UnaryExp + if (!ctx->mulExp()) { + return ctx->unaryExp()->accept(this); + } + + // 提取左操作数 + auto left_any = ctx->mulExp()->accept(this); + if (!left_any.has_value()) { + throw std::runtime_error(FormatError("irgen", "左操作数求值失败")); + } + ir::Value* left = std::any_cast(left_any); + + // 提取右操作数 + auto right_any = ctx->unaryExp()->accept(this); + if (!right_any.has_value()) { + throw std::runtime_error(FormatError("irgen", "右操作数求值失败")); + } + ir::Value* right = std::any_cast(right_any); + + DEBUG_MSG("[DEBUG] visitMulExp: left=" << (void*)left + << ", type=" << (left->GetType()->IsFloat() ? "float" : "int") + << ", right=" << (void*)right + << ", type=" << (right->GetType()->IsFloat() ? "float" : "int")); + + // 处理类型转换:如果操作数类型不同,需要进行类型转换 + if (left->GetType()->IsFloat() != right->GetType()->IsFloat()) { + if (left->GetType()->IsFloat()) { + // left是float,right是int,需要将right转换为float + right = builder_.CreateSIToFP(right, ir::Type::GetFloatType()); + } else { + // right是float,left是int,需要将left转换为float + left = builder_.CreateSIToFP(left, ir::Type::GetFloatType()); + } } + + // 根据操作符生成指令 + if (ctx->MulOp()) { + if (left->GetType()->IsFloat()) { + return static_cast( + builder_.CreateFMul(left, right, module_.GetContext().NextTemp())); + } else { + return static_cast( + builder_.CreateMul(left, right, module_.GetContext().NextTemp())); + } + } else if (ctx->DivOp()) { + if (left->GetType()->IsFloat()) { + return static_cast( + builder_.CreateFDiv(left, right, module_.GetContext().NextTemp())); + } else { + return static_cast( + builder_.CreateDiv(left, right, module_.GetContext().NextTemp())); + } + } else if (ctx->QuoOp()) { + // 取模运算:浮点数不支持取模,只支持整数 + if (left->GetType()->IsFloat() || right->GetType()->IsFloat()) { + throw std::runtime_error( + FormatError("irgen", "浮点数不支持取模运算")); + } + return static_cast( + builder_.CreateMod(left, right, module_.GetContext().NextTemp())); + } + + throw std::runtime_error(FormatError("irgen", "未知的乘法操作符")); +} + + + +// 逻辑与 +std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitLAndExp: 开始处理逻辑与表达式 " << (ctx ? ctx->getText() : "")); + if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); + + if (!ctx->lAndExp()) { + return ctx->eqExp()->accept(this); + } + + ir::Value* left = std::any_cast(ctx->lAndExp()->accept(this)); + ir::Value* right = std::any_cast(ctx->eqExp()->accept(this)); + + auto to_bool = [&](ir::Value* v) -> ir::Value* { + if (v->GetType()->IsInt1()) { + return v; + } + if (v->GetType()->IsFloat()) { + return builder_.CreateFCmpONE(v, builder_.CreateConstFloat(0.0f), + module_.GetContext().NextTemp()); + } + return builder_.CreateICmpNE(v, builder_.CreateConstInt(0), + module_.GetContext().NextTemp()); + }; + + auto* left_bool = to_bool(left); + auto* right_bool = to_bool(right); return static_cast( - builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); + builder_.CreateAnd(left_bool, right_bool, module_.GetContext().NextTemp())); } +// 逻辑或 +std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitLOrExp: 开始处理逻辑或表达式 " << (ctx ? ctx->getText() : "")); + if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); -std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { - throw std::runtime_error(FormatError("irgen", "非法加法表达式")); + if (!ctx->lOrExp()) { + return ctx->lAndExp()->accept(this); } - ir::Value* lhs = EvalExpr(*ctx->exp(0)); - ir::Value* rhs = EvalExpr(*ctx->exp(1)); + + ir::Value* left = std::any_cast(ctx->lOrExp()->accept(this)); + ir::Value* right = std::any_cast(ctx->lAndExp()->accept(this)); + + auto to_bool = [&](ir::Value* v) -> ir::Value* { + if (v->GetType()->IsInt1()) { + return v; + } + if (v->GetType()->IsFloat()) { + return builder_.CreateFCmpONE(v, builder_.CreateConstFloat(0.0f), + module_.GetContext().NextTemp()); + } + return builder_.CreateICmpNE(v, builder_.CreateConstInt(0), + module_.GetContext().NextTemp()); + }; + + auto* left_bool = to_bool(left); + auto* right_bool = to_bool(right); return static_cast( - builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, - module_.GetContext().NextTemp())); + builder_.CreateOr(left_bool, right_bool, module_.GetContext().NextTemp())); +} + +std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitExp: 开始处理表达式 " << (ctx ? ctx->getText() : "")); + if (!ctx) throw std::runtime_error(FormatError("irgen", "非法表达式")); + return ctx->addExp()->accept(this); +} + +std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitCond: 开始处理条件 " << (ctx ? ctx->getText() : "")); + if (!ctx) throw std::runtime_error(FormatError("irgen", "非法条件表达式")); + return ctx->lOrExp()->accept(this); +} + +std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitCallExp: 开始处理函数调用 " << (ctx ? ctx->getText() : "")); + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("irgen", "非法函数调用")); + } + + std::string funcName = ctx->Ident()->getText(); + DEBUG_MSG("[DEBUG IRGEN] visitCallExp: 调用函数 " << funcName); + + // 查找函数对象 + ir::Function* callee = module_.FindFunction(funcName); + + // 如果没找到,可能是运行时函数还没声明,尝试动态声明 + if (!callee) { + DEBUG_MSG("[DEBUG IRGEN] 函数 " << funcName << " 未找到,尝试动态声明"); + + // 根据函数名动态创建运行时函数声明 + callee = CreateRuntimeFunctionDecl(funcName); + if (!callee) { + throw std::runtime_error(FormatError("irgen", "未找到函数: " + funcName)); + } + } + + // 收集实参 + std::vector args; + if (ctx->funcRParams()) { + auto argList = ctx->funcRParams()->accept(this); + try { + args = std::any_cast>(argList); + DEBUG_MSG("[DEBUG IRGEN] visitCallExp: 收集到 " << args.size() << " 个参数"); + } catch (const std::bad_any_cast& e) { + DEBUG_MSG("[ERROR] visitCallExp: 函数调用参数类型错误: " << e.what()); + } + } + + // 按形参类型修正实参(数组衰减为指针等)。 + if (auto* fty = dynamic_cast(callee->GetType().get())) { + const auto& param_tys = fty->GetParamTypes(); + size_t n = std::min(param_tys.size(), args.size()); + for (size_t i = 0; i < n; ++i) { + if (!args[i] || !param_tys[i]) continue; + + // 数组实参传给指针形参时,执行数组到指针衰减。 + if (args[i]->GetType()->IsArray() && + (param_tys[i]->IsPtrInt32() || param_tys[i]->IsPtrFloat())) { + std::vector idx; + idx.push_back(builder_.CreateConstInt(0)); + idx.push_back(builder_.CreateConstInt(0)); + args[i] = builder_.CreateGEP(args[i], idx, module_.GetContext().NextTemp()); + } + + // 标量实参的隐式类型转换(int <-> float)。 + if (param_tys[i]->IsFloat() && args[i]->GetType()->IsInt32()) { + args[i] = builder_.CreateSIToFP(args[i], ir::Type::GetFloatType(), + module_.GetContext().NextTemp()); + } else if (param_tys[i]->IsInt32() && args[i]->GetType()->IsFloat()) { + args[i] = builder_.CreateFPToSI(args[i], ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + } + } + + // 生成调用指令 + ir::Value* callResult = builder_.CreateCall(callee, args, module_.GetContext().NextTemp()); + + // 如果函数返回 void,返回一个默认值(用于表达式上下文) + if (callResult->GetType()->IsVoid()) { + // void 函数调用不产生值,但我们返回一个 0 常量以保持类型一致性 + return static_cast(builder_.CreateConstInt(0)); + } + + DEBUG_MSG("[DEBUG IRGEN] visitCallExp: 函数调用完成,返回值 " << (void*)callResult); + return static_cast(callResult); +} + +// 动态创建运行时函数声明的辅助函数 +ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName) { + DEBUG_MSG("[DEBUG IRGEN] CreateRuntimeFunctionDecl: 开始创建运行时函数声明 " << funcName); + + // 根据常见运行时函数名创建对应的函数类型 + if (funcName == "getint" || funcName == "getch") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType(ir::Type::GetInt32Type(), {})); + } + else if (funcName == "putint" || funcName == "putch") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetInt32Type()})); + } + else if (funcName == "getarray") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetInt32Type(), + {ir::Type::GetPtrInt32Type()})); + } + else if (funcName == "putarray") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetInt32Type(), ir::Type::GetPtrInt32Type()})); + } + else if (funcName == "puts") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrInt32Type()})); + } + else if (funcName == "starttime" || funcName == "stoptime") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType(ir::Type::GetVoidType(), {})); + } + else if (funcName == "_sysy_starttime" || funcName == "_sysy_stoptime") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetInt32Type()})); + } + else if (funcName == "getfloat") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType(ir::Type::GetFloatType(), {})); + } + else if (funcName == "putfloat") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetFloatType()})); + } + else if (funcName == "getfarray") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetInt32Type(), + {ir::Type::GetPtrFloatType()})); + } + else if (funcName == "putfarray") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetInt32Type(), ir::Type::GetPtrFloatType()})); + } + else if (funcName == "memset") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetPtrInt32Type(), + {ir::Type::GetPtrInt32Type(), + ir::Type::GetInt32Type(), + ir::Type::GetInt32Type()})); + } + else if (funcName == "sysy_alloc_i32") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetPtrInt32Type(), + {ir::Type::GetInt32Type()})); + } + else if (funcName == "sysy_alloc_f32") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetPtrFloatType(), + {ir::Type::GetInt32Type()})); + } + else if (funcName == "sysy_free_i32") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrInt32Type()})); + } + else if (funcName == "sysy_free_f32") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrFloatType()})); + } + else if (funcName == "sysy_zero_i32") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrInt32Type(), ir::Type::GetInt32Type()})); + } + else if (funcName == "sysy_zero_f32") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrFloatType(), ir::Type::GetInt32Type()})); + } + + // 其他函数不支持动态创建 + return nullptr; +} + +std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitUnaryExp: 开始处理一元表达式 " << (ctx ? ctx->getText() : "")); + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法一元表达式")); + } + + // 基本表达式 + if (ctx->primaryExp()) { + return ctx->primaryExp()->accept(this); + } + + // 函数调用 + if (ctx->Ident()) { + return visitCallExp(ctx); + } + + // 一元运算 + if (ctx->unaryOp() && ctx->unaryExp()) { + auto* operand = std::any_cast(ctx->unaryExp()->accept(this)); + std::string op = ctx->unaryOp()->getText(); + + if (op == "+") { + // +x 等价于 x + return operand; + } else if (op == "-") { + // -x 根据操作数类型选择整数或浮点减法 + if (operand->GetType()->IsFloat()) { + // 浮点取负:0.0 - x + ir::Value* zero_float = builder_.CreateConstFloat(0.0f); + return static_cast( + builder_.CreateFSub(zero_float, operand, module_.GetContext().NextTemp())); + } else { + // 整数取负:0 - x + ir::Value* zero = builder_.CreateConstInt(0); + return static_cast( + builder_.CreateSub(zero, operand, module_.GetContext().NextTemp())); + } + } else if (op == "!") { + // 逻辑非运算 + // 先将值转换为bool(与0比较) + ir::Value* zero; + if (operand->GetType()->IsFloat()) { + zero = builder_.CreateConstFloat(0.0f); + // 浮点逻辑非:x == 0.0 + ir::Value* cmp = builder_.CreateFCmpOEQ(operand, zero, module_.GetContext().NextTemp()); + // 将bool转换为int + return static_cast( + builder_.CreateZExt(cmp, ir::Type::GetInt32Type())); + } else { + zero = builder_.CreateConstInt(0); + return static_cast( + builder_.CreateNot(operand, module_.GetContext().NextTemp())); + } + } + } + + throw std::runtime_error(FormatError("irgen", "暂不支持的一元表达式形式")); +} + +// 实现函数调用 +std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitFuncRParams: 开始处理函数参数 " << (ctx ? ctx->getText() : "")); + if (!ctx) return std::vector{}; + std::vector args; + for (auto* exp : ctx->exp()) { + args.push_back(EvalExpr(*exp)); + } + return args; +} + +// visitConstExp - 处理常量表达式 +std::any IRGenImpl::visitConstExp(SysYParser::ConstExpContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitConstExp: 开始处理常量表达式 " << (ctx ? ctx->getText() : "")); + if (!ctx || !ctx->addExp()) { + throw std::runtime_error(FormatError("irgen", "非法常量表达式")); + } + + auto result = ctx->addExp()->accept(this); + + if (!result.has_value()) { + throw std::runtime_error(FormatError("irgen", "常量表达式求值失败")); + } + + try { + return std::any_cast(result); + } catch (const std::bad_any_cast& e) { + throw std::runtime_error(FormatError("irgen", + "常量表达式返回类型错误: " + std::string(e.what()))); + } +} + +// visitConstInitVal - 处理常量初始化值 +std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitConstInitVal: 开始处理常量初始化值 " << (ctx ? ctx->getText() : "")); + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法常量初始化值")); + } + + // 如果是单个常量表达式 + if (ctx->constExp()) { + return ctx->constExp()->accept(this); + } + // 如果是聚合初始化(花括号列表) + else if (!ctx->constInitVal().empty()) { + std::vector all_values; + + for (auto* init_val : ctx->constInitVal()) { + auto result = init_val->accept(this); + if (!result.has_value()) { + throw std::runtime_error(FormatError("irgen", "常量初始化值求值失败")); + } + + try { + // 尝试获取单个常量值 + ir::Value* value = std::any_cast(result); + all_values.push_back(value); + } catch (const std::bad_any_cast&) { + try { + // 尝试获取值列表(嵌套情况) + std::vector nested_values = + std::any_cast>(result); + all_values.insert(all_values.end(), + nested_values.begin(), nested_values.end()); + } catch (const std::bad_any_cast& e) { + throw std::runtime_error(FormatError("irgen", + "不支持的常量初始化值类型: " + std::string(e.what()))); + } + } + } + + return all_values; + } + + // 空初始化列表 + return std::vector{}; +} + +std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitRelExp: 开始处理关系表达式 " << (ctx ? ctx->getText() : "")); + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法关系表达式")); + } + + if (ctx->relExp() && ctx->addExp()) { + auto left_any = ctx->relExp()->accept(this); + auto right_any = ctx->addExp()->accept(this); + auto* lhs = std::any_cast(left_any); + auto* rhs = std::any_cast(right_any); + + DEBUG_MSG("[DEBUG] visitRelExp: left=" << (void*)lhs + << ", type=" << (lhs->GetType()->IsFloat() ? "float" : "int") + << ", right=" << (void*)rhs + << ", type=" << (rhs->GetType()->IsFloat() ? "float" : "int")); + + // 处理类型转换:如果操作数类型不同,需要进行类型转换 + if (lhs->GetType()->IsFloat() != rhs->GetType()->IsFloat()) { + if (lhs->GetType()->IsFloat()) { + // lhs是float,rhs是int,需要将rhs转换为float + rhs = builder_.CreateSIToFP(rhs, ir::Type::GetFloatType()); + } else { + // rhs是float,lhs是int,需要将lhs转换为float + lhs = builder_.CreateSIToFP(lhs, ir::Type::GetFloatType()); + } + } + + // 根据操作数和类型选择指令 + if (ctx->LOp()) { + if (lhs->GetType()->IsFloat()) { + return static_cast( + builder_.CreateFCmpOLT(lhs, rhs, module_.GetContext().NextTemp())); + } else { + return static_cast( + builder_.CreateICmpLT(lhs, rhs, module_.GetContext().NextTemp())); + } + } + if (ctx->GOp()) { + if (lhs->GetType()->IsFloat()) { + return static_cast( + builder_.CreateFCmpOGT(lhs, rhs, module_.GetContext().NextTemp())); + } else { + return static_cast( + builder_.CreateICmpGT(lhs, rhs, module_.GetContext().NextTemp())); + } + } + if (ctx->LeOp()) { + if (lhs->GetType()->IsFloat()) { + return static_cast( + builder_.CreateFCmpOLE(lhs, rhs, module_.GetContext().NextTemp())); + } else { + return static_cast( + builder_.CreateICmpLE(lhs, rhs, module_.GetContext().NextTemp())); + } + } + if (ctx->GeOp()) { + if (lhs->GetType()->IsFloat()) { + return static_cast( + builder_.CreateFCmpOGE(lhs, rhs, module_.GetContext().NextTemp())); + } else { + return static_cast( + builder_.CreateICmpGE(lhs, rhs, module_.GetContext().NextTemp())); + } + } + throw std::runtime_error(FormatError("irgen", "未知关系运算符")); + } + + if (ctx->addExp()) { + return ctx->addExp()->accept(this); + } + + throw std::runtime_error(FormatError("irgen", "关系表达式暂未实现")); +} + +std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitEqExp: 开始处理相等表达式 " << (ctx ? ctx->getText() : "")); + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法相等表达式")); + } + + if (ctx->eqExp() && ctx->relExp()) { + auto left_any = ctx->eqExp()->accept(this); + auto right_any = ctx->relExp()->accept(this); + auto* lhs = std::any_cast(left_any); + auto* rhs = std::any_cast(right_any); + + DEBUG_MSG("[DEBUG] visitEqExp: left=" << (void*)lhs + << ", type=" << (lhs->GetType()->IsFloat() ? "float" : "int") + << ", right=" << (void*)rhs + << ", type=" << (rhs->GetType()->IsFloat() ? "float" : "int")); + + // 处理类型转换:如果操作数类型不同,需要进行类型转换 + if (lhs->GetType()->IsFloat() != rhs->GetType()->IsFloat()) { + if (lhs->GetType()->IsFloat()) { + // lhs是float,rhs是int,需要将rhs转换为float + rhs = builder_.CreateSIToFP(rhs, ir::Type::GetFloatType()); + } else { + // rhs是float,lhs是int,需要将lhs转换为float + lhs = builder_.CreateSIToFP(lhs, ir::Type::GetFloatType()); + } + } + + // 根据操作符和类型选择指令 + if (ctx->EqOp()) { + if (lhs->GetType()->IsFloat()) { + return static_cast( + builder_.CreateFCmpOEQ(lhs, rhs, module_.GetContext().NextTemp())); + } else { + return static_cast( + builder_.CreateICmpEQ(lhs, rhs, module_.GetContext().NextTemp())); + } + } + if (ctx->NeOp()) { + if (lhs->GetType()->IsFloat()) { + return static_cast( + builder_.CreateFCmpONE(lhs, rhs, module_.GetContext().NextTemp())); + } else { + return static_cast( + builder_.CreateICmpNE(lhs, rhs, module_.GetContext().NextTemp())); + } + } + throw std::runtime_error(FormatError("irgen", "未知相等运算符")); + } + + if (ctx->relExp()) { + return ctx->relExp()->accept(this); + } + + throw std::runtime_error(FormatError("irgen", "相等表达式暂未实现")); +} + + +ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] EvalAssign: 开始处理赋值语句 " << (ctx ? ctx->getText() : "")); + DEBUG_MSG("[DEBUG IRGEN] visitCond: " << (ctx ? ctx->getText() : "")); + if (!ctx || !ctx->lVal() || !ctx->exp()) { + throw std::runtime_error(FormatError("irgen", "非法赋值语句")); + } + + // 计算右值 + ir::Value* rhs = EvalExpr(*ctx->exp()); + + auto* lval = ctx->lVal(); + std::string varName = lval->Ident()->getText(); + + // 首先尝试从语义分析获取变量定义 + auto* var_decl = sema_.ResolveVarUse(lval); + + if (var_decl) { + // 是变量赋值 + // 从storage_map_获取存储位置 + auto it = storage_map_.find(var_decl); + if (it == storage_map_.end()) { + throw std::runtime_error( + FormatError("irgen", "变量声明缺少存储槽位: " + varName)); + } + + ir::Value* base_ptr = it->second; + + auto convert_for_store = [&](ir::Value* value, ir::Value* ptr) -> ir::Value* { + if (ptr->GetType()->IsPtrFloat() && value->GetType()->IsInt32()) { + return builder_.CreateSIToFP(value, ir::Type::GetFloatType(), + module_.GetContext().NextTemp()); + } + if (ptr->GetType()->IsPtrInt32() && value->GetType()->IsFloat()) { + return builder_.CreateFPToSI(value, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + return value; + }; + + // 检查是否有数组下标 + auto exp_list = lval->exp(); + if (!exp_list.empty()) { + // 这是数组元素赋值,需要生成GEP指令 + std::vector indices; + + // 标量指针参数(T*)不应添加前导0;数组对象需要前导0。 + if (!(base_ptr->GetType()->IsPtrInt32() || base_ptr->GetType()->IsPtrFloat())) { + indices.push_back(builder_.CreateConstInt(0)); + } + + // 添加用户提供的下标 + for (auto* exp : exp_list) { + ir::Value* index = EvalExpr(*exp); + indices.push_back(index); + } + + // 生成GEP指令获取元素地址 + ir::Value* elem_ptr = builder_.CreateGEP( + base_ptr, indices, module_.GetContext().NextTemp()); + + // 生成store指令 + rhs = convert_for_store(rhs, elem_ptr); + builder_.CreateStore(rhs, elem_ptr); + } else { + // 普通标量赋值 + // 调试输出指针类型 + DEBUG_MSG("[DEBUG] base_ptr type: " << base_ptr->GetType()); + DEBUG_MSG("[DEBUG] rhs type: " << rhs->GetType()); + + // 如果 base_ptr 不是标量指针类型,可能需要特殊处理 + if (!base_ptr->GetType()->IsPtrInt32() && !base_ptr->GetType()->IsPtrFloat()) { + DEBUG_MSG("[ERROR] base_ptr is not a pointer type!"); + throw std::runtime_error("尝试存储到非指针类型"); + } + rhs = convert_for_store(rhs, base_ptr); + builder_.CreateStore(rhs, base_ptr); + } + } else { + // 尝试获取常量定义 + auto* const_decl = sema_.ResolveConstUse(lval); + if (const_decl) { + // 尝试给常量赋值,这是错误的 + throw std::runtime_error( + FormatError("irgen", "不能给常量赋值: " + varName)); + } else { + throw std::runtime_error( + FormatError("irgen", "变量/常量使用缺少语义绑定: " + varName)); + } + } + + return rhs; } diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 4912d03..df827ad 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -1,6 +1,8 @@ #include "irgen/IRGen.h" +#include #include +#include #include "SysYParser.h" #include "ir/IR.h" @@ -9,7 +11,6 @@ namespace { void VerifyFunctionStructure(const ir::Function& func) { - // 当前 IRGen 仍是单入口、顺序生成;这里在生成结束后补一层块终结校验。 for (const auto& bb : func.GetBlocks()) { if (!bb || !bb->HasTerminator()) { throw std::runtime_error( @@ -19,69 +20,498 @@ void VerifyFunctionStructure(const ir::Function& func) { } } +bool HasDirectSelfCall(antlr4::ParserRuleContext* node, + const std::string& func_name) { + if (!node) { + return false; + } + + if (auto* unary = dynamic_cast(node)) { + if (unary->Ident() && unary->Ident()->getText() == func_name) { + return true; + } + } + + for (auto* child : node->children) { + if (auto* rule = dynamic_cast(child)) { + if (HasDirectSelfCall(rule, func_name)) { + return true; + } + } + } + return false; +} + } // namespace -IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema) - : module_(module), - sema_(sema), - func_(nullptr), - builder_(module.GetContext(), nullptr) {} - -// 编译单元的 IR 生成当前只实现了最小功能: -// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容; -// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR; -// -// 当前还没有实现: -// - 多个函数定义的遍历与生成; -// - 全局变量、全局常量的 IR 生成。 +// 实现新的构造函数 +IRGenImpl::IRGenImpl(ir::Module& module, + const SemanticContext& sema, + const SymbolTable& sym_table) + : module_(module), sema_(sema), symbol_table_(sym_table), + builder_(module.GetContext(), nullptr), func_(nullptr) { + AddRuntimeFunctions(); +} + +void IRGenImpl::AddRuntimeFunctions() { + DEBUG_MSG("[DEBUG IRGEN] 添加运行时库函数声明"); + + // 输入函数(返回 int) + module_.CreateFunction("getint", + ir::Type::GetFunctionType(ir::Type::GetInt32Type(), {})); + module_.CreateFunction("getch", + ir::Type::GetFunctionType(ir::Type::GetInt32Type(), {})); + + // getarray(int* a, int n): int + module_.CreateFunction("getarray", + ir::Type::GetFunctionType( + ir::Type::GetInt32Type(), + {ir::Type::GetPtrInt32Type()})); + + // 输出函数(返回 void) + module_.CreateFunction("putint", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetInt32Type()})); + module_.CreateFunction("putch", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetInt32Type()})); + + // putarray(int n, int* a): void + module_.CreateFunction("putarray", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetInt32Type(), ir::Type::GetPtrInt32Type()})); + + // 字符串输出(暂时用 int* 替代 char*) + module_.CreateFunction("puts", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrInt32Type()})); + + // 时间测量函数(SysY 标准库) + module_.CreateFunction("_sysy_starttime", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetInt32Type()})); + module_.CreateFunction("_sysy_stoptime", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetInt32Type()})); + + // 简化版本 + module_.CreateFunction("starttime", + ir::Type::GetFunctionType(ir::Type::GetVoidType(), {})); + module_.CreateFunction("stoptime", + ir::Type::GetFunctionType(ir::Type::GetVoidType(), {})); + + // 浮点 I/O + module_.CreateFunction("getfloat", + ir::Type::GetFunctionType(ir::Type::GetFloatType(), {})); + module_.CreateFunction("putfloat", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetFloatType()})); + module_.CreateFunction("getfarray", + ir::Type::GetFunctionType( + ir::Type::GetInt32Type(), + {ir::Type::GetPtrFloatType()})); + module_.CreateFunction("putfarray", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetInt32Type(), ir::Type::GetPtrFloatType()})); + + // 内存操作函数 + module_.CreateFunction("memset", + ir::Type::GetFunctionType( + ir::Type::GetPtrInt32Type(), + {ir::Type::GetPtrInt32Type(), + ir::Type::GetInt32Type(), + ir::Type::GetInt32Type()})); + + module_.CreateFunction("sysy_alloc_i32", + ir::Type::GetFunctionType( + ir::Type::GetPtrInt32Type(), + {ir::Type::GetInt32Type()})); + module_.CreateFunction("sysy_alloc_f32", + ir::Type::GetFunctionType( + ir::Type::GetPtrFloatType(), + {ir::Type::GetInt32Type()})); + module_.CreateFunction("sysy_free_i32", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrInt32Type()})); + module_.CreateFunction("sysy_free_f32", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrFloatType()})); + module_.CreateFunction("sysy_zero_i32", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrInt32Type(), ir::Type::GetInt32Type()})); + module_.CreateFunction("sysy_zero_f32", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrFloatType(), ir::Type::GetInt32Type()})); + + DEBUG_MSG("[DEBUG IRGEN] 运行时库函数声明完成"); +} + +// 修正:没有 mainFuncDef,通过函数名找到 main std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitCompUnit"); + DEBUG_MSG("[DEBUG] IRGen: 符号表地址 = " << &symbol_table_); + DEBUG_MSG("[DEBUG] IRGen: 开始生成 IR"); + + // 尝试查找 main 函数 + const Symbol* main_sym = symbol_table_.lookup("main"); + if (main_sym) { + DEBUG_MSG("[DEBUG] IRGen: 找到 main 函数符号"); + } else { + DEBUG_MSG("[DEBUG] IRGen: 未找到 main 函数符号"); + } if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少编译单元")); } - auto* func = ctx->funcDef(); - if (!func) { - throw std::runtime_error(FormatError("irgen", "缺少函数定义")); + + // 处理全局变量声明 + for (auto* decl : ctx->decl()) { + if (decl) { + decl->accept(this); + } + } + + // 处理所有函数定义 + for (auto* funcDef : ctx->funcDef()) { + if (funcDef) { + funcDef->accept(this); + } } - func->accept(this); + return {}; } -// 函数 IR 生成当前实现了: -// 1. 获取函数名; -// 2. 检查函数返回类型; -// 3. 在 Module 中创建 Function; -// 4. 将 builder 插入点设置到入口基本块; -// 5. 继续生成函数体。 -// -// 当前还没有实现: -// - 通用函数返回类型处理; -// - 形参列表遍历与参数类型收集; -// - FunctionType 这样的函数类型对象; -// - Argument/形式参数 IR 对象; -// - 入口块中的参数初始化逻辑。 -// ... - -// 因此这里目前只支持最小的“无参 int 函数”生成。 std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitFuncDef: " << (ctx && ctx->Ident() ? ctx->Ident()->getText() : "")); if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少函数定义")); } - if (!ctx->blockStmt()) { + + if (!ctx->Ident()) { + throw std::runtime_error(FormatError("irgen", "缺少函数名")); + } + + std::string funcName = ctx->Ident()->getText(); + + if (!ctx->block()) { throw std::runtime_error(FormatError("irgen", "函数体为空")); } - if (!ctx->ID()) { - throw std::runtime_error(FormatError("irgen", "缺少函数名")); + + std::shared_ptr ret_type = ir::Type::GetInt32Type(); + if (ctx->funcType()) { + if (ctx->funcType()->Void()) { + ret_type = ir::Type::GetVoidType(); + } else if (ctx->funcType()->Int()) { + ret_type = ir::Type::GetInt32Type(); + } else if (ctx->funcType()->Float()) { + ret_type = ir::Type::GetFloatType(); + } } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数")); + + std::vector> param_types; + if (ctx->funcFParams()) { + for (auto* param : ctx->funcFParams()->funcFParam()) { + if (!param || !param->Ident()) continue; + std::string name = param->Ident()->getText(); + std::shared_ptr param_ty; + + // 检查 bType 是否存在 + if (!param->bType()) { + throw std::runtime_error(FormatError("irgen", "函数参数缺少类型: " + name)); + } + + if (param->bType()->Int()) { + param_ty = ir::Type::GetInt32Type(); + } else if (param->bType()->Float()) { + param_ty = ir::Type::GetFloatType(); + } else { + param_ty = ir::Type::GetInt32Type(); // 默认值 + } + + if (!param->L_BRACK().empty()) { + if (param_ty->IsInt32()) { + param_ty = ir::Type::GetPtrInt32Type(); + } else if (param_ty->IsFloat()) { + param_ty = ir::Type::GetPtrFloatType(); + } + } + + param_types.push_back(param_ty); + } } - func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type()); - builder_.SetInsertPoint(func_->GetEntry()); + // 创建函数类型 + auto func_type = ir::Type::GetFunctionType(ret_type, param_types); + + // 调试输出 + DEBUG_MSG("[DEBUG] visitFuncDef: 创建函数 " << funcName + << ",返回类型: " << (ret_type->IsVoid() ? "void" : ret_type->IsFloat() ? "float" : "int") + << ",参数数量: " << param_types.size()); + + // 创建函数对象 + func_ = module_.CreateFunction(funcName, func_type); + + // 检查函数是否成功创建 + if (!func_) { + DEBUG_MSG("[ERROR] visitFuncDef: 创建函数失败,func_ 为 nullptr!"); + throw std::runtime_error(FormatError("irgen", "创建函数失败: " + funcName)); + } + + DEBUG_MSG("[DEBUG] visitFuncDef: 函数对象地址: " << (void*)func_); + + // 设置插入点 + auto* entry_block = func_->GetEntry(); + if (!entry_block) { + DEBUG_MSG("[ERROR] visitFuncDef: 函数入口基本块为空!"); + throw std::runtime_error(FormatError("irgen", "函数入口基本块为空: " + funcName)); + } + + builder_.SetInsertPoint(entry_block); storage_map_.clear(); + param_map_.clear(); + pointer_param_names_.clear(); + heap_local_array_names_.clear(); + current_function_name_ = funcName; + current_function_is_recursive_ = HasDirectSelfCall(ctx->block(), funcName); + function_cleanup_block_ = nullptr; + function_cleanup_actions_.clear(); + function_return_slot_ = nullptr; + + if (ret_type->IsInt32()) { + function_return_slot_ = CreateEntryAllocaI32(module_.GetContext().NextTemp() + ".retval"); + } else if (ret_type->IsFloat()) { + function_return_slot_ = CreateEntryAllocaFloat(module_.GetContext().NextTemp() + ".retval"); + } + + // 函数参数处理 + if (ctx->funcFParams()) { + for (auto* param : ctx->funcFParams()->funcFParam()) { + if (!param || !param->Ident()) continue; + std::string name = param->Ident()->getText(); + std::shared_ptr param_ty; + + // 再次检查 bType + if (!param->bType()) { + throw std::runtime_error(FormatError("irgen", "函数参数缺少类型: " + name)); + } + + if (param->bType()->Int()) { + param_ty = ir::Type::GetInt32Type(); + } else if (param->bType()->Float()) { + param_ty = ir::Type::GetFloatType(); + } else { + param_ty = ir::Type::GetInt32Type(); + } + + if (!param->L_BRACK().empty()) { + if (param_ty->IsInt32()) { + param_ty = ir::Type::GetPtrInt32Type(); + } else if (param_ty->IsFloat()) { + param_ty = ir::Type::GetPtrFloatType(); + } + } + + // 检查函数对象是否有效 + if (!func_) { + DEBUG_MSG("[ERROR] visitFuncDef: func_ 在添加参数时变为 nullptr!"); + throw std::runtime_error(FormatError("irgen", "函数对象无效")); + } + + DEBUG_MSG("[DEBUG] visitFuncDef: 为函数 " << funcName + << " 添加参数 " << name << ",类型: " + << (param_ty->IsInt32() ? "int32" : param_ty->IsFloat() ? "float" : + param_ty->IsPtrInt32() ? "ptr_int32" : param_ty->IsPtrFloat() ? "ptr_float" : "other") + ); + + // 创建参数并添加到函数 + auto arg = std::make_unique(param_ty, name); + if (!arg) { + throw std::runtime_error(FormatError("irgen", "创建参数失败: " + name)); + } + + auto* arg_ptr = arg.get(); + auto* added_arg = func_->AddArgument(std::move(arg)); + + if (!added_arg) { + DEBUG_MSG("[ERROR] visitFuncDef: AddArgument 返回 nullptr!"); + throw std::runtime_error(FormatError("irgen", "添加参数失败: " + name)); + } + + // 标量参数:入栈到本地槽位;数组参数(指针)直接作为地址使用。 + if (param_ty->IsPtrInt32() || param_ty->IsPtrFloat()) { + param_map_[name] = added_arg; + pointer_param_names_.insert(name); + } else { + ir::AllocaInst* slot = nullptr; + if (param_ty->IsInt32()) { + slot = CreateEntryAllocaI32(module_.GetContext().NextTemp()); + } else if (param_ty->IsFloat()) { + slot = CreateEntryAllocaFloat(module_.GetContext().NextTemp()); + } else { + throw std::runtime_error(FormatError("irgen", "不支持的参数类型")); + } + + if (!slot) { + throw std::runtime_error(FormatError("irgen", "创建参数存储槽位失败: " + name)); + } + + builder_.CreateStore(added_arg, slot); + param_map_[name] = slot; + pointer_param_names_.erase(name); + } + + DEBUG_MSG("[DEBUG] visitFuncDef: 参数 " << name << " 处理完成"); + } + } + + // 生成函数体 + DEBUG_MSG("[DEBUG] visitFuncDef: 开始生成函数体"); + ctx->block()->accept(this); - ctx->blockStmt()->accept(this); - // 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 - VerifyFunctionStructure(*func_); + // 如果当前插入块没有终止指令,添加默认返回 + if (auto* cur = builder_.GetInsertBlock(); cur && !cur->HasTerminator()) { + DEBUG_MSG("[DEBUG] visitFuncDef: 函数体没有终止指令,添加默认返回"); + if (function_cleanup_block_) { + if (ret_type->IsFloat()) { + builder_.CreateStore(builder_.CreateConstFloat(0.0f), function_return_slot_); + } else if (ret_type->IsInt32()) { + builder_.CreateStore(builder_.CreateConstInt(0), function_return_slot_); + } + builder_.CreateBr(function_cleanup_block_); + } else if (ret_type->IsVoid()) { + builder_.CreateRet(nullptr); + } else if (ret_type->IsFloat()) { + builder_.CreateRet(builder_.CreateConstFloat(0.0f)); + } else { + builder_.CreateRet(builder_.CreateConstInt(0)); + } + } + + if (function_cleanup_block_ && !function_cleanup_block_->HasTerminator()) { + builder_.SetInsertPoint(function_cleanup_block_); + for (auto it = function_cleanup_actions_.rbegin(); + it != function_cleanup_actions_.rend(); ++it) { + builder_.CreateCall(it->first, {it->second}, module_.GetContext().NextTemp()); + } + + if (ret_type->IsVoid()) { + builder_.CreateRet(nullptr); + } else { + builder_.CreateRet(builder_.CreateLoad(function_return_slot_, module_.GetContext().NextTemp())); + } + } + + // 验证函数结构 + try { + VerifyFunctionStructure(*func_); + } catch (const std::exception& e) { + DEBUG_MSG("[ERROR] visitFuncDef: 验证函数结构失败: " << e.what()); + throw; + } + + DEBUG_MSG("[DEBUG] visitFuncDef: 函数 " << funcName << " 生成完成"); + func_ = nullptr; + current_function_name_.clear(); + current_function_is_recursive_ = false; + function_return_slot_ = nullptr; + function_cleanup_block_ = nullptr; + function_cleanup_actions_.clear(); return {}; } + +ir::BasicBlock* IRGenImpl::EnsureCleanupBlock() { + if (!function_cleanup_block_) { + std::string name = module_.GetContext().NextTemp(); + if (!name.empty() && name[0] == '%') { + name.erase(0, 1); + } + function_cleanup_block_ = func_->CreateBlock("cleanup." + name); + } + return function_cleanup_block_; +} + +void IRGenImpl::RegisterCleanup(ir::Function* free_func, ir::Value* ptr) { + if (!free_func || !ptr) { + return; + } + EnsureCleanupBlock(); + function_cleanup_actions_.push_back({free_func, ptr}); +} + +ir::AllocaInst* IRGenImpl::CreateEntryAlloca(std::shared_ptr ty, + const std::string& name) { + if (!func_ || !func_->GetEntry()) { + throw std::runtime_error(FormatError("irgen", "缺少函数入口块,无法创建入口栈槽位")); + } + return func_->GetEntry()->InsertBeforeTerminator(ty, name); +} + +ir::AllocaInst* IRGenImpl::CreateEntryAllocaI32(const std::string& name) { + return CreateEntryAlloca(ir::Type::GetPtrInt32Type(), name); +} + +ir::AllocaInst* IRGenImpl::CreateEntryAllocaFloat(const std::string& name) { + return CreateEntryAlloca(ir::Type::GetPtrFloatType(), name); +} + + +std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitBlock: " << (ctx ? ctx->getText() : "")); + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "缺少语句块")); + } + + BlockFlow flow = BlockFlow::Continue; + for (auto* item : ctx->blockItem()) { + if (!item) continue; + + flow = VisitBlockItemResult(*item); + if (flow == BlockFlow::Terminated) { + break; + } + + auto* cur = builder_.GetInsertBlock(); + DEBUG_MSG("[DEBUG] current insert block: " + << (cur ? cur->GetName() : "")); + if (cur && cur->HasTerminator()) { + break; + } + } + + return flow; +} + +// 类型安全的包装器 +IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult( + SysYParser::BlockItemContext& item) { + auto result = item.accept(this);// 调用 visitBlockItem,返回 std::any 包装的 BlockFlow + return std::any_cast(result); // 解包为 BlockFlow +} +// 用于遍历块内项,返回是否继续访问后续项(如遇到 return/break/continue 则终止访问) +std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitBlockItem: " << (ctx ? ctx->getText() : "")); + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "缺少块内项")); + } + // 块内项可以是语句或声明,优先处理语句(如 return/break/continue 可能终止块内执行) + if (ctx->stmt()) { + return ctx->stmt()->accept(this); // 语句访问返回 BlockFlow,指示是否继续访问后续项 + } + // 处理声明(如变量定义),继续访问后续项 + if (ctx->decl()) { + ctx->decl()->accept(this); + return BlockFlow::Continue; // 声明不会终止块内执行,继续访问后续项 + } + throw std::runtime_error(FormatError("irgen", "暂不支持的块内项")); +} diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index 751550c..7dad686 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -16,24 +16,512 @@ // - 空语句、块语句嵌套分发之外的更多语句形态 std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] visitStmt: " << (ctx ? ctx->getText() : "")); if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句")); } - if (ctx->returnStmt()) { - return ctx->returnStmt()->accept(this); + + // return 语句 - 通过 Return() 关键字判断 + if (ctx->Return()) { + return HandleReturnStmt(ctx); + } + + // 赋值语句 + if (ctx->lVal() && ctx->Assign() && ctx->exp()) { + return HandleAssignStmt(ctx); + } + + // if 语句 + if (ctx->If()) { + return HandleIfStmt(ctx); + } + + // while 语句 + if (ctx->While()) { + return HandleWhileStmt(ctx); + } + + // break 语句 + if (ctx->Break()) { + return HandleBreakStmt(ctx); + } + // continue 语句 + if (ctx->Continue()) { + return HandleContinueStmt(ctx); + } + // 块语句 + if (ctx->block()) { + return ctx->block()->accept(this); + } + + // 空语句或表达式语句(先计算表达式) + if (ctx->exp()) { + EvalExpr(*ctx->exp()); + return BlockFlow::Continue; } + throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); } +// 修改 HandleReturnStmt 函数 - -std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { +IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] HandleReturnStmt: " << (ctx ? ctx->getText() : "")); if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少 return 语句")); } - if (!ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "return 缺少表达式")); + + // 检查函数是否存在 + if (!func_) { + throw std::runtime_error(FormatError("irgen", "Return语句不在函数中")); + } + + // 获取函数类型中的返回类型 + auto func_type = std::dynamic_pointer_cast(func_->GetType()); + if (!func_type) { + throw std::runtime_error(FormatError("irgen", "函数类型无效")); + } + + auto ret_type = func_type->GetReturnType(); + + if (ret_type->IsVoid()) { + if (ctx->exp()) { + // 表达式被忽略(可计算但不使用) + EvalExpr(*ctx->exp()); + } + if (function_cleanup_block_) { + builder_.CreateBr(function_cleanup_block_); + } else { + // 对于void函数,创建返回指令(不传参数) + builder_.CreateRet(nullptr); + } + } else { + ir::Value* retValue = nullptr; + if (ctx->exp()) { + retValue = EvalExpr(*ctx->exp()); + if (!retValue) { + throw std::runtime_error(FormatError("irgen", "返回值表达式求值失败")); + } + // 类型转换 + if (retValue->GetType() != ret_type) { + if (ret_type->IsInt32() && retValue->GetType()->IsFloat()) { + retValue = builder_.CreateFPToSI(retValue, ir::Type::GetInt32Type()); + } else if (ret_type->IsFloat() && retValue->GetType()->IsInt32()) { + retValue = builder_.CreateSIToFP(retValue, ir::Type::GetFloatType()); + } + } + } else { + // 无表达式,返回默认值 + if (ret_type->IsInt32()) { + retValue = builder_.CreateConstInt(0); + } else if (ret_type->IsFloat()) { + retValue = builder_.CreateConstFloat(0.0f); + } else { + retValue = builder_.CreateConstInt(0); // fallback + } + } + if (function_cleanup_block_) { + builder_.CreateStore(retValue, function_return_slot_); + builder_.CreateBr(function_cleanup_block_); + } else { + builder_.CreateRet(retValue); + } } - ir::Value* v = EvalExpr(*ctx->exp()); - builder_.CreateRet(v); return BlockFlow::Terminated; } + + +// if语句(待实现) +IRGenImpl::BlockFlow IRGenImpl::HandleIfStmt(SysYParser::StmtContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] HandleIfStmt: " << (ctx ? ctx->getText() : "")); + + auto* cond = ctx->cond(); + auto* thenStmt = ctx->stmt(0); + auto* elseStmt = ctx->stmt(1); + + // 创建基本块(使用唯一名称,避免同名标签) + auto uniq = [&](const std::string& prefix) { + std::string t = module_.GetContext().NextTemp(); + if (!t.empty() && t[0] == '%') t.erase(0, 1); + return prefix + "." + t; + }; + auto* thenBlock = func_->CreateBlock(uniq("then")); + auto* elseBlock = (ctx->Else() && elseStmt) ? func_->CreateBlock(uniq("else")) : nullptr; + auto* mergeBlock = func_->CreateBlock(uniq("merge")); + + DEBUG_MSG("[DEBUG IF] thenBlock: " << thenBlock->GetName()); + if (elseBlock) DEBUG_MSG("[DEBUG IF] elseBlock: " << elseBlock->GetName()); + DEBUG_MSG("[DEBUG IF] mergeBlock: " << mergeBlock->GetName()); + DEBUG_MSG("[DEBUG IF] current insert block before cond: " + << builder_.GetInsertBlock()->GetName()); + + // 生成条件 + auto* condValue = EvalCond(*cond); + if (!condValue->GetType()->IsInt1()) { + if (condValue->GetType()->IsFloat()) { + condValue = builder_.CreateFCmpONE( + condValue, builder_.CreateConstFloat(0.0f), module_.GetContext().NextTemp()); + } else { + condValue = builder_.CreateICmpNE( + condValue, builder_.CreateConstInt(0), module_.GetContext().NextTemp()); + } + } + + // 创建条件跳转 + if (elseBlock) { + DEBUG_MSG("[DEBUG IF] Creating condbr: " << condValue->GetName() + << " -> " << thenBlock->GetName() << ", " << elseBlock->GetName()); + builder_.CreateCondBr(condValue, thenBlock, elseBlock); + } else { + DEBUG_MSG("[DEBUG IF] Creating condbr: " << condValue->GetName() + << " -> " << thenBlock->GetName() << ", " << mergeBlock->GetName()); + builder_.CreateCondBr(condValue, thenBlock, mergeBlock); + } + + // 生成 then 分支 + DEBUG_MSG("[DEBUG IF] Generating then branch in block: " << thenBlock->GetName()); + builder_.SetInsertPoint(thenBlock); + auto thenResult = thenStmt->accept(this); + bool thenTerminated = (std::any_cast(thenResult) == BlockFlow::Terminated); + DEBUG_MSG("[DEBUG IF] then branch terminated: " << thenTerminated); + + if (!thenTerminated) { + DEBUG_MSG("[DEBUG IF] Adding br to merge block from then"); + builder_.CreateBr(mergeBlock); + } + DEBUG_MSG("[DEBUG IF] then block has terminator: " << thenBlock->HasTerminator()); + + // 生成 else 分支 + bool elseTerminated = false; + if (elseBlock) { + DEBUG_MSG("[DEBUG IF] Generating else branch in block: " << elseBlock->GetName()); + builder_.SetInsertPoint(elseBlock); + auto elseResult = elseStmt->accept(this); + elseTerminated = (std::any_cast(elseResult) == BlockFlow::Terminated); + DEBUG_MSG("[DEBUG IF] else branch terminated: " << elseTerminated); + + if (!elseTerminated) { + DEBUG_MSG("[DEBUG IF] Adding br to merge block from else"); + builder_.CreateBr(mergeBlock); + } + DEBUG_MSG("[DEBUG IF] else block has terminator: " << elseBlock->HasTerminator()); + } + + // 决定后续插入点 + DEBUG_MSG("[DEBUG IF] thenTerminated=" << thenTerminated + << ", elseTerminated=" << elseTerminated); + + if (elseBlock) { + DEBUG_MSG("[DEBUG IF] Setting insert point to merge block: " + << mergeBlock->GetName()); + builder_.SetInsertPoint(mergeBlock); + } else { + DEBUG_MSG("[DEBUG IF] No else, setting insert point to merge block: " + << mergeBlock->GetName()); + builder_.SetInsertPoint(mergeBlock); + } + + DEBUG_MSG("[DEBUG IF] Final insert block: " + << builder_.GetInsertBlock()->GetName()); + + return BlockFlow::Continue; +} + +// while语句(待实现)IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) { +IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] HandleWhileStmt: " << (ctx ? ctx->getText() : "")); + + if (!ctx || !ctx->cond() || !ctx->stmt(0)) { + throw std::runtime_error(FormatError("irgen", "非法 while 语句")); + } + + DEBUG_MSG("[DEBUG WHILE] Current insert block before while: " + << builder_.GetInsertBlock()->GetName()); + + auto uniq = [&](const std::string& prefix) { + std::string t = module_.GetContext().NextTemp(); + if (!t.empty() && t[0] == '%') t.erase(0, 1); + return prefix + "." + t; + }; + auto* condBlock = func_->CreateBlock(uniq("while.cond")); + auto* bodyBlock = func_->CreateBlock(uniq("while.body")); + auto* exitBlock = func_->CreateBlock(uniq("while.exit")); + + DEBUG_MSG("[DEBUG WHILE] condBlock: " << condBlock->GetName()); + DEBUG_MSG("[DEBUG WHILE] bodyBlock: " << bodyBlock->GetName()); + DEBUG_MSG("[DEBUG WHILE] exitBlock: " << exitBlock->GetName()); + + DEBUG_MSG("[DEBUG WHILE] Adding br to condBlock from current block"); + builder_.CreateBr(condBlock); + + loopStack_.push_back({condBlock, bodyBlock, exitBlock}); + DEBUG_MSG("[DEBUG WHILE] loopStack size: " << loopStack_.size()); + + // 条件块 + DEBUG_MSG("[DEBUG WHILE] Generating condition in block: " << condBlock->GetName()); + builder_.SetInsertPoint(condBlock); + auto* condValue = EvalCond(*ctx->cond()); + if (!condValue->GetType()->IsInt1()) { + if (condValue->GetType()->IsFloat()) { + condValue = builder_.CreateFCmpONE( + condValue, builder_.CreateConstFloat(0.0f), module_.GetContext().NextTemp()); + } else { + condValue = builder_.CreateICmpNE( + condValue, builder_.CreateConstInt(0), module_.GetContext().NextTemp()); + } + } + builder_.CreateCondBr(condValue, bodyBlock, exitBlock); + DEBUG_MSG("[DEBUG WHILE] condBlock has terminator: " << condBlock->HasTerminator()); + + // 循环体 + DEBUG_MSG("[DEBUG WHILE] Generating body in block: " << bodyBlock->GetName()); + builder_.SetInsertPoint(bodyBlock); + auto bodyResult = ctx->stmt(0)->accept(this); + bool bodyTerminated = (std::any_cast(bodyResult) == BlockFlow::Terminated); + DEBUG_MSG("[DEBUG WHILE] body terminated: " << bodyTerminated); + + if (!bodyTerminated) { + DEBUG_MSG("[DEBUG WHILE] Adding br to condBlock from body"); + builder_.CreateBr(condBlock); + } + DEBUG_MSG("[DEBUG WHILE] bodyBlock has terminator: " << bodyBlock->HasTerminator()); + + loopStack_.pop_back(); + DEBUG_MSG("[DEBUG WHILE] loopStack size after pop: " << loopStack_.size()); + + // 设置插入点为 exitBlock + DEBUG_MSG("[DEBUG WHILE] Setting insert point to exitBlock: " << exitBlock->GetName()); + builder_.SetInsertPoint(exitBlock); + DEBUG_MSG("[DEBUG WHILE] exitBlock has terminator before return: " + << exitBlock->HasTerminator()); + + return BlockFlow::Continue; +} + +// break语句(待实现) +IRGenImpl::BlockFlow IRGenImpl::HandleBreakStmt(SysYParser::StmtContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] HandleBreakStmt: " << (ctx ? ctx->getText() : "")); + + if (loopStack_.empty()) { + throw std::runtime_error(FormatError("irgen", "break 语句不在循环中")); + } + + DEBUG_MSG("[DEBUG BREAK] Current insert block before break: " + << builder_.GetInsertBlock()->GetName()); + DEBUG_MSG("[DEBUG BREAK] Breaking to exitBlock: " + << loopStack_.back().exitBlock->GetName()); + + // 跳转到循环退出块 + builder_.CreateBr(loopStack_.back().exitBlock); + + // break 本身就是终止当前路径,后续 unreachable 代码不需要继续生成。 + return BlockFlow::Terminated; +} + +IRGenImpl::BlockFlow IRGenImpl::HandleContinueStmt(SysYParser::StmtContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] HandleContinueStmt: " << (ctx ? ctx->getText() : "")); + + if (loopStack_.empty()) { + throw std::runtime_error(FormatError("irgen", "continue 语句不在循环中")); + } + + DEBUG_MSG("[DEBUG CONTINUE] Current insert block before continue: " + << builder_.GetInsertBlock()->GetName()); + DEBUG_MSG("[DEBUG CONTINUE] Continuing to condBlock: " + << loopStack_.back().condBlock->GetName()); + + // 跳转到循环条件块 + builder_.CreateBr(loopStack_.back().condBlock); + + // continue 本身就是终止当前路径,后续 unreachable 代码不需要继续生成。 + return BlockFlow::Terminated; +} + +// 赋值语句(待实现) +// 赋值语句 +// 赋值语句 +IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) { + DEBUG_MSG("[DEBUG IRGEN] HandleAssignStmt: " << (ctx ? ctx->getText() : "")); + + if (!ctx || !ctx->lVal() || !ctx->exp()) { + throw std::runtime_error(FormatError("irgen", "非法赋值语句")); + } + + // 计算右值 + ir::Value* rhs = EvalExpr(*ctx->exp()); + if (!rhs) { + throw std::runtime_error(FormatError("irgen", "赋值 RHS 计算失败")); + } + + auto* lval = ctx->lVal(); + std::string varName = lval->Ident()->getText(); + DEBUG_MSG("[DEBUG] HandleAssignStmt: assigning to " << varName); + + // 1. 检查是否为常量(不能给常量赋值) + auto* const_decl = sema_.ResolveConstUse(lval); + if (const_decl) { + throw std::runtime_error( + FormatError("irgen", "不能给常量赋值: " + varName)); + } + + // 2. 查找存储位置 + ir::Value* base_ptr = nullptr; + + // 2.1 尝试通过语义分析获取变量定义,并从 storage_map_ 查找 + auto* var_decl = sema_.ResolveVarUse(lval); + if (var_decl) { + auto it = storage_map_.find(var_decl); + if (it != storage_map_.end()) { + base_ptr = it->second; + DEBUG_MSG("[DEBUG] HandleAssignStmt: found in storage_map_ for " << varName + << ", ptr = " << (void*)base_ptr); + } + } + + // 2.2 从参数映射查找(关键!) + if (!base_ptr) { + auto it2 = param_map_.find(varName); + if (it2 != param_map_.end()) { + base_ptr = it2->second; + DEBUG_MSG("[DEBUG] HandleAssignStmt: found in param_map_ for " << varName + << ", ptr = " << (void*)base_ptr); + } + } + + // 2.3 从全局变量映射查找 + if (!base_ptr) { + auto it3 = global_map_.find(varName); + if (it3 != global_map_.end()) { + base_ptr = it3->second; + DEBUG_MSG("[DEBUG] HandleAssignStmt: found in global_map_ for " << varName + << ", ptr = " << (void*)base_ptr); + } + } + + // 2.4 从局部变量映射查找(fallback) + if (!base_ptr) { + auto it4 = local_var_map_.find(varName); + if (it4 != local_var_map_.end()) { + base_ptr = it4->second; + DEBUG_MSG("[DEBUG] HandleAssignStmt: found in local_var_map_ for " << varName + << ", ptr = " << (void*)base_ptr); + } + } + + // 如果还是找不到,才报错 + if (!base_ptr) { + throw std::runtime_error( + FormatError("irgen", "变量声明缺少存储槽位: " + varName)); + } + + // 3. 检查是否有数组下标 + auto exp_list = lval->exp(); + if (!exp_list.empty()) { + // 数组元素赋值 + std::vector idx_vals; + for (auto* exp : exp_list) { + ir::Value* index = EvalExpr(*exp); + idx_vals.push_back(index); + } + + ir::Value* elem_ptr = nullptr; + + // 扁平数组/数组参数(T*)的多维访问:先线性化,再单索引 GEP。 + if ((base_ptr->GetType()->IsPtrInt32() || base_ptr->GetType()->IsPtrFloat()) && idx_vals.size() > 1) { + const Symbol* var_sym = symbol_table_.lookup(varName); + if (!var_sym && var_decl) { + var_sym = symbol_table_.lookupByVarDef(var_decl); + } + if (!var_sym) { + var_sym = symbol_table_.lookupAll(varName); + } + + std::vector dims; + if (var_sym) { + if (var_sym->is_array_param && !var_sym->array_dims.empty()) { + dims = var_sym->array_dims; + } else if (var_sym->type && var_sym->type->IsArray()) { + auto* at = dynamic_cast(var_sym->type.get()); + if (at) dims = at->GetDimensions(); + } + } + + if (dims.empty() && var_decl) { + for (auto* cexp : var_decl->constExp()) { + dims.push_back(symbol_table_.EvaluateConstExp(cexp)); + } + } + + ir::Value* flat = nullptr; + for (size_t i = 0; i < idx_vals.size(); ++i) { + ir::Value* term = idx_vals[i]; + if (!term) continue; + + int mult = 1; + if (!dims.empty() && i + 1 < dims.size()) { + for (size_t j = i + 1; j < dims.size(); ++j) { + if (dims[j] > 0) mult *= dims[j]; + } + } + + if (mult != 1) { + auto* mval = builder_.CreateConstInt(mult); + term = builder_.CreateMul(term, mval, module_.GetContext().NextTemp()); + } + + if (!flat) flat = term; + else flat = builder_.CreateAdd(flat, term, module_.GetContext().NextTemp()); + } + + if (!flat) flat = builder_.CreateConstInt(0); + + std::vector gep_indices = {flat}; + elem_ptr = builder_.CreateGEP(base_ptr, gep_indices, module_.GetContext().NextTemp()); + } else { + std::vector indices; + if (base_ptr->GetType()->IsPtrInt32() || base_ptr->GetType()->IsPtrFloat()) { + for (auto* v : idx_vals) indices.push_back(v); + } else { + indices.push_back(builder_.CreateConstInt(0)); + for (auto* v : idx_vals) indices.push_back(v); + } + elem_ptr = builder_.CreateGEP(base_ptr, indices, module_.GetContext().NextTemp()); + } + + if (elem_ptr->GetType()->IsPtrFloat() && rhs->GetType()->IsInt32()) { + rhs = builder_.CreateSIToFP(rhs, ir::Type::GetFloatType(), + module_.GetContext().NextTemp()); + } else if (elem_ptr->GetType()->IsPtrInt32() && rhs->GetType()->IsFloat()) { + rhs = builder_.CreateFPToSI(rhs, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + builder_.CreateStore(rhs, elem_ptr); + } else { + // 普通标量赋值 + DEBUG_MSG("[DEBUG] HandleAssignStmt: scalar assignment to " << varName + << ", ptr = " << (void*)base_ptr + << ", rhs = " << (void*)rhs); + // 在 HandleAssignStmt 中,存储前添加类型调试 + if (base_ptr && base_ptr->GetType()) { + + DEBUG_MSG("[DEBUG] Is int32: " << base_ptr->GetType()->IsInt32()); + DEBUG_MSG("[DEBUG] Is float: " << base_ptr->GetType()->IsFloat()); + DEBUG_MSG("[DEBUG] Is ptr int32: " << base_ptr->GetType()->IsPtrInt32()); + DEBUG_MSG("[DEBUG] Is ptr float: " << base_ptr->GetType()->IsPtrFloat()); + DEBUG_MSG("[DEBUG] Is array: " << base_ptr->GetType()->IsArray()); + } + if (rhs && rhs->GetType()) { + + DEBUG_MSG("[DEBUG] Value is int32: " << rhs->GetType()->IsInt32()); + } + if (base_ptr->GetType()->IsPtrFloat() && rhs->GetType()->IsInt32()) { + rhs = builder_.CreateSIToFP(rhs, ir::Type::GetFloatType(), + module_.GetContext().NextTemp()); + } else if (base_ptr->GetType()->IsPtrInt32() && rhs->GetType()->IsFloat()) { + rhs = builder_.CreateFPToSI(rhs, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + builder_.CreateStore(rhs, base_ptr); + } + + return BlockFlow::Continue; +} \ No newline at end of file diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index 745374c..c072f6b 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -3,198 +3,1587 @@ #include #include #include +#include +#include #include "SysYBaseVisitor.h" #include "sem/SymbolTable.h" #include "utils/Log.h" +//#define DEBUG_SEMA + +#ifdef DEBUG_SEMA +#include +#define DEBUG_MSG(msg) std::cerr << "[Sema Debug] " << msg << std::endl +#else +#define DEBUG_MSG(msg) +#endif + namespace { -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("sema", "非法左值")); - } - return lvalue.ID()->getText(); +// 获取左值名称的辅助函数 +std::string GetLValueName(SysYParser::LValContext& lval) { + if (!lval.Ident()) { + throw std::runtime_error(FormatError("sema", "非法左值")); + } + return lval.Ident()->getText(); +} + +// 从 BTypeContext 获取类型 +std::shared_ptr GetTypeFromBType(SysYParser::BTypeContext* ctx) { + if (!ctx) return ir::Type::GetInt32Type(); + if (ctx->Int()) return ir::Type::GetInt32Type(); + if (ctx->Float()) return ir::Type::GetFloatType(); + return ir::Type::GetInt32Type(); } +// 语义分析 Visitor class SemaVisitor final : public SysYBaseVisitor { - public: - std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "缺少编译单元")); - } - auto* func = ctx->funcDef(); - if (!func || !func->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); - } - if (!func->ID() || func->ID()->getText() != "main") { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); - } - func->accept(this); - if (!seen_return_) { - throw std::runtime_error( - FormatError("sema", "main 函数必须包含 return 语句")); - } - return {}; - } - - std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { - if (!ctx || !ctx->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); - } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持 int main")); - } - const auto& items = ctx->blockStmt()->blockItem(); - if (items.empty()) { - throw std::runtime_error( - FormatError("sema", "main 函数不能为空,且必须以 return 结束")); - } - ctx->blockStmt()->accept(this); - return {}; - } - - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "缺少语句块")); - } - const auto& items = ctx->blockItem(); - for (size_t i = 0; i < items.size(); ++i) { - auto* item = items[i]; - if (!item) { - continue; - } - if (seen_return_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); - } - current_item_index_ = i; - total_items_ = items.size(); - item->accept(this); - } - return {}; - } - - std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - if (ctx->decl()) { - ctx->decl()->accept(this); - return {}; - } - if (ctx->stmt()) { - ctx->stmt()->accept(this); - return {}; - } - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - - std::any visitDecl(SysYParser::DeclContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); - } - if (!ctx->btype() || !ctx->btype()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明")); - } - auto* var_def = ctx->varDef(); - if (!var_def || !var_def->lValue()) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); - } - const std::string name = GetLValueName(*var_def->lValue()); - if (table_.Contains(name)) { - throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); - } - if (auto* init = var_def->initValue()) { - if (!init->exp()) { - throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化")); - } - init->exp()->accept(this); - } - table_.Add(name, var_def); - return {}; - } - - std::any visitStmt(SysYParser::StmtContext* ctx) override { - if (!ctx || !ctx->returnStmt()) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - ctx->returnStmt()->accept(this); - return {}; - } - - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "return 缺少表达式")); - } - ctx->exp()->accept(this); - seen_return_ = true; - if (current_item_index_ + 1 != total_items_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); - } - return {}; - } - - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "非法括号表达式")); - } - ctx->exp()->accept(this); - return {}; - } - - std::any visitVarExp(SysYParser::VarExpContext* ctx) override { - if (!ctx || !ctx->var()) { - throw std::runtime_error(FormatError("sema", "非法变量表达式")); - } - ctx->var()->accept(this); - return {}; - } - - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { - throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量")); - } - return {}; - } - - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { - throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式")); - } - ctx->exp(0)->accept(this); - ctx->exp(1)->accept(this); - return {}; - } - - std::any visitVar(SysYParser::VarContext* ctx) override { - if (!ctx || !ctx->ID()) { - throw std::runtime_error(FormatError("sema", "非法变量引用")); +public: + SemaVisitor() : table_() {} + std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "缺少编译单元")); + } + for (auto* func : ctx->funcDef()) { // 收集所有函数声明(处理互相调用) + CollectFunctionDeclaration(func); + } + for (auto* decl : ctx->decl()) { // 处理所有声明和定义 + if (decl) decl->accept(this); + } + for (auto* func : ctx->funcDef()) { + if (func) func->accept(this); + } + CheckMainFunction(); // 检查 main 函数存在且正确 + return {}; + } + + std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "函数定义缺少标识符")); + } + std::string name = ctx->Ident()->getText(); + std::shared_ptr return_type; // 获取返回类型 + if (ctx->funcType()) { + if (ctx->funcType()->Void()) { + return_type = ir::Type::GetVoidType(); + } else if (ctx->funcType()->Int()) { + return_type = ir::Type::GetInt32Type(); + } else if (ctx->funcType()->Float()) { + return_type = ir::Type::GetFloatType(); + } else { + return_type = ir::Type::GetInt32Type(); + } + } else { + return_type = ir::Type::GetInt32Type(); + } + DEBUG_MSG("[DEBUG] 进入函数: " << name + << " 返回类型: " << (return_type->IsInt32() ? "int" : + return_type->IsFloat() ? "float" : "void")); + + // 记录当前函数返回类型(用于 return 检查) + current_func_return_type_ = return_type; + current_func_has_return_ = false; + + table_.enterScope(); + if (ctx->funcFParams()) { // 处理参数 + CollectFunctionParams(ctx->funcFParams()); + } + if (ctx->block()) { // 处理函数体 + ctx->block()->accept(this); + } + DEBUG_MSG("[DEBUG] 函数 " << name + << " has_return: " << current_func_has_return_ + << " return_type_is_void: " << return_type->IsVoid()); + if (!return_type->IsVoid() && !current_func_has_return_) { // 检查非 void 函数是否有 return + throw std::runtime_error(FormatError("sema", "非 void 函数 " + name + " 缺少 return 语句")); + } + table_.exitScope(); + + current_func_return_type_ = nullptr; + current_func_has_return_ = false; + return {}; + } + + std::any visitBlock(SysYParser::BlockContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "缺少语句块")); + } + table_.enterScope(); + for (auto* item : ctx->blockItem()) { // 处理所有 blockItem + if (item) { + item->accept(this); + // 如果已经有 return,可以继续(但 return 必须是最后一条) + // 注意:这里不需要跳出,因为 return 语句本身已经标记了 + } + } + table_.exitScope(); + return {}; + } + + std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + } + if (ctx->decl()) { + ctx->decl()->accept(this); + return {}; + } + if (ctx->stmt()) { + ctx->stmt()->accept(this); + return {}; + } + throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + } + + std::any visitDecl(SysYParser::DeclContext* ctx) override { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "非法变量声明")); + } + if (ctx->constDecl()) { + ctx->constDecl()->accept(this); + } else if (ctx->varDecl()) { + ctx->varDecl()->accept(this); + } + return {}; + } + + // ==================== 变量声明 ==================== + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override { + if (!ctx || !ctx->bType()) { + throw std::runtime_error(FormatError("sema", "非法变量声明")); + } + std::shared_ptr base_type = GetTypeFromBType(ctx->bType()); + bool is_global = (table_.currentScopeLevel() == 0); + for (auto* var_def : ctx->varDef()) { + if (var_def) { + CheckVarDef(var_def, base_type, is_global); + } + } + return {}; + } + + void CheckVarDef(SysYParser::VarDefContext* ctx, + std::shared_ptr base_type, + bool is_global) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法变量定义")); + } + std::string name = ctx->Ident()->getText(); + if (table_.lookupCurrent(name)) { // 检查重复定义 + throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); + } + // 确定类型(处理数组维度) + std::shared_ptr type = base_type; + std::vector dims; + bool is_array = !ctx->constExp().empty(); + // 调试输出 + DEBUG_MSG("[DEBUG] CheckVarDef: " << name + << " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown") + << " is_array: " << is_array + << " dim_count: " << ctx->constExp().size()); + if (is_array) { + // 处理数组维度 + for (auto* dim_exp : ctx->constExp()) { + // ========== 绑定维度表达式 ========== + dim_exp->addExp()->accept(this); // 触发常量绑定(如 N) + + int dim = table_.EvaluateConstExp(dim_exp); + if (dim <= 0) { + throw std::runtime_error(FormatError("sema", "数组维度必须为正整数")); + } + dims.push_back(dim); + DEBUG_MSG("[DEBUG] dim[" << dims.size() - 1 << "] = " << dim); + } + // 创建数组类型 + type = ir::Type::GetArrayType(base_type, dims); + DEBUG_MSG("[DEBUG] 创建数组类型完成"); + DEBUG_MSG("[DEBUG] type->IsArray(): " << type->IsArray()); + DEBUG_MSG("[DEBUG] type->GetKind(): " << (int)type->GetKind()); + // 验证数组类型 + if (type->IsArray()) { + auto* arr_type = dynamic_cast(type.get()); + if (arr_type) { + DEBUG_MSG("[DEBUG] ArrayType dimensions: "); + for (int d : arr_type->GetDimensions()) { + DEBUG_MSG(d << " "); + } + DEBUG_MSG("[DEBUG] Element type: " + << (arr_type->GetElementType()->IsInt32() ? "int" : + arr_type->GetElementType()->IsFloat() ? "float" : "unknown")); + } + } + } + bool has_init = (ctx->initVal() != nullptr); // 处理初始化 + if (is_global && has_init) { + CheckGlobalInitIsConst(ctx->initVal()); // 全局变量初始化必须是常量表达式 + } + // ========== 绑定初始化表达式 ========== + if (ctx->initVal()) { + BindInitVal(ctx->initVal()); + } + // 创建符号 + Symbol sym; + sym.name = name; + sym.kind = SymbolKind::Variable; + sym.type = type; + sym.scope_level = table_.currentScopeLevel(); + sym.is_initialized = has_init; + sym.var_def_ctx = ctx; + if (is_array) { + // 存储维度信息,但 param_types 通常用于函数参数 + // 数组变量的维度信息已经包含在 type 中 + sym.param_types.clear(); // 确保不混淆 + } + table_.addSymbol(sym); // 添加到符号表 + DEBUG_MSG("[DEBUG] 符号添加完成: " << name + << " type_kind: " << (int)sym.type->GetKind() + << " is_array: " << sym.type->IsArray() + ); + } + + void CheckConstDef(SysYParser::ConstDefContext* ctx, + std::shared_ptr base_type) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法常量定义")); + } + std::string name = ctx->Ident()->getText(); + if (table_.lookupCurrent(name)) { + throw std::runtime_error(FormatError("sema", "重复定义常量: " + name)); + } + + // 确定类型 + std::shared_ptr type = base_type; + std::vector dims; + bool is_array = !ctx->constExp().empty(); + DEBUG_MSG("[DEBUG] CheckConstDef: " << name + << " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown") + << " is_array: " << is_array + << " dim_count: " << ctx->constExp().size()); + + if (is_array) { + for (auto* dim_exp : ctx->constExp()) { + int dim = table_.EvaluateConstExp(dim_exp); + if (dim <= 0) { + throw std::runtime_error(FormatError("sema", "数组维度必须为正整数")); + } + dims.push_back(dim); + DEBUG_MSG("[DEBUG] dim[" << dims.size() - 1 << "] = " << dim); + } + type = ir::Type::GetArrayType(base_type, dims); + DEBUG_MSG("[DEBUG] 创建数组类型完成,IsArray: " << type->IsArray()); + } + + // ========== 绑定维度表达式 ========== + for (auto* dim_exp : ctx->constExp()) { + dim_exp->addExp()->accept(this); + } + + // 求值初始化器 + std::vector init_values; + if (ctx->constInitVal()) { + // ========== 绑定初始化表达式 ========== + BindConstInitVal(ctx->constInitVal()); + + init_values = table_.EvaluateConstInitVal(ctx->constInitVal(), dims, base_type); + DEBUG_MSG("[DEBUG] 初始化值数量: " << init_values.size()); + } + + // 计算期望的元素数量 + size_t expected_count = 1; + if (is_array) { + expected_count = 1; + for (int d : dims) expected_count *= d; + DEBUG_MSG("[DEBUG] 期望元素数量: " << expected_count); + } + + // 如果初始化值不足,补零 + if (is_array && init_values.size() < expected_count) { + DEBUG_MSG("[DEBUG] 初始化值不足,补零"); + SymbolTable::ConstValue zero; + if (base_type->IsInt32()) { + zero.kind = SymbolTable::ConstValue::INT; + zero.int_val = 0; + } else { + zero.kind = SymbolTable::ConstValue::FLOAT; + zero.float_val = 0.0f; + } + init_values.resize(expected_count, zero); + } + + // 检查初始化值数量 + if (init_values.size() > expected_count) { + throw std::runtime_error(FormatError("sema", "初始化值过多")); + } + + // 创建符号 + Symbol sym; + sym.name = name; + sym.kind = SymbolKind::Constant; + DEBUG_MSG("CheckConstDef: before addSymbol, sym.kind = " << (int)sym.kind); + sym.type = type; + sym.scope_level = table_.currentScopeLevel(); + sym.is_initialized = true; + sym.var_def_ctx = nullptr; + sym.const_def_ctx = ctx; + DEBUG_MSG("保存常量定义上下文: " << name << ", ctx: " << ctx); + + // ========== 存储常量值 ========== + if (is_array) { + // 存储数组常量(扁平化存储) + sym.is_array_const = true; + sym.array_const_values.clear(); + + for (const auto& val : init_values) { + Symbol::ConstantValue cv; + if (val.kind == SymbolTable::ConstValue::INT) { + cv.i32 = val.int_val; + } else { + cv.f32 = val.float_val; + } + sym.array_const_values.push_back(cv); + } + + DEBUG_MSG("[DEBUG] 存储数组常量,共 " << sym.array_const_values.size() + << " 个元素"); + + } else if (!init_values.empty()) { + // 存储标量常量 + if (base_type->IsInt32() && init_values[0].kind == SymbolTable::ConstValue::INT) { + sym.is_int_const = true; + sym.const_value.i32 = init_values[0].int_val; + DEBUG_MSG("[DEBUG] 存储整型常量: " << init_values[0].int_val); + } else if (base_type->IsFloat() && init_values[0].kind == SymbolTable::ConstValue::FLOAT) { + sym.is_int_const = false; + sym.const_value.f32 = init_values[0].float_val; + DEBUG_MSG("[DEBUG] 存储浮点常量: " << init_values[0].float_val); + } else if (base_type->IsInt32() && init_values[0].kind == SymbolTable::ConstValue::FLOAT) { + // 整型常量用浮点数初始化(需要检查是否为整数) + float f = init_values[0].float_val; + int i = static_cast(f); + if (std::abs(f - i) > 1e-6) { + throw std::runtime_error(FormatError("sema", + "整型常量不能用非整数值的浮点数初始化: " + std::to_string(f))); + } + sym.is_int_const = true; + sym.const_value.i32 = i; + DEBUG_MSG("[DEBUG] 浮点转整型常量: " << f << " -> " << i); + } else if (base_type->IsFloat() && init_values[0].kind == SymbolTable::ConstValue::INT) { + // 浮点常量用整型初始化,隐式转换 + sym.is_int_const = false; + sym.const_value.f32 = static_cast(init_values[0].int_val); + DEBUG_MSG("[DEBUG] 整型转浮点常量: " << init_values[0].int_val + << " -> " << static_cast(init_values[0].int_val)); + } + } else { + // 没有初始化值,对于标量常量这是错误的 + if (!is_array) { + throw std::runtime_error(FormatError("sema", "常量必须有初始化值: " + name)); + } + DEBUG_MSG("[DEBUG] 数组常量无初始化器,将全部补零"); + } + + table_.addSymbol(sym); + DEBUG_MSG("CheckConstDef: after addSymbol, sym.kind = " << (int)sym.kind); + auto* stored = table_.lookup(name); + DEBUG_MSG("CheckConstDef: after addSymbol, stored const_def_ctx = " << stored->const_def_ctx); + + DEBUG_MSG("[DEBUG] 常量符号添加完成: " << name + << " is_array_const: " << sym.is_array_const + << " element_count: " << sym.array_const_values.size()); +} + + // ==================== 常量声明 ==================== + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override { + if (!ctx || !ctx->bType()) { + throw std::runtime_error(FormatError("sema", "非法常量声明")); + } + std::shared_ptr base_type = GetTypeFromBType(ctx->bType()); + for (auto* const_def : ctx->constDef()) { + if (const_def) { + CheckConstDef(const_def, base_type); + } + } + return {}; + } + + // ==================== 语句语义检查 ==================== + + // 处理所有语句 - 通过运行时类型判断 + std::any visitStmt(SysYParser::StmtContext* ctx) override { + if (!ctx) return {}; + // 调试输出 + DEBUG_MSG("[DEBUG] visitStmt: "); + if (ctx->Return()) DEBUG_MSG("Return "); + if (ctx->If()) DEBUG_MSG("If "); + if (ctx->While()) DEBUG_MSG("While "); + if (ctx->Break()) DEBUG_MSG("Break "); + if (ctx->Continue()) DEBUG_MSG("Continue "); + if (ctx->lVal() && ctx->Assign()) DEBUG_MSG("Assign "); + if (ctx->exp() && ctx->Semi()) DEBUG_MSG("ExpStmt "); + if (ctx->block()) DEBUG_MSG("Block "); + // 判断语句类型 - 注意:Return() 返回的是 TerminalNode* + if (ctx->Return() != nullptr) { + // return 语句 + DEBUG_MSG("[DEBUG] 检测到 return 语句"); + return visitReturnStmtInternal(ctx); + } else if (ctx->lVal() != nullptr && ctx->Assign() != nullptr) { + // 赋值语句 + return visitAssignStmt(ctx); + } else if (ctx->exp() != nullptr && ctx->Semi() != nullptr) { + // 表达式语句(可能有表达式) + return visitExpStmt(ctx); + } else if (ctx->block() != nullptr) { + // 块语句 + return ctx->block()->accept(this); + } else if (ctx->If() != nullptr) { + // if 语句 + return visitIfStmtInternal(ctx); + } else if (ctx->While() != nullptr) { + // while 语句 + return visitWhileStmtInternal(ctx); + } else if (ctx->Break() != nullptr) { + // break 语句 + return visitBreakStmtInternal(ctx); + } else if (ctx->Continue() != nullptr) { + // continue 语句 + return visitContinueStmtInternal(ctx); + } + return {}; + } + + // return 语句内部实现 + std::any visitReturnStmtInternal(SysYParser::StmtContext* ctx) { + DEBUG_MSG("[DEBUG] visitReturnStmtInternal 被调用"); + std::shared_ptr expected = current_func_return_type_; + if (!expected) { + throw std::runtime_error(FormatError("sema", "return 语句不在函数体内")); + } + if (ctx->exp() != nullptr) { + // 有返回值的 return + DEBUG_MSG("[DEBUG] 有返回值的 return"); + ExprInfo ret_val = CheckExp(ctx->exp()); + if (expected->IsVoid()) { + throw std::runtime_error(FormatError("sema", "void 函数不能返回值")); + } else if (!IsTypeCompatible(ret_val.type, expected)) { + throw std::runtime_error(FormatError("sema", "返回值类型不匹配")); + } + // 标记需要隐式转换 + if (ret_val.type != expected) { + sema_.AddConversion(ctx->exp(), ret_val.type, expected); + } + // 设置 has_return 标志 + current_func_has_return_ = true; + DEBUG_MSG("[DEBUG] 设置 current_func_has_return_ = true"); + } else { + // 无返回值的 return + DEBUG_MSG("[DEBUG] 无返回值的 return"); + if (!expected->IsVoid()) { + throw std::runtime_error(FormatError("sema", "非 void 函数必须返回值")); + } + // 设置 has_return 标志 + current_func_has_return_ = true; + DEBUG_MSG("[DEBUG] 设置 current_func_has_return_ = true"); + } + return {}; + } + + // 左值表达式(变量引用) + std::any visitLVal(SysYParser::LValContext* ctx) override { + DEBUG_MSG("[DEBUG] visitLVal: " << ctx->getText()); + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法变量引用")); + } + std::string name = ctx->Ident()->getText(); + auto* sym = table_.lookup(name); + if (!sym) { + throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); + } + // 检查数组访问 + bool is_array_access = !ctx->exp().empty(); + DEBUG_MSG("[DEBUG] name: " << name + << ", is_array_access: " << is_array_access + << ", subscript_count: " << ctx->exp().size()); + ExprInfo result; + // 判断是否为数组类型或指针类型(数组参数) + bool is_array_or_ptr = false; + if (sym->type) { + is_array_or_ptr = sym->type->IsArray() || sym->type->IsPtrInt32() || sym->type->IsPtrFloat(); + DEBUG_MSG("[DEBUG] type_kind: " << (int)sym->type->GetKind() + << ", is_array: " << sym->type->IsArray() + << ", is_ptr: " << (sym->type->IsPtrInt32() || sym->type->IsPtrFloat())); + } + + if (is_array_or_ptr) { + // 获取维度信息 + size_t dim_count = 0; + std::shared_ptr elem_type = sym->type; + if (sym->type->IsArray()) { + if (auto* arr_type = dynamic_cast(sym->type.get())) { + dim_count = arr_type->GetDimensions().size(); + elem_type = arr_type->GetElementType(); + DEBUG_MSG("[DEBUG] 数组维度: " << dim_count); + } + } else if (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) { + dim_count = 1; + if (sym->type->IsPtrInt32()) { + elem_type = ir::Type::GetInt32Type(); + } else if (sym->type->IsPtrFloat()) { + elem_type = ir::Type::GetFloatType(); + } + DEBUG_MSG("[DEBUG] 指针类型, dim_count: 1"); + } + + if (is_array_access) { + DEBUG_MSG("[DEBUG] 有下标访问,期望维度: " << dim_count + << ", 实际下标数: " << ctx->exp().size()); + if (ctx->exp().size() != dim_count) { + throw std::runtime_error(FormatError("sema", "数组下标个数不匹配")); + } + for (auto* idx_exp : ctx->exp()) { + ExprInfo idx = CheckExp(idx_exp); + if (!idx.type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "数组下标必须是 int 类型")); + } + } + result.type = elem_type; + result.is_lvalue = true; + result.is_const = false; + } else { + DEBUG_MSG("[DEBUG] 无下标访问"); + if (sym->type->IsArray()) { + DEBUG_MSG("[DEBUG] 数组名作为地址,转换为指针"); + if (auto* arr_type = dynamic_cast(sym->type.get())) { + if (arr_type->GetElementType()->IsInt32()) { + result.type = ir::Type::GetPtrInt32Type(); + } else if (arr_type->GetElementType()->IsFloat()) { + result.type = ir::Type::GetPtrFloatType(); + } else { + result.type = ir::Type::GetPtrInt32Type(); + } + } else { + result.type = ir::Type::GetPtrInt32Type(); + } + result.is_lvalue = false; + result.is_const = true; + } else { + result.type = sym->type; + result.is_lvalue = true; + result.is_const = (sym->kind == SymbolKind::Constant); + } + } + } else { + if (is_array_access) { + throw std::runtime_error(FormatError("sema", "非数组变量不能使用下标: " + name)); + } + result.type = sym->type; + result.is_lvalue = true; + result.is_const = (sym->kind == SymbolKind::Constant); + if (result.is_const && sym->type && !sym->type->IsArray()) { + if (sym->is_int_const) { + result.is_const_int = true; + result.const_int_value = sym->const_value.i32; + } else { + result.const_float_value = sym->const_value.f32; + } + } + } + sema_.SetExprType(ctx, result); + return {}; + } + + // if 语句内部实现 + std::any visitIfStmtInternal(SysYParser::StmtContext* ctx) { + // 检查条件表达式 + if (ctx->cond()) { + ExprInfo cond = CheckCond(ctx->cond()); // CheckCond 已经处理了类型转换 + // 不需要额外检查,因为 CheckCond 已经确保类型正确 + } + // 处理 then 分支 + if (ctx->stmt().size() > 0) { + ctx->stmt()[0]->accept(this); + } + // 处理 else 分支 + if (ctx->stmt().size() > 1) { + ctx->stmt()[1]->accept(this); + } + return {}; + } + + // while 语句内部实现 + std::any visitWhileStmtInternal(SysYParser::StmtContext* ctx) { + if (ctx->cond()) { + ExprInfo cond = CheckCond(ctx->cond()); // CheckCond 已经处理了类型转换 + // 不需要额外检查 + } + loop_stack_.push_back({true, ctx}); + if (ctx->stmt().size() > 0) { + ctx->stmt()[0]->accept(this); + } + loop_stack_.pop_back(); + return {}; + } + + // break 语句内部实现 + std::any visitBreakStmtInternal(SysYParser::StmtContext* ctx) { + if (loop_stack_.empty() || !loop_stack_.back().in_loop) { + throw std::runtime_error(FormatError("sema", "break 语句必须在循环体内使用")); + } + return {}; + } + + // continue 语句内部实现 + std::any visitContinueStmtInternal(SysYParser::StmtContext* ctx) { + if (loop_stack_.empty() || !loop_stack_.back().in_loop) { + throw std::runtime_error(FormatError("sema", "continue 语句必须在循环体内使用")); + } + return {}; } - const std::string name = ctx->ID()->getText(); - auto* decl = table_.Lookup(name); - if (!decl) { - throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); - } - sema_.BindVarUse(ctx, decl); - return {}; - } - - SemanticContext TakeSemanticContext() { return std::move(sema_); } - private: - SymbolTable table_; - SemanticContext sema_; - bool seen_return_ = false; - size_t current_item_index_ = 0; - size_t total_items_ = 0; + // 赋值语句内部实现 + std::any visitAssignStmt(SysYParser::StmtContext* ctx) { + if (!ctx->lVal() || !ctx->exp()) { + throw std::runtime_error(FormatError("sema", "非法赋值语句")); + } + ExprInfo lvalue = CheckLValue(ctx->lVal()); // 检查左值 + if (lvalue.is_const) { + throw std::runtime_error(FormatError("sema", "不能给常量赋值")); + } + if (!lvalue.is_lvalue) { + throw std::runtime_error(FormatError("sema", "赋值左边必须是左值")); + } + ExprInfo rvalue = CheckExp(ctx->exp()); // 检查右值 + if (!IsTypeCompatible(rvalue.type, lvalue.type)) { + throw std::runtime_error(FormatError("sema", "赋值类型不匹配")); + } + if (rvalue.type != lvalue.type) { // 标记需要隐式转换 + sema_.AddConversion(ctx->exp(), rvalue.type, lvalue.type); + } + return {}; + } + + // 表达式语句内部实现 + std::any visitExpStmt(SysYParser::StmtContext* ctx) { + if (ctx->exp()) { + CheckExp(ctx->exp()); + } + return {}; + } + + // ==================== 表达式类型推导 ==================== + + // 主表达式 + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override { + DEBUG_MSG("[DEBUG] visitPrimaryExp: " << ctx->getText()); + ExprInfo result; + if (ctx->lVal()) { // 左值表达式 + result = CheckLValue(ctx->lVal()); + result.is_lvalue = true; + } else if (ctx->HEX_FLOAT() || ctx->DEC_FLOAT()) { // 浮点字面量 + result.type = ir::Type::GetFloatType(); + result.is_const = true; + result.is_const_int = false; + std::string text; + if (ctx->HEX_FLOAT()) text = ctx->HEX_FLOAT()->getText(); + else text = ctx->DEC_FLOAT()->getText(); + result.const_float_value = std::stof(text); + } else if (ctx->HEX_INT() || ctx->OCTAL_INT() || ctx->DECIMAL_INT() || ctx->ZERO()) { // 整数字面量 + result.type = ir::Type::GetInt32Type(); + result.is_const = true; + result.is_const_int = true; + std::string text; + if (ctx->HEX_INT()) text = ctx->HEX_INT()->getText(); + else if (ctx->OCTAL_INT()) text = ctx->OCTAL_INT()->getText(); + else if (ctx->DECIMAL_INT()) text = ctx->DECIMAL_INT()->getText(); + else text = ctx->ZERO()->getText(); + result.const_int_value = std::stoi(text, nullptr, 0); + } else if (ctx->exp()) { // 括号表达式 + result = CheckExp(ctx->exp()); + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 一元表达式 + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override { + DEBUG_MSG("[DEBUG] visitUnaryExp: " << ctx->getText()); + ExprInfo result; + if (ctx->primaryExp()) { + ctx->primaryExp()->accept(this); + auto* info = sema_.GetExprType(ctx->primaryExp()); + if (info) result = *info; + } else if (ctx->Ident() && ctx->L_PAREN()) { // 函数调用 + DEBUG_MSG("[DEBUG] 函数调用: " << ctx->Ident()->getText()); + result = CheckFuncCall(ctx); + } else if (ctx->unaryOp()) { // 一元运算 + ctx->unaryExp()->accept(this); + auto* operand = sema_.GetExprType(ctx->unaryExp()); + if (!operand) { + throw std::runtime_error(FormatError("sema", "一元操作数类型推导失败")); + result.type = ir::Type::GetInt32Type(); + } else { + std::string op = ctx->unaryOp()->getText(); + if (op == "!") { + // 逻辑非:要求操作数是 int 类型,或者可以转换为 int 的 float + if (operand->type->IsInt32()) { + // 已经是 int,没问题 + } else if (operand->type->IsFloat()) { + // float 可以隐式转换为 int + sema_.AddConversion(ctx->unaryExp(), operand->type, ir::Type::GetInt32Type()); + // 更新操作数类型为 int + operand->type = ir::Type::GetInt32Type(); + operand->is_const_int = true; + if (operand->is_const && !operand->is_const_int) { + // 如果原来是 float 常量,转换为 int 常量 + operand->const_int_value = (int)operand->const_float_value; + operand->is_const_int = true; + } + } else { + throw std::runtime_error(FormatError("sema", "逻辑非操作数必须是 int 类型或可以转换为 int 的 float 类型")); + } + result.type = ir::Type::GetInt32Type(); + result.is_lvalue = false; + result.is_const = operand->is_const; + if (operand->is_const && operand->is_const_int) { + result.is_const_int = true; + result.const_int_value = (operand->const_int_value == 0) ? 1 : 0; + } + } else { + // 正负号 + if (!operand->type->IsInt32() && !operand->type->IsFloat()) { + throw std::runtime_error(FormatError("sema", "正负号操作数必须是算术类型")); + } + result.type = operand->type; + result.is_lvalue = false; + result.is_const = operand->is_const; + if (op == "-" && operand->is_const) { + if (operand->type->IsInt32() && operand->is_const_int) { + result.is_const_int = true; + result.const_int_value = -operand->const_int_value; + } else if (operand->type->IsFloat()) { + result.const_float_value = -operand->const_float_value; + } + } + } + } + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 乘除模表达式 + std::any visitMulExp(SysYParser::MulExpContext* ctx) override { + ExprInfo result; + if (ctx->mulExp()) { + ctx->mulExp()->accept(this); + ctx->unaryExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->mulExp()); + auto* right_info = sema_.GetExprType(ctx->unaryExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "乘除模操作数类型推导失败")); + result.type = ir::Type::GetInt32Type(); + } else { + std::string op; + if (ctx->MulOp()) { + op = "*"; + } else if (ctx->DivOp()) { + op = "/"; + } else if (ctx->QuoOp()) { + op = "%"; + } + result = CheckBinaryOp(left_info, right_info, op, ctx); + } + } else { + ctx->unaryExp()->accept(this); + auto* info = sema_.GetExprType(ctx->unaryExp()); + if (info) { + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 加减表达式 + std::any visitAddExp(SysYParser::AddExpContext* ctx) override { + ExprInfo result; + if (ctx->addExp()) { + ctx->addExp()->accept(this); + ctx->mulExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->addExp()); + auto* right_info = sema_.GetExprType(ctx->mulExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "加减操作数类型推导失败")); + result.type = ir::Type::GetInt32Type(); + } else { + std::string op; + if (ctx->AddOp()) { + op = "+"; + } else if (ctx->SubOp()) { + op = "-"; + } + result = CheckBinaryOp(left_info, right_info, op, ctx); + } + } else { + ctx->mulExp()->accept(this); + auto* info = sema_.GetExprType(ctx->mulExp()); + if (info) { + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 关系表达式 + std::any visitRelExp(SysYParser::RelExpContext* ctx) override { + ExprInfo result; + if (ctx->relExp()) { + ctx->relExp()->accept(this); + ctx->addExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->relExp()); + auto* right_info = sema_.GetExprType(ctx->addExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "关系操作数类型推导失败")); + } else { + if (!left_info->type->IsInt32() && !left_info->type->IsFloat()) { + throw std::runtime_error(FormatError("sema", "关系运算操作数必须是算术类型")); + } + std::string op; + if (ctx->LOp()) { + op = "<"; + } else if (ctx->GOp()) { + op = ">"; + } else if (ctx->LeOp()) { + op = "<="; + } else if (ctx->GeOp()) { + op = ">="; + } + result.type = ir::Type::GetInt32Type(); + result.is_lvalue = false; + if (left_info->is_const && right_info->is_const) { + result.is_const = true; + result.is_const_int = true; + float l = GetFloatValue(*left_info); + float r = GetFloatValue(*right_info); + if (op == "<") result.const_int_value = (l < r) ? 1 : 0; + else if (op == ">") result.const_int_value = (l > r) ? 1 : 0; + else if (op == "<=") result.const_int_value = (l <= r) ? 1 : 0; + else if (op == ">=") result.const_int_value = (l >= r) ? 1 : 0; + } + } + } else { + ctx->addExp()->accept(this); + auto* info = sema_.GetExprType(ctx->addExp()); + if (info) { + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 相等性表达式 + std::any visitEqExp(SysYParser::EqExpContext* ctx) override { + ExprInfo result; + if (ctx->eqExp()) { + ctx->eqExp()->accept(this); + ctx->relExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->eqExp()); + auto* right_info = sema_.GetExprType(ctx->relExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "相等性操作数类型推导失败")); + } else { + std::string op; + if (ctx->EqOp()) { + op = "=="; + } else if (ctx->NeOp()) { + op = "!="; + } + result.type = ir::Type::GetInt32Type(); + result.is_lvalue = false; + if (left_info->is_const && right_info->is_const) { + result.is_const = true; + result.is_const_int = true; + float l = GetFloatValue(*left_info); + float r = GetFloatValue(*right_info); + if (op == "==") result.const_int_value = (l == r) ? 1 : 0; + else if (op == "!=") result.const_int_value = (l != r) ? 1 : 0; + } + } + } else { + ctx->relExp()->accept(this); + auto* info = sema_.GetExprType(ctx->relExp()); + if (info) { + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 逻辑与表达式 + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override { + ExprInfo result; + if (ctx->lAndExp()) { + ctx->lAndExp()->accept(this); + ctx->eqExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->lAndExp()); + auto* right_info = sema_.GetExprType(ctx->eqExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "逻辑与操作数类型推导失败")); + } else { + // 处理左操作数 + if (left_info->type->IsInt32()) { + // 已经是 int,没问题 + } else if (left_info->type->IsFloat()) { + // float 可以隐式转换为 int + sema_.AddConversion(ctx->lAndExp(), left_info->type, ir::Type::GetInt32Type()); + left_info->type = ir::Type::GetInt32Type(); + if (left_info->is_const && !left_info->is_const_int) { + left_info->const_int_value = (int)left_info->const_float_value; + left_info->is_const_int = true; + } + } else { + throw std::runtime_error(FormatError("sema", "逻辑与左操作数必须是 int 类型或可以转换为 int 的 float 类型")); + } + + // 处理右操作数 + if (right_info->type->IsInt32()) { + // 已经是 int,没问题 + } else if (right_info->type->IsFloat()) { + // float 可以隐式转换为 int + sema_.AddConversion(ctx->eqExp(), right_info->type, ir::Type::GetInt32Type()); + right_info->type = ir::Type::GetInt32Type(); + if (right_info->is_const && !right_info->is_const_int) { + right_info->const_int_value = (int)right_info->const_float_value; + right_info->is_const_int = true; + } + } else { + throw std::runtime_error(FormatError("sema", "逻辑与右操作数必须是 int 类型或可以转换为 int 的 float 类型")); + } + + result.type = ir::Type::GetInt32Type(); + result.is_lvalue = false; + if (left_info->is_const && right_info->is_const && + left_info->is_const_int && right_info->is_const_int) { + result.is_const = true; + result.is_const_int = true; + result.const_int_value = + (left_info->const_int_value && right_info->const_int_value) ? 1 : 0; + } + } + } else { + ctx->eqExp()->accept(this); + auto* info = sema_.GetExprType(ctx->eqExp()); + if (info) { + // 对于单个操作数,也需要确保类型是 int(用于条件表达式) + if (info->type->IsFloat()) { + sema_.AddConversion(ctx->eqExp(), info->type, ir::Type::GetInt32Type()); + info->type = ir::Type::GetInt32Type(); + if (info->is_const && !info->is_const_int) { + info->const_int_value = (int)info->const_float_value; + info->is_const_int = true; + } + } else if (!info->type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "逻辑与操作数必须是 int 类型或可以转换为 int 的 float 类型")); + } + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + + // 逻辑或表达式 + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override { + ExprInfo result; + if (ctx->lOrExp()) { + ctx->lOrExp()->accept(this); + ctx->lAndExp()->accept(this); + auto* left_info = sema_.GetExprType(ctx->lOrExp()); + auto* right_info = sema_.GetExprType(ctx->lAndExp()); + if (!left_info || !right_info) { + throw std::runtime_error(FormatError("sema", "逻辑或操作数类型推导失败")); + } else { + // 处理左操作数 + if (left_info->type->IsInt32()) { + // 已经是 int,没问题 + } else if (left_info->type->IsFloat()) { + // float 可以隐式转换为 int + sema_.AddConversion(ctx->lOrExp(), left_info->type, ir::Type::GetInt32Type()); + left_info->type = ir::Type::GetInt32Type(); + if (left_info->is_const && !left_info->is_const_int) { + left_info->const_int_value = (int)left_info->const_float_value; + left_info->is_const_int = true; + } + } else { + throw std::runtime_error(FormatError("sema", "逻辑或左操作数必须是 int 类型或可以转换为 int 的 float 类型")); + } + + // 处理右操作数 + if (right_info->type->IsInt32()) { + // 已经是 int,没问题 + } else if (right_info->type->IsFloat()) { + // float 可以隐式转换为 int + sema_.AddConversion(ctx->lAndExp(), right_info->type, ir::Type::GetInt32Type()); + right_info->type = ir::Type::GetInt32Type(); + if (right_info->is_const && !right_info->is_const_int) { + right_info->const_int_value = (int)right_info->const_float_value; + right_info->is_const_int = true; + } + } else { + throw std::runtime_error(FormatError("sema", "逻辑或右操作数必须是 int 类型或可以转换为 int 的 float 类型")); + } + + result.type = ir::Type::GetInt32Type(); + result.is_lvalue = false; + if (left_info->is_const && right_info->is_const && + left_info->is_const_int && right_info->is_const_int) { + result.is_const = true; + result.is_const_int = true; + result.const_int_value = + (left_info->const_int_value || right_info->const_int_value) ? 1 : 0; + } + } + } else { + ctx->lAndExp()->accept(this); + auto* info = sema_.GetExprType(ctx->lAndExp()); + if (info) { + // 对于单个操作数,也需要确保类型是 int(用于条件表达式) + if (info->type->IsFloat()) { + sema_.AddConversion(ctx->lAndExp(), info->type, ir::Type::GetInt32Type()); + info->type = ir::Type::GetInt32Type(); + if (info->is_const && !info->is_const_int) { + info->const_int_value = (int)info->const_float_value; + info->is_const_int = true; + } + } else if (!info->type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "逻辑或操作数必须是 int 类型或可以转换为 int 的 float 类型")); + } + sema_.SetExprType(ctx, *info); + } + return {}; + } + sema_.SetExprType(ctx, result); + return {}; + } + // 新增:获取符号表 + SymbolTable TakeSymbolTable() { return std::move(table_); } + SemanticContext TakeSemanticContext() { return std::move(sema_); } + + // 新增:同时返回两者 + SemaResult TakeResult() { + DEBUG_MSG("[DEBUG] TakeResult 前: 符号表作用域数量 = " + << table_.getScopeCount()); + + // 可选:打印符号表内容 + // table_.dump(); + + SemaResult result; + result.context = std::move(sema_); + result.symbol_table = std::move(table_); + + DEBUG_MSG("[DEBUG] TakeResult 后: 符号表作用域数量 = " + << result.symbol_table.getScopeCount()); + return result; + } + + +private: + SymbolTable table_; + SemanticContext sema_; + struct LoopContext { + bool in_loop; + antlr4::ParserRuleContext* loop_node; + }; + std::vector loop_stack_; + std::shared_ptr current_func_return_type_ = nullptr; + bool current_func_has_return_ = false; + + // ==================== 辅助函数 ==================== + ExprInfo CheckExp(SysYParser::ExpContext* ctx) { + if (!ctx || !ctx->addExp()) { + throw std::runtime_error(FormatError("sema", "无效表达式")); + } + DEBUG_MSG("[DEBUG] CheckExp: " << ctx->getText()); + ctx->addExp()->accept(this); + auto* info = sema_.GetExprType(ctx->addExp()); + if (!info) { + throw std::runtime_error(FormatError("sema", "表达式类型推导失败")); + } + ExprInfo result = *info; + sema_.SetExprType(ctx, result); + return result; + } + + // 专门用于检查 AddExp 的辅助函数(用于常量表达式) + ExprInfo CheckAddExp(SysYParser::AddExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("sema", "无效表达式")); + } + ctx->accept(this); + auto* info = sema_.GetExprType(ctx); + if (!info) { + throw std::runtime_error(FormatError("sema", "表达式类型推导失败")); + } + return *info; + } + + ExprInfo CheckCond(SysYParser::CondContext* ctx) { + if (!ctx || !ctx->lOrExp()) { + throw std::runtime_error(FormatError("sema", "无效条件表达式")); + } + ctx->lOrExp()->accept(this); + auto* info = sema_.GetExprType(ctx->lOrExp()); + if (!info) { + throw std::runtime_error(FormatError("sema", "条件表达式类型推导失败")); + } + ExprInfo result = *info; + // 条件表达式的结果必须是 int,如果是 float 则需要转换 + // 注意:lOrExp 已经处理了类型转换,这里只是再检查一次 + if (!result.type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "条件表达式必须是 int 类型")); + } + return result; + } + + ExprInfo CheckLValue(SysYParser::LValContext* ctx) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法左值")); + } + std::string name = ctx->Ident()->getText(); + auto* sym = table_.lookup(name); + if (!sym) { + throw std::runtime_error(FormatError("sema", "未定义的变量: " + name)); + } + DEBUG_MSG("CheckLValue: found sym->name = " << sym->name + << ", sym->kind = " << (int)sym->kind); + + if (sym->kind == SymbolKind::Variable && sym->var_def_ctx) { + sema_.BindVarUse(ctx, sym->var_def_ctx); + DEBUG_MSG("绑定变量: " << name << " -> VarDefContext"); + } + else if (sym->kind == SymbolKind::Constant && sym->const_def_ctx) { + sema_.BindConstUse(ctx, sym->const_def_ctx); + DEBUG_MSG("绑定常量: " << name << " -> ConstDefContext"); + } + DEBUG_MSG("CheckLValue 绑定变量: " << name + << ", sym->kind: " << (int)sym->kind + << ", sym->var_def_ctx: " << sym->var_def_ctx + << ", sym->const_def_ctx: " << sym->const_def_ctx); + + bool is_array_access = !ctx->exp().empty(); + bool is_const = (sym->kind == SymbolKind::Constant); + + size_t dim_count = 0; + std::shared_ptr elem_type = sym->type; + std::vector dims; + + // 获取维度信息 + if (sym->type && sym->type->IsArray()) { + if (auto* arr_type = dynamic_cast(sym->type.get())) { + dims = arr_type->GetDimensions(); + dim_count = dims.size(); + // 计算元素类型(递归获取最内层元素类型) + std::shared_ptr t = sym->type; + while (t->IsArray()) { + auto* arr_t = dynamic_cast(t.get()); + t = arr_t->GetElementType(); + } + elem_type = t; + } + } else if (sym->is_array_param) { + // 数组参数,使用保存的维度信息 + dims = sym->array_dims; + dim_count = dims.size(); + // 元素类型是基本类型 + if (sym->type->IsPtrInt32()) { + elem_type = ir::Type::GetInt32Type(); + } else if (sym->type->IsPtrFloat()) { + elem_type = ir::Type::GetFloatType(); + } + DEBUG_MSG("数组参数维度: " << dim_count << " 维, dims: "); + for (int d : dims) DEBUG_MSG(d << " "); + } else if (sym->type && (sym->type->IsPtrInt32() || sym->type->IsPtrFloat())) { + // 普通指针,只能有一个下标 + dim_count = 1; + if (sym->type->IsPtrInt32()) { + elem_type = ir::Type::GetInt32Type(); + } else if (sym->type->IsPtrFloat()) { + elem_type = ir::Type::GetFloatType(); + } + } + + size_t subscript_count = ctx->exp().size(); + + DEBUG_MSG("dim_count: " << dim_count << ", subscript_count: " << subscript_count); + + if (dim_count > 0 || sym->is_array_param || sym->type->IsArray() || + sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) { + if (subscript_count > 0) { + // 有下标访问 + // 对于数组参数,第一维是省略的(0),但实际可以访问 + // 我们需要检查提供的下标个数是否超过实际维度个数 + if (subscript_count > dim_count) { + throw std::runtime_error(FormatError("sema", "数组下标个数过多")); + } + // 检查每个下标表达式 + for (auto* idx_exp : ctx->exp()) { + ExprInfo idx = CheckExp(idx_exp); + if (!idx.type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "数组下标必须是 int 类型")); + } + } + + if (subscript_count == dim_count) { + // 完全索引,返回元素类型 + DEBUG_MSG("完全索引,返回元素类型"); + return {elem_type, true, false}; + } else { + // 部分索引,返回子数组的指针类型 + DEBUG_MSG("部分索引,返回指针类型"); + // 计算剩余维度的指针类型 + if (elem_type->IsInt32()) { + return {ir::Type::GetPtrInt32Type(), false, false}; + } else if (elem_type->IsFloat()) { + return {ir::Type::GetPtrFloatType(), false, false}; + } else { + return {ir::Type::GetPtrInt32Type(), false, false}; + } + } + } else { + // 没有下标访问 + if (sym->type && sym->type->IsArray()) { + // 数组名作为地址 + DEBUG_MSG("数组名作为地址"); + if (auto* arr_type = dynamic_cast(sym->type.get())) { + if (arr_type->GetElementType()->IsInt32()) { + return {ir::Type::GetPtrInt32Type(), false, true}; + } else if (arr_type->GetElementType()->IsFloat()) { + return {ir::Type::GetPtrFloatType(), false, true}; + } + } + return {ir::Type::GetPtrInt32Type(), false, true}; + } else if (sym->is_array_param) { + // 数组参数名作为地址 + DEBUG_MSG("数组参数名作为地址"); + if (sym->type->IsPtrInt32()) { + return {ir::Type::GetPtrInt32Type(), false, true}; + } else { + return {ir::Type::GetPtrFloatType(), false, true}; + } + } else { + // 普通变量或指针 + return {sym->type, true, is_const}; + } + } + } else { + if (subscript_count > 0) { + throw std::runtime_error(FormatError("sema", "非数组变量不能使用下标: " + name)); + } + return {sym->type, true, is_const}; + } + } + + ExprInfo CheckFuncCall(SysYParser::UnaryExpContext* ctx) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法函数调用")); + } + std::string func_name = ctx->Ident()->getText(); + DEBUG_MSG("[DEBUG] CheckFuncCall: " << func_name); + auto* func_sym = table_.lookup(func_name); + if (!func_sym || func_sym->kind != SymbolKind::Function) { + throw std::runtime_error(FormatError("sema", "未定义的函数: " + func_name)); + } + std::vector args; + if (ctx->funcRParams()) { + DEBUG_MSG("[DEBUG] 处理函数调用参数:"); + for (auto* exp : ctx->funcRParams()->exp()) { + if (exp) { + args.push_back(CheckExp(exp)); + } + } + } + if (args.size() != func_sym->param_types.size()) { + throw std::runtime_error(FormatError("sema", "参数个数不匹配")); + } + for (size_t i = 0; i < std::min(args.size(), func_sym->param_types.size()); ++i) { + DEBUG_MSG("[DEBUG] 检查参数 " << i << ": 实参类型 " << (int)args[i].type->GetKind() + << " 形参类型 " << (int)func_sym->param_types[i]->GetKind()); + if (!IsTypeCompatible(args[i].type, func_sym->param_types[i])) { + throw std::runtime_error(FormatError("sema", "参数类型不匹配")); + } + if (args[i].type != func_sym->param_types[i] && ctx->funcRParams() && + i < ctx->funcRParams()->exp().size()) { + sema_.AddConversion(ctx->funcRParams()->exp()[i], + args[i].type, func_sym->param_types[i]); + } + } + std::shared_ptr return_type; + if (func_sym->type && func_sym->type->IsFunction()) { + auto* func_type = dynamic_cast(func_sym->type.get()); + if (func_type) { + return_type = func_type->GetReturnType(); + } + } + if (!return_type) { + return_type = ir::Type::GetInt32Type(); + } + ExprInfo result; + result.type = return_type; + result.is_lvalue = false; + result.is_const = false; + return result; + } + + ExprInfo CheckBinaryOp(const ExprInfo* left, const ExprInfo* right, + const std::string& op, antlr4::ParserRuleContext* ctx) { + ExprInfo result; + if (!left->type->IsInt32() && !left->type->IsFloat()) { + throw std::runtime_error(FormatError("sema", "左操作数必须是算术类型")); + } + if (!right->type->IsInt32() && !right->type->IsFloat()) { + throw std::runtime_error(FormatError("sema", "右操作数必须是算术类型")); + } + if (op == "%" && (!left->type->IsInt32() || !right->type->IsInt32())) { + throw std::runtime_error(FormatError("sema", "取模运算要求操作数为 int 类型")); + } + if (left->type->IsFloat() || right->type->IsFloat()) { + result.type = ir::Type::GetFloatType(); + } else { + result.type = ir::Type::GetInt32Type(); + } + result.is_lvalue = false; + if (left->is_const && right->is_const) { + result.is_const = true; + float l = GetFloatValue(*left); + float r = GetFloatValue(*right); + if (result.type->IsInt32()) { + result.is_const_int = true; + int li = (int)l, ri = (int)r; + if (op == "*") result.const_int_value = li * ri; + else if (op == "/") result.const_int_value = li / ri; + else if (op == "%") result.const_int_value = li % ri; + else if (op == "+") result.const_int_value = li + ri; + else if (op == "-") result.const_int_value = li - ri; + } else { + if (op == "*") result.const_float_value = l * r; + else if (op == "/") result.const_float_value = l / r; + else if (op == "+") result.const_float_value = l + r; + else if (op == "-") result.const_float_value = l - r; + } + } + return result; + } + + float GetFloatValue(const ExprInfo& info) { + if (info.type->IsInt32() && info.is_const_int) { + return (float)info.const_int_value; + } else { + return info.const_float_value; + } + } + + bool IsTypeCompatible(std::shared_ptr src, std::shared_ptr dst) { + if (src == dst) return true; + if (src->IsInt32() && dst->IsFloat()) return true; + if (src->IsFloat() && dst->IsInt32()) return true; + return false; + } + + void CollectFunctionDeclaration(SysYParser::FuncDefContext* ctx) { + if (!ctx || !ctx->Ident()) return; + std::string name = ctx->Ident()->getText(); + if (table_.lookup(name)) return; + std::shared_ptr ret_type; + if (ctx->funcType()) { + if (ctx->funcType()->Void()) { + ret_type = ir::Type::GetVoidType(); + } else if (ctx->funcType()->Int()) { + ret_type = ir::Type::GetInt32Type(); + } else if (ctx->funcType()->Float()) { + ret_type = ir::Type::GetFloatType(); + } + } + if (!ret_type) ret_type = ir::Type::GetInt32Type(); + std::vector> param_types; + if (ctx->funcFParams()) { + for (auto* param : ctx->funcFParams()->funcFParam()) { + if (!param) continue; + std::shared_ptr param_type; + if (param->bType()) { + if (param->bType()->Int()) { + param_type = ir::Type::GetInt32Type(); + } else if (param->bType()->Float()) { + param_type = ir::Type::GetFloatType(); + } + } + if (!param_type) param_type = ir::Type::GetInt32Type(); + if (!param->L_BRACK().empty()) { + if (param_type->IsInt32()) { + param_type = ir::Type::GetPtrInt32Type(); + } else if (param_type->IsFloat()) { + param_type = ir::Type::GetPtrFloatType(); + } + } + param_types.push_back(param_type); + } + } + + // 创建函数类型 + std::shared_ptr func_type = ir::Type::GetFunctionType(ret_type, param_types); + + // 创建函数符号 + Symbol sym; + sym.name = name; + sym.kind = SymbolKind::Function; + sym.type = func_type; + sym.param_types = param_types; + sym.scope_level = 0; + sym.is_initialized = true; + sym.func_def_ctx = ctx; + + table_.addSymbol(sym); + } + + void CollectFunctionParams(SysYParser::FuncFParamsContext* ctx) { + if (!ctx) return; + for (auto* param : ctx->funcFParam()) { + if (!param || !param->Ident()) continue; + std::string name = param->Ident()->getText(); + if (table_.lookupCurrent(name)) { + throw std::runtime_error(FormatError("sema", "重复定义参数: " + name)); + } + std::shared_ptr param_type; + if (param->bType()) { + if (param->bType()->Int()) { + param_type = ir::Type::GetInt32Type(); + } else if (param->bType()->Float()) { + param_type = ir::Type::GetFloatType(); + } + } + if (!param_type) param_type = ir::Type::GetInt32Type(); + + bool is_array = !param->L_BRACK().empty(); + std::vector dims; + + if (is_array) { + // 第一维是 [],没有表达式,所以维度为0(表示省略) + dims.push_back(0); + + // 后续维度有表达式 + // 注意:exp() 返回的是 ExpContext 列表,对应后面的维度表达式 + for (auto* exp_ctx : param->exp()) { + // 使用常量求值器直接求值 + // 创建一个临时的 ConstExpContext + // 由于 ConstExpContext 只是 addExp 的包装,我们可以直接使用 addExp + auto* addExp = exp_ctx->addExp(); + if (!addExp) { + throw std::runtime_error(FormatError("sema", "无效的数组维度表达式")); + } + + // 求值常量表达式 + int dim = table_.EvaluateConstExpression(exp_ctx); + if (dim <= 0) { + throw std::runtime_error(FormatError("sema", "数组维度必须为正整数")); + } + dims.push_back(dim); + } + + // 数组参数退化为指针 + if (param_type->IsInt32()) { + param_type = ir::Type::GetPtrInt32Type(); + } else if (param_type->IsFloat()) { + param_type = ir::Type::GetPtrFloatType(); + } + } + + Symbol sym; + sym.name = name; + sym.kind = SymbolKind::Parameter; + sym.type = param_type; + sym.scope_level = table_.currentScopeLevel(); + sym.is_initialized = true; + sym.var_def_ctx = nullptr; + sym.is_array_param = is_array; + sym.array_dims = dims; + table_.addSymbol(sym); + + DEBUG_MSG("[DEBUG] 添加参数: " << name << " type_kind: " << (int)param_type->GetKind()); + for (int d : dims) DEBUG_MSG(d << " "); + } + } + + void CheckGlobalInitIsConst(SysYParser::InitValContext* ctx) { + if (!ctx) return; + if (ctx->exp()) { + ExprInfo info = CheckExp(ctx->exp()); + if (!info.is_const) { + throw std::runtime_error(FormatError("sema", "全局变量初始化必须是常量表达式")); + } + } else { + for (auto* init : ctx->initVal()) { + CheckGlobalInitIsConst(init); + } + } + } + + void CheckMainFunction() { + auto* main_sym = table_.lookup("main"); + if (!main_sym || main_sym->kind != SymbolKind::Function) { + throw std::runtime_error(FormatError("sema", "缺少 main 函数")); + } + std::shared_ptr ret_type; + if (main_sym->type && main_sym->type->IsFunction()) { + auto* func_type = dynamic_cast(main_sym->type.get()); + if (func_type) { + ret_type = func_type->GetReturnType(); + } + } + if (!ret_type || !ret_type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "main 函数必须返回 int")); + } + if (!main_sym->param_types.empty()) { + throw std::runtime_error(FormatError("sema", "main 函数不能有参数")); + } + } + + void BindConstInitVal(SysYParser::ConstInitValContext* ctx) { + if (!ctx) return; + if (ctx->constExp()) { + // 遍历表达式树,触发 visitLVal 中的绑定 + ctx->constExp()->addExp()->accept(this); + } else { + for (auto* sub : ctx->constInitVal()) { + BindConstInitVal(sub); + } + } + } + + void BindInitVal(SysYParser::InitValContext* ctx) { + if (!ctx) return; + if (ctx->exp()) { + CheckExp(ctx->exp()); // 触发绑定 + } else { + for (auto* sub : ctx->initVal()) { + BindInitVal(sub); + } + } + } }; } // namespace -SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { +// 修改 RunSema 函数,使其返回 SemaResult 结构体,包含符号表和语义上下文 +SemaResult RunSema(SysYParser::CompUnitContext& comp_unit) { SemaVisitor visitor; comp_unit.accept(&visitor); - return visitor.TakeSemanticContext(); + // 直接返回 TakeResult(),利用移动语义 + return visitor.TakeResult(); } diff --git a/src/sem/SymbolTable.cpp b/src/sem/SymbolTable.cpp index ffeea89..a5a9db7 100644 --- a/src/sem/SymbolTable.cpp +++ b/src/sem/SymbolTable.cpp @@ -1,17 +1,789 @@ -// 维护局部变量声明的注册与查找。 - #include "sem/SymbolTable.h" +#include // 用于访问父节点 +#include +#include +#include +#include +#include + +//#define DEBUG_SYMBOL_TABLE + +#ifdef DEBUG_SYMBOL_TABLE +#include +#define DEBUG_MSG(msg) std::cerr << "[SymbolTable Debug] " << msg << std::endl +#else +#define DEBUG_MSG(msg) +#endif + +// ---------- 构造函数 ---------- +SymbolTable::SymbolTable() { + scopes_.emplace_back(); // 初始化全局作用域 + active_scope_stack_.push_back(0); + registerBuiltinFunctions(); // 注册内置库函数 +} + +// ---------- 作用域管理 ---------- +void SymbolTable::enterScope() { + scopes_.emplace_back(); + active_scope_stack_.push_back(scopes_.size() - 1); +} + +void SymbolTable::exitScope() { + if (active_scope_stack_.size() > 1) { + active_scope_stack_.pop_back(); + } + // 不能退出全局作用域 +} + +// ---------- 符号添加与查找 ---------- +bool SymbolTable::addSymbol(const Symbol& sym) { + auto& current_scope = scopes_[active_scope_stack_.back()]; + if (current_scope.find(sym.name) != current_scope.end()) { + return false; // 重复定义 + } + + Symbol stored_sym = sym; + stored_sym.scope_level = currentScopeLevel(); + current_scope[sym.name] = stored_sym; + + // 立即验证存储的符号 + const auto& stored = current_scope[sym.name]; + DEBUG_MSG("SymbolTable::addSymbol: stored " << sym.name + << " with kind=" << (int)stored.kind + << ", const_def_ctx=" << stored.const_def_ctx); + + return true; +} + +Symbol* SymbolTable::lookup(const std::string& name) { + return const_cast(static_cast(this)->lookup(name)); +} + +Symbol* SymbolTable::lookupCurrent(const std::string& name) { + return const_cast(static_cast(this)->lookupCurrent(name)); +} + +const Symbol* SymbolTable::lookup(const std::string& name) const { + for (auto it = active_scope_stack_.rbegin(); it != active_scope_stack_.rend(); ++it) { + const auto& scope = scopes_[*it]; + auto found = scope.find(name); + if (found != scope.end()) { + DEBUG_MSG("SymbolTable::lookup: found " << name + << " in active scope index " << *it + << ", kind=" << (int)found->second.kind + << ", const_def_ctx=" << found->second.const_def_ctx); + return &found->second; + } + } + return nullptr; +} + +const Symbol* SymbolTable::lookupCurrent(const std::string& name) const { + const auto& current_scope = scopes_[active_scope_stack_.back()]; + auto it = current_scope.find(name); + if (it != current_scope.end()) { + return &it->second; + } + return nullptr; +} + +const Symbol* SymbolTable::lookupAll(const std::string& name) const { + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + auto found = it->find(name); + if (found != it->end()) { + return &found->second; + } + } + return nullptr; +} -void SymbolTable::Add(const std::string& name, - SysYParser::VarDefContext* decl) { - table_[name] = decl; +const Symbol* SymbolTable::lookupByVarDef(const SysYParser::VarDefContext* decl) const { + if (!decl) return nullptr; + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + for (const auto& [name, sym] : *it) { + if (sym.var_def_ctx == decl) { + return &sym; + } + } + } + return nullptr; +} + +const Symbol* SymbolTable::lookupByConstDef(const SysYParser::ConstDefContext* decl) const { + if (!decl) return nullptr; + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + for (const auto& [name, sym] : *it) { + if (sym.const_def_ctx == decl) { + return &sym; + } + } + } + return nullptr; +} + +// ---------- 兼容原接口 ---------- +void SymbolTable::Add(const std::string& name, SysYParser::VarDefContext* decl) { + Symbol sym; + sym.name = name; + sym.kind = SymbolKind::Variable; + sym.type = getTypeFromVarDef(decl); + sym.var_def_ctx = decl; + sym.scope_level = currentScopeLevel(); + addSymbol(sym); } bool SymbolTable::Contains(const std::string& name) const { - return table_.find(name) != table_.end(); + for (auto it = active_scope_stack_.rbegin(); it != active_scope_stack_.rend(); ++it) { + const auto& scope = scopes_[*it]; + if (scope.find(name) != scope.end()) { + return true; + } + } + return false; } SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const { - auto it = table_.find(name); - return it == table_.end() ? nullptr : it->second; + for (auto it = active_scope_stack_.rbegin(); it != active_scope_stack_.rend(); ++it) { + const auto& scope = scopes_[*it]; + auto found = scope.find(name); + if (found != scope.end()) { + // 只返回变量定义的上下文(函数等其他符号返回 nullptr) + if (found->second.kind == SymbolKind::Variable) { + return found->second.var_def_ctx; + } + return nullptr; + } + } + return nullptr; +} + +// ---------- 辅助函数:从 VarDefContext 获取外层 VarDeclContext ---------- +static SysYParser::VarDeclContext* getOuterVarDecl(SysYParser::VarDefContext* varDef) { + auto parent = varDef->parent; + while (parent) { + if (auto varDecl = dynamic_cast(parent)) { + return varDecl; + } + parent = parent->parent; + } + return nullptr; +} + +// ---------- 辅助函数:从 VarDefContext 获取外层 ConstDeclContext(常量定义)---------- +static SysYParser::ConstDeclContext* getOuterConstDecl(SysYParser::VarDefContext* varDef) { + auto parent = varDef->parent; + while (parent) { + if (auto constDecl = dynamic_cast(parent)) { + return constDecl; + } + parent = parent->parent; + } + return nullptr; +} + +// 从 VarDefContext 构造类型 +// 原静态函数改为成员函数,并调用成员 EvaluateConstExp +std::shared_ptr SymbolTable::getTypeFromVarDef(SysYParser::VarDefContext* ctx) const { + // 获取基本类型(同原代码,但通过外层 Decl 确定) + std::shared_ptr base_type = nullptr; + auto varDecl = getOuterVarDecl(ctx); + if (varDecl) { + auto bType = varDecl->bType(); + if (bType->Int()) base_type = ir::Type::GetInt32Type(); + else if (bType->Float()) base_type = ir::Type::GetFloatType(); + } else { + auto constDecl = getOuterConstDecl(ctx); + if (constDecl) { + auto bType = constDecl->bType(); + if (bType->Int()) base_type = ir::Type::GetInt32Type(); + else if (bType->Float()) base_type = ir::Type::GetFloatType(); + } + } + if (!base_type) base_type = ir::Type::GetInt32Type(); + + // 解析维度 + std::vector dims; + for (auto* dimExp : ctx->constExp()) { + int dim = EvaluateConstExp(dimExp); // 调用成员函数 + if (dim <= 0) { + throw std::runtime_error("数组维度必须为正整数"); + } + dims.push_back(dim); + } + + if (!dims.empty()) { + return ir::Type::GetArrayType(base_type, dims); + } + return base_type; +} +// 从 FuncDefContext 构造函数类型 +std::shared_ptr SymbolTable::getTypeFromFuncDef(SysYParser::FuncDefContext* ctx) { + // 1. 返回类型 + std::shared_ptr ret_type; + auto funcType = ctx->funcType(); + if (funcType->Void()) { + ret_type = ir::Type::GetVoidType(); + } else if (funcType->Int()) { + ret_type = ir::Type::GetInt32Type(); + } else if (funcType->Float()) { + ret_type = ir::Type::GetFloatType(); + } else { + ret_type = ir::Type::GetInt32Type(); // fallback + } + + // 2. 参数类型 + std::vector> param_types; + auto fParams = ctx->funcFParams(); + if (fParams) { + for (auto param : fParams->funcFParam()) { + std::shared_ptr param_type; + auto bType = param->bType(); + if (bType->Int()) { + param_type = ir::Type::GetInt32Type(); + } else if (bType->Float()) { + param_type = ir::Type::GetFloatType(); + } else { + param_type = ir::Type::GetInt32Type(); + } + + // 处理数组参数:如果存在 [ ] 或 [ exp ],退化为指针 + if (param->L_BRACK().size() > 0) { + if (param_type->IsInt32()) { + param_type = ir::Type::GetPtrInt32Type(); + } else if (param_type->IsFloat()) { + param_type = ir::Type::GetPtrFloatType(); + } + } + param_types.push_back(param_type); + } + } + + return ir::Type::GetFunctionType(ret_type, param_types); +} + +// ----- 注册内置库函数----- +void SymbolTable::registerBuiltinFunctions() { + // 确保当前处于全局作用域(scopes_ 只有一层) + // 1. getint: int getint() + Symbol getint; + getint.name = "getint"; + getint.kind = SymbolKind::Function; + getint.type = ir::Type::GetFunctionType(ir::Type::GetInt32Type(), {}); // 无参数 + getint.param_types = {}; + getint.scope_level = 0; + getint.is_builtin = true; + addSymbol(getint); + + // 2. getfloat: float getfloat() + Symbol getfloat; + getfloat.name = "getfloat"; + getfloat.kind = SymbolKind::Function; + getfloat.type = ir::Type::GetFunctionType(ir::Type::GetFloatType(), {}); + getfloat.param_types = {}; + getfloat.scope_level = 0; + getfloat.is_builtin = true; + addSymbol(getfloat); + + // 3. getch: int getch() + Symbol getch; + getch.name = "getch"; + getch.kind = SymbolKind::Function; + getch.type = ir::Type::GetFunctionType(ir::Type::GetInt32Type(), {}); + getch.param_types = {}; + getch.scope_level = 0; + getch.is_builtin = true; + addSymbol(getch); + + // 4. putint: void putint(int) + std::vector> putint_params = { ir::Type::GetInt32Type() }; + Symbol putint; + putint.name = "putint"; + putint.kind = SymbolKind::Function; + putint.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), putint_params); + putint.param_types = putint_params; + putint.scope_level = 0; + putint.is_builtin = true; + addSymbol(putint); + + // 5. putfloat: void putfloat(float) + std::vector> putfloat_params = { ir::Type::GetFloatType() }; + Symbol putfloat; + putfloat.name = "putfloat"; + putfloat.kind = SymbolKind::Function; + putfloat.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), putfloat_params); + putfloat.param_types = putfloat_params; + putfloat.scope_level = 0; + putfloat.is_builtin = true; + addSymbol(putfloat); + + // 6. putch: void putch(int) + std::vector> putch_params = { ir::Type::GetInt32Type() }; + Symbol putch; + putch.name = "putch"; + putch.kind = SymbolKind::Function; + putch.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), putch_params); + putch.param_types = putch_params; + putch.scope_level = 0; + putch.is_builtin = true; + addSymbol(putch); + + // 7. getarray: int getarray(int a[]) + // 参数类型: int a[] 退化为 int* 即 PtrInt32 + std::vector> getarray_params = { ir::Type::GetPtrInt32Type() }; + Symbol getarray; + getarray.name = "getarray"; + getarray.kind = SymbolKind::Function; + getarray.type = ir::Type::GetFunctionType(ir::Type::GetInt32Type(), getarray_params); + getarray.param_types = getarray_params; + getarray.scope_level = 0; + getarray.is_builtin = true; + addSymbol(getarray); + + // 8. putarray: void putarray(int n, int a[]) + // 参数: int n, int a[] -> 实际类型: int, int* + std::vector> putarray_params = { ir::Type::GetInt32Type(), ir::Type::GetPtrInt32Type() }; + Symbol putarray; + putarray.name = "putarray"; + putarray.kind = SymbolKind::Function; + putarray.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), putarray_params); + putarray.param_types = putarray_params; + putarray.scope_level = 0; + putarray.is_builtin = true; + addSymbol(putarray); + + // starttime: void starttime() + Symbol starttime; + starttime.name = "starttime"; + starttime.kind = SymbolKind::Function; + starttime.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), {}); // 无参数,返回 void + starttime.param_types = {}; + starttime.scope_level = 0; + starttime.is_builtin = true; + addSymbol(starttime); + + // stoptime: void stoptime() + Symbol stoptime; + stoptime.name = "stoptime"; + stoptime.kind = SymbolKind::Function; + stoptime.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), {}); // 无参数,返回 void + stoptime.param_types = {}; + stoptime.scope_level = 0; + stoptime.is_builtin = true; + addSymbol(stoptime); + + // getfarray: int getfarray(float arr[]) + std::vector> getfarray_params = { ir::Type::GetPtrFloatType() }; + Symbol getfarray; + getfarray.name = "getfarray"; + getfarray.kind = SymbolKind::Function; + getfarray.type = ir::Type::GetFunctionType(ir::Type::GetInt32Type(), getfarray_params); + getfarray.param_types = getfarray_params; + getfarray.scope_level = 0; + getfarray.is_builtin = true; + addSymbol(getfarray); + + // putfarray: void putfarray(int len, float arr[]) + std::vector> putfarray_params = { + ir::Type::GetInt32Type(), + ir::Type::GetPtrFloatType() + }; + Symbol putfarray; + putfarray.name = "putfarray"; + putfarray.kind = SymbolKind::Function; + putfarray.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), putfarray_params); + putfarray.param_types = putfarray_params; + putfarray.scope_level = 0; + putfarray.is_builtin = true; + addSymbol(putfarray); +} + +// ==================== 常量表达式求值实现 ==================== + +static long long ParseIntegerLiteral(const std::string& text) { + // 处理前缀:0x/0X 十六进制,0 八进制,否则十进制 + if (text.size() > 2 && (text[0] == '0' && (text[1] == 'x' || text[1] == 'X'))) { + return std::stoll(text.substr(2), nullptr, 16); + } else if (text.size() > 1 && text[0] == '0') { + return std::stoll(text, nullptr, 8); + } else { + return std::stoll(text, nullptr, 10); + } +} + +static float ParseFloatLiteral(const std::string& text) { + return std::stof(text); +} + +SymbolTable::ConstValue SymbolTable::EvaluatePrimaryExp(SysYParser::PrimaryExpContext* ctx) const { + if (!ctx) throw std::runtime_error("常量表达式求值:无效 PrimaryExp"); + + if (ctx->lVal()) { + auto lval = ctx->lVal(); + if (!lval->Ident()) throw std::runtime_error("常量表达式求值:无效左值"); + std::string name = lval->Ident()->getText(); + const Symbol* sym = lookup(name); + if (!sym) throw std::runtime_error("常量表达式求值:未定义的标识符 " + name); + if (sym->kind != SymbolKind::Constant) + throw std::runtime_error("常量表达式求值:标识符 " + name + " 不是常量"); + + ConstValue val; + if (sym->is_int_const) { + val.kind = ConstValue::INT; + val.int_val = sym->const_value.i32; + } else { + val.kind = ConstValue::FLOAT; + val.float_val = sym->const_value.f32; + } + return val; + } + else if (ctx->HEX_FLOAT() || ctx->DEC_FLOAT()) { + std::string text; + if (ctx->HEX_FLOAT()) text = ctx->HEX_FLOAT()->getText(); + else text = ctx->DEC_FLOAT()->getText(); + ConstValue val; + val.kind = ConstValue::FLOAT; + val.float_val = ParseFloatLiteral(text); + return val; + } + else if (ctx->HEX_INT() || ctx->OCTAL_INT() || ctx->DECIMAL_INT() || ctx->ZERO()) { + std::string text; + if (ctx->HEX_INT()) text = ctx->HEX_INT()->getText(); + else if (ctx->OCTAL_INT()) text = ctx->OCTAL_INT()->getText(); + else if (ctx->DECIMAL_INT()) text = ctx->DECIMAL_INT()->getText(); + else text = ctx->ZERO()->getText(); + ConstValue val; + val.kind = ConstValue::INT; + val.int_val = static_cast(ParseIntegerLiteral(text)); + return val; + } + else if (ctx->exp()) { + return EvaluateAddExp(ctx->exp()->addExp()); + } + else { + throw std::runtime_error("常量表达式求值:不支持的 PrimaryExp 类型"); + } +} + +SymbolTable::ConstValue SymbolTable::EvaluateUnaryExp(SysYParser::UnaryExpContext* ctx) const { + if (!ctx) throw std::runtime_error("常量表达式求值:无效 UnaryExp"); + + if (ctx->primaryExp()) { + return EvaluatePrimaryExp(ctx->primaryExp()); + } + else if (ctx->unaryOp()) { + ConstValue operand = EvaluateUnaryExp(ctx->unaryExp()); + std::string op = ctx->unaryOp()->getText(); + + if (op == "+") { + return operand; + } + else if (op == "-") { + if (operand.kind == ConstValue::INT) { + operand.int_val = -operand.int_val; + } else { + operand.float_val = -operand.float_val; + } + return operand; + } + else if (op == "!") { + if (operand.kind != ConstValue::INT) { + throw std::runtime_error("常量表达式求值:逻辑非操作数必须是整数"); + } + ConstValue res; + res.kind = ConstValue::INT; + res.int_val = (operand.int_val == 0) ? 1 : 0; + return res; + } + else { + throw std::runtime_error("常量表达式求值:未知一元运算符 " + op); + } + } + else { + // 函数调用在常量表达式中不允许 + throw std::runtime_error("常量表达式求值:不允许函数调用"); + } +} + +SymbolTable::ConstValue SymbolTable::EvaluateMulExp(SysYParser::MulExpContext* ctx) const { + if (!ctx) throw std::runtime_error("常量表达式求值:无效 MulExp"); + + if (ctx->mulExp()) { + ConstValue left = EvaluateMulExp(ctx->mulExp()); + ConstValue right = EvaluateUnaryExp(ctx->unaryExp()); + + std::string op; + if (ctx->MulOp()) op = "*"; + else if (ctx->DivOp()) op = "/"; + else if (ctx->QuoOp()) op = "%"; + else throw std::runtime_error("常量表达式求值:未知乘法运算符"); + + bool is_float = (left.kind == ConstValue::FLOAT || right.kind == ConstValue::FLOAT); + if (is_float) { + float l = (left.kind == ConstValue::INT) ? static_cast(left.int_val) : left.float_val; + float r = (right.kind == ConstValue::INT) ? static_cast(right.int_val) : right.float_val; + ConstValue res; + res.kind = ConstValue::FLOAT; + if (op == "*") res.float_val = l * r; + else if (op == "/") res.float_val = l / r; + else if (op == "%") throw std::runtime_error("常量表达式求值:浮点数不支持取模运算"); + return res; + } else { + int l = left.int_val; + int r = right.int_val; + ConstValue res; + res.kind = ConstValue::INT; + if (op == "*") res.int_val = l * r; + else if (op == "/") { + if (r == 0) throw std::runtime_error("常量表达式求值:除零错误"); + res.int_val = l / r; + } + else if (op == "%") { + if (r == 0) throw std::runtime_error("常量表达式求值:模零错误"); + res.int_val = l % r; + } + return res; + } + } + else { + return EvaluateUnaryExp(ctx->unaryExp()); + } +} + +SymbolTable::ConstValue SymbolTable::EvaluateAddExp(SysYParser::AddExpContext* ctx) const { + if (!ctx) throw std::runtime_error("常量表达式求值:无效 AddExp"); + + if (ctx->addExp()) { + ConstValue left = EvaluateAddExp(ctx->addExp()); + ConstValue right = EvaluateMulExp(ctx->mulExp()); + + std::string op; + if (ctx->AddOp()) op = "+"; + else if (ctx->SubOp()) op = "-"; + else throw std::runtime_error("常量表达式求值:未知加法运算符"); + + bool is_float = (left.kind == ConstValue::FLOAT || right.kind == ConstValue::FLOAT); + if (is_float) { + float l = (left.kind == ConstValue::INT) ? static_cast(left.int_val) : left.float_val; + float r = (right.kind == ConstValue::INT) ? static_cast(right.int_val) : right.float_val; + ConstValue res; + res.kind = ConstValue::FLOAT; + if (op == "+") res.float_val = l + r; + else res.float_val = l - r; + return res; + } else { + int l = left.int_val; + int r = right.int_val; + ConstValue res; + res.kind = ConstValue::INT; + if (op == "+") res.int_val = l + r; + else res.int_val = l - r; + return res; + } + } + else { + return EvaluateMulExp(ctx->mulExp()); + } +} + +int SymbolTable::EvaluateConstExp(SysYParser::ConstExpContext* ctx) const { + if (!ctx || !ctx->addExp()) + throw std::runtime_error("常量表达式求值:无效 ConstExp"); + ConstValue val = EvaluateAddExp(ctx->addExp()); + if (val.kind == ConstValue::INT) { + return val.int_val; + } else { + float f = val.float_val; + int i = static_cast(f); + if (std::abs(f - i) > 1e-6) { + throw std::runtime_error("常量表达式求值:浮点常量不能隐式转换为整数"); + } + return i; + } +} + +float SymbolTable::EvaluateConstExpFloat(SysYParser::ConstExpContext* ctx) const { + if (!ctx || !ctx->addExp()) + throw std::runtime_error("常量表达式求值:无效 ConstExp"); + ConstValue val = EvaluateAddExp(ctx->addExp()); + if (val.kind == ConstValue::INT) { + return static_cast(val.int_val); + } else { + return val.float_val; + } +} + +void SymbolTable::flattenInit(SysYParser::ConstInitValContext* ctx, + std::vector& out, + std::shared_ptr base_type) const { + if (!ctx) return; + + // 获取当前初始化列表的文本(用于调试) + std::string ctxText; + if (ctx->constExp()) { + ctxText = ctx->constExp()->getText(); + } else { + ctxText = "{ ... }"; + } + + if (ctx->constExp()) { + ConstValue val = EvaluateAddExp(ctx->constExp()->addExp()); + + DEBUG_MSG("处理常量表达式: " << ctxText + << " 类型=" << (val.kind == ConstValue::INT ? "INT" : "FLOAT") + << " 值=" << (val.kind == ConstValue::INT ? std::to_string(val.int_val) : std::to_string(val.float_val)) + << " 目标类型=" << (base_type->IsInt32() ? "Int32" : "Float")); + + // 整型数组不能接受浮点常量 + if (base_type->IsInt32() && val.kind == ConstValue::FLOAT) { + DEBUG_MSG("错误:整型数组遇到浮点常量,值=" << val.float_val); + throw std::runtime_error("常量初始化:整型数组不能使用浮点常量"); + } + // 浮点数组接受整型常量,并隐式转换 + if (base_type->IsFloat() && val.kind == ConstValue::INT) { + DEBUG_MSG("浮点数组接收整型常量,隐式转换为浮点: " << val.int_val); + val.kind = ConstValue::FLOAT; + val.float_val = static_cast(val.int_val); + } + out.push_back(val); + } else { + DEBUG_MSG("进入花括号初始化列表: " << ctxText); + // 花括号初始化列表:递归展开所有子项 + for (auto* sub : ctx->constInitVal()) { + flattenInit(sub, out, base_type); + } + DEBUG_MSG("退出花括号初始化列表"); + } +} + +std::vector SymbolTable::EvaluateConstInitVal( + SysYParser::ConstInitValContext* ctx, + const std::vector& dims, + std::shared_ptr base_type) const { + + // ========== 1. 标量常量(dims 为空)========== + if (dims.empty()) { + if (!ctx || !ctx->constExp()) { + throw std::runtime_error("标量常量初始化必须使用单个表达式"); + } + ConstValue val = EvaluateAddExp(ctx->constExp()->addExp()); + + // 类型兼容性检查 + /* + if (base_type->IsInt32() && val.kind == ConstValue::FLOAT) { + throw std::runtime_error("整型常量不能使用浮点常量初始化"); + } + */ + // 隐式类型转换 + if (base_type->IsInt32() && val.kind == ConstValue::FLOAT) { + val.kind = ConstValue::INT; + val.int_val = static_cast(val.float_val); + } + if (base_type->IsFloat() && val.kind == ConstValue::INT) { + val.kind = ConstValue::FLOAT; + val.float_val = static_cast(val.int_val); + } + return {val}; // 返回包含单个值的向量 + } + + // ========== 2. 数组常量(dims 非空)========== + size_t total = 1; + for (int d : dims) total *= d; + + ConstValue zero; + if (base_type->IsInt32()) { + zero.kind = ConstValue::INT; + zero.int_val = 0; + } else { + zero.kind = ConstValue::FLOAT; + zero.float_val = 0.0f; + } + + // 先整体补零,再按 C 语言花括号规则覆盖显式初始化项。 + std::vector flat(total, zero); + + auto convert_value = [&](ConstValue v) -> ConstValue { + if (base_type->IsInt32()) { + if (v.kind == ConstValue::FLOAT) { + throw std::runtime_error("常量初始化:整型数组不能使用浮点常量"); + } + return v; + } + if (v.kind == ConstValue::INT) { + ConstValue t; + t.kind = ConstValue::FLOAT; + t.float_val = static_cast(v.int_val); + return t; + } + return v; + }; + + auto subarray_span = [&](size_t depth) -> size_t { + size_t span = 1; + for (size_t i = depth + 1; i < dims.size(); ++i) span *= static_cast(dims[i]); + return span; + }; + + std::function fill; + fill = [&](SysYParser::ConstInitValContext* node, + size_t depth, + size_t begin, + size_t end) -> size_t { + if (!node || begin >= end) return begin; + + // 标量初始化项 + if (node->constExp()) { + ConstValue v = convert_value(EvaluateAddExp(node->constExp()->addExp())); + if (begin < flat.size()) flat[begin] = v; + return std::min(begin + 1, end); + } + + size_t cursor = begin; + for (auto* child : node->constInitVal()) { + if (cursor >= end) break; + + if (child->constExp()) { + ConstValue v = convert_value(EvaluateAddExp(child->constExp()->addExp())); + if (cursor < flat.size()) flat[cursor] = v; + ++cursor; + continue; + } + + // 花括号子列表:在非最内层需要按子聚合边界对齐。 + if (depth + 1 < dims.size()) { + const size_t span = subarray_span(depth); + const size_t rel = (cursor - begin) % span; + if (rel != 0) cursor += (span - rel); + if (cursor >= end) break; + + const size_t sub_end = std::min(cursor + span, end); + fill(child, depth + 1, cursor, sub_end); + // 一个带花括号的子初始化器会消费当前层的一个子聚合。 + cursor = sub_end; + } else { + // 最内层(标量数组)遇到额外花括号,按同层顺序展开。 + cursor = fill(child, depth, cursor, end); + } + } + return cursor; + }; + + fill(ctx, 0, 0, total); + return flat; } + +int SymbolTable::EvaluateConstExpression(SysYParser::ExpContext* ctx) const { + if (!ctx || !ctx->addExp()) { + throw std::runtime_error("常量表达式求值:无效 ExpContext"); + } + ConstValue val = EvaluateAddExp(ctx->addExp()); + if (val.kind == ConstValue::INT) { + return val.int_val; + } else { + float f = val.float_val; + int i = static_cast(f); + if (std::abs(f - i) > 1e-6) { + throw std::runtime_error("常量表达式求值:浮点常量不能隐式转换为整数"); + } + return i; + } +} \ No newline at end of file diff --git a/sylib/sylib.c b/sylib/sylib.c index 7f26d0b..7237ef1 100644 --- a/sylib/sylib.c +++ b/sylib/sylib.c @@ -1,4 +1,162 @@ -// SysY 运行库实现: -// - 按实验/评测规范提供 I/O 等函数实现 -// - 与编译器生成的目标代码链接,支撑运行时行为 +#include "sylib.h" + +#include +#include + +extern int scanf(const char* format, ...); +extern int printf(const char* format, ...); +extern int getchar(void); +extern int putchar(int c); + +int getint(void) { + int x = 0; + scanf("%d", &x); + return x; +} + +int getch(void) { + return getchar(); +} + +int getarray(int a[]) { + int n; + scanf("%d", &n); + int i = 0; + for (; i < n; ++i) { + scanf("%d", &a[i]); + } + return n; +} + +float getfloat(void) { + float x = 0.0f; + scanf("%f", &x); + return x; +} + +int getfarray(float a[]) { + int n = 0; + if (scanf("%d", &n) != 1) { + return 0; + } + int i = 0; + for (; i < n; ++i) { + if (scanf("%f", &a[i]) != 1) { + return i; + } + } + return n; +} + +void putint(int x) { + printf("%d", x); +} + +void putch(int x) { + putchar(x); +} + +void putarray(int n, int a[]) { + int i = 0; + printf("%d:", n); + for (; i < n; ++i) { + printf(" %d", a[i]); + } + putchar('\n'); +} + +void putfloat(float x) { + printf("%a", x); +} + +void putfarray(int n, float a[]) { + int i = 0; + printf("%d:", n); + for (; i < n; ++i) { + printf(" %a", a[i]); + } + putchar('\n'); +} + +void puts(int s[]) { + if (!s) return; + while (*s) { + putchar(*s); + ++s; + } +} + +void _sysy_starttime(int lineno) { + (void)lineno; +} + +void _sysy_stoptime(int lineno) { + (void)lineno; +} + +void starttime(void) { + _sysy_starttime(0); +} + +void stoptime(void) { + _sysy_stoptime(0); +} + +int* memset(int* ptr, int value, int count) { + unsigned char* p = (unsigned char*)ptr; + unsigned char byte = (unsigned char)(value & 0xFF); + int i = 0; + for (; i < count; ++i) { + p[i] = byte; + } + return ptr; +} + +int* sysy_alloc_i32(int count) { + if (count <= 0) { + return 0; + } + return (int*)malloc((size_t)count * sizeof(int)); +} + +float* sysy_alloc_f32(int count) { + if (count <= 0) { + return 0; + } + return (float*)malloc((size_t)count * sizeof(float)); +} + +void sysy_free_i32(int* ptr) { + if (!ptr) { + return; + } + free(ptr); +} + +void sysy_free_f32(float* ptr) { + if (!ptr) { + return; + } + free(ptr); +} + +void sysy_zero_i32(int* ptr, int count) { + int i = 0; + if (!ptr || count <= 0) { + return; + } + for (; i < count; ++i) { + ptr[i] = 0; + } +} + +void sysy_zero_f32(float* ptr, int count) { + int i = 0; + if (!ptr || count <= 0) { + return; + } + for (; i < count; ++i) { + ptr[i] = 0.0f; + } +} diff --git a/sylib/sylib.h b/sylib/sylib.h index 502d488..0d81b83 100644 --- a/sylib/sylib.h +++ b/sylib/sylib.h @@ -1,4 +1,29 @@ -// SysY 运行库头文件: -// - 声明运行库函数原型(供编译器生成 call 或链接阶段引用) -// - 与 sylib.c 配套,按规范逐步补齐声明 +#pragma once + +int getint(void); +int getch(void); +int getarray(int a[]); +float getfloat(void); +int getfarray(float a[]); + +void putint(int x); +void putch(int x); +void putarray(int n, int a[]); +void putfloat(float x); +void putfarray(int n, float a[]); +void puts(int s[]); + +void _sysy_starttime(int lineno); +void _sysy_stoptime(int lineno); +void starttime(void); +void stoptime(void); + +int read_map(void); +int* memset(int* ptr, int value, int count); +int* sysy_alloc_i32(int count); +float* sysy_alloc_f32(int count); +void sysy_free_i32(int* ptr); +void sysy_free_f32(float* ptr); +void sysy_zero_i32(int* ptr, int count); +void sysy_zero_f32(float* ptr, int count); diff --git a/test.c b/test.c new file mode 100644 index 0000000..76e8197 --- /dev/null +++ b/test.c @@ -0,0 +1 @@ +int main() { return 0; } diff --git a/test.sy b/test.sy new file mode 100644 index 0000000..76e8197 --- /dev/null +++ b/test.sy @@ -0,0 +1 @@ +int main() { return 0; } diff --git a/test2.sy b/test2.sy new file mode 100644 index 0000000..05e7a96 --- /dev/null +++ b/test2.sy @@ -0,0 +1 @@ +int add(int a, int b) { return a + b; } int main() { return add(1, 2); }