diff --git a/include/ir/IR.h b/include/ir/IR.h index 20a2e64..ea62b05 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -109,18 +109,25 @@ class Context { std::string NextTemp(); private: - std::unordered_map> const_ints_; + // 数组常量缓存需要添加到类中 + struct ArrayKey { + std::shared_ptr type; + std::vector elements; + + bool operator==(const ArrayKey& other) const; + }; - // 浮点常量:使用整数表示浮点数位模式作为键(避免浮点精度问题) - std::unordered_map> const_floats_; + struct ArrayKeyHash { + size_t operator()(const ArrayKey& key) const; + }; + + std::unordered_map, ArrayKeyHash> array_cache_; - // 零常量缓存(按类型指针) + // 其他现有成员... + std::unordered_map> const_ints_; + std::unordered_map> const_floats_; std::unordered_map> zero_constants_; std::unordered_map> aggregate_zeros_; - - // 数组常量简单存储,不去重(因为数组常量通常组合多样,去重成本高) - std::vector> const_arrays_; - int temp_index_ = -1; }; @@ -357,18 +364,18 @@ class User : public Value { // GlobalValue 是全局值/全局变量体系的空壳占位类。 // 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。 -// ir/IR.h - 修正 GlobalValue 定义 -// ir/IR.h - 修正 GlobalValue 定义 +// ir/IR.h - GlobalValue 类定义需要添加这些方法 + class GlobalValue : public User { private: - std::vector initializer_; // 初始化值列表 - bool is_constant_ = false; // 是否为常量(如const变量) - bool is_extern_ = false; // 是否为外部声明 + std::vector initializer_; + bool is_constant_ = false; + bool is_extern_ = false; public: GlobalValue(std::shared_ptr ty, std::string name); - // 初始化器相关 - 使用 ConstantValue* + // 初始化器相关 void SetInitializer(ConstantValue* init); void SetInitializer(const std::vector& init); const std::vector& GetInitializer() const { return initializer_; } @@ -382,17 +389,28 @@ public: void SetExtern(bool is_extern) { is_extern_ = is_extern; } bool IsExtern() const { return is_extern_; } - // 类型判断 - 使用 Type 的方法 + // 类型判断 bool IsArray() const { return GetType()->IsArray(); } bool IsScalar() const { return GetType()->IsInt32() || GetType()->IsFloat(); } + // 数组常量相关方法 + bool IsArrayConstant() const; + ConstantValue* GetArrayElement(size_t index) const; + size_t GetArraySize() const; + // 获取数组大小(如果是数组类型) - int GetArraySize() const { + int GetArraySizeInElements() const { if (auto* array_ty = dynamic_cast(GetType().get())) { return array_ty->GetElementCount(); } return 0; } + +private: + // 辅助方法 + std::shared_ptr GetValueType() const; + bool CheckTypeCompatibility(std::shared_ptr value_type, + ConstantValue* init) const; }; class Instruction : public User { @@ -742,6 +760,20 @@ class BasicBlock : public Value { return ptr; } + template + T* InsertBeforeTerminator(Args&&... args) { + auto inst = std::make_unique(std::forward(args)...); + auto* ptr = inst.get(); + ptr->SetParent(this); + + auto pos = instructions_.end(); + if (HasTerminator()) { + pos = instructions_.end() - 1; + } + instructions_.insert(pos, std::move(inst)); + return ptr; + } + private: Function* parent_ = nullptr; std::vector> instructions_; @@ -812,6 +844,7 @@ class IRBuilder { BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name); BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name); + AllocaInst* CreateAlloca(std::shared_ptr ty, const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name); AllocaInst* CreateAllocaFloat(const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name); diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index 947ffa5..5f57b4d 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -7,6 +7,7 @@ #include #include #include +#include #include "SysYBaseVisitor.h" #include "SysYParser.h" @@ -22,7 +23,10 @@ class Value; class IRGenImpl final : public SysYBaseVisitor { public: - IRGenImpl(ir::Module& module, const SemanticContext& sema); + // 修改构造函数,添加 SymbolTable 参数 + IRGenImpl(ir::Module& module, + const SemanticContext& sema, + const SymbolTable& sym_table); // 新增 // 顶层 std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; @@ -67,9 +71,21 @@ public: ir::Value* EvalCond(SysYParser::CondContext& cond); std::any visitCallExp(SysYParser::UnaryExpContext* ctx); std::vector ProcessNestedInitVals(SysYParser::InitValContext* ctx); + // 带维度感知的展平:按 C 语言花括号对齐规则填充 total_size 个槽位 + // dims[0] 是最外层维度,dims.back() 是最内层维度(元素层) + // 返回已展平并补零的 total_size 大小的向量 + std::vector FlattenInitVal(SysYParser::InitValContext* ctx, + const std::vector& dims, + bool is_float); int TryEvaluateConstInt(SysYParser::ConstExpContext* ctx); void AddRuntimeFunctions(); ir::Function* CreateRuntimeFunctionDecl(const std::string& funcName); + ir::BasicBlock* EnsureCleanupBlock(); + void RegisterCleanup(ir::Function* free_func, ir::Value* ptr); + ir::AllocaInst* CreateEntryAlloca(std::shared_ptr ty, + const std::string& name); + ir::AllocaInst* CreateEntryAllocaI32(const std::string& name); + ir::AllocaInst* CreateEntryAllocaFloat(const std::string& name); private: // 辅助函数声明 enum class BlockFlow{ @@ -108,6 +124,7 @@ private: ir::Module& module_; const SemanticContext& sema_; + const SymbolTable& symbol_table_; // 新增成员 ir::Function* func_; ir::IRBuilder builder_; ir::Value* EvalAssign(SysYParser::StmtContext* ctx); @@ -119,6 +136,8 @@ private: std::unordered_map local_var_map_; // 局部变量 std::unordered_map global_map_; // 全局变量 std::unordered_map param_map_; // 函数参数 + std::unordered_set pointer_param_names_; // 指针/数组形参名 + std::unordered_set heap_local_array_names_; // 堆分配的局部数组名 // 常量映射:常量名 -> 常量值(标量常量) std::unordered_map const_value_map_; @@ -131,21 +150,23 @@ private: std::unordered_map array_info_map_; + std::string current_function_name_; + bool current_function_is_recursive_ = false; + ir::AllocaInst* function_return_slot_ = nullptr; + ir::BasicBlock* function_cleanup_block_ = nullptr; + std::vector> function_cleanup_actions_; + // 新增:处理全局和局部变量的辅助函数 + // 修改处理函数的签名,使用 Symbol* 参数 std::any HandleGlobalVariable(SysYParser::VarDefContext* ctx, - const std::string& varName, - bool is_array); - std::any HandleLocalVariable(SysYParser::VarDefContext* ctx, - const std::string& varName, - bool is_array); + const std::string& varName, + const Symbol* sym); - // 常量求值辅助函数 - int EvaluateConstAddExp(SysYParser::AddExpContext* ctx); - int EvaluateConstMulExp(SysYParser::MulExpContext* ctx); - int EvaluateConstUnaryExp(SysYParser::UnaryExpContext* ctx); - int EvaluateConstPrimaryExp(SysYParser::PrimaryExpContext* ctx); - int EvaluateConstExp(SysYParser::ExpContext* ctx); + std::any HandleLocalVariable(SysYParser::VarDefContext* ctx, + const std::string& varName, + const Symbol* sym); }; +// 修改 GenerateIR 函数签名 std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, - const SemanticContext& sema); + const SemaResult& sema_result); \ No newline at end of file diff --git a/include/sem/Sema.h b/include/sem/Sema.h index 10e4d8e..c053428 100644 --- a/include/sem/Sema.h +++ b/include/sem/Sema.h @@ -7,7 +7,7 @@ #include "SysYParser.h" #include "ir/IR.h" - +#include "sem/SymbolTable.h" // 表达式信息结构 struct ExprInfo { std::shared_ptr type = nullptr; @@ -91,4 +91,12 @@ private: // 目前仅检查: // - 变量先声明后使用 // - 局部变量不允许重复定义 -SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); \ No newline at end of file +// SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); +// 新增:语义分析结果结构体 +struct SemaResult { + SemanticContext context; + SymbolTable symbol_table; +}; + +// 修改 RunSema 的返回类型 +SemaResult RunSema(SysYParser::CompUnitContext& comp_unit); \ No newline at end of file diff --git a/include/sem/SymbolTable.h b/include/sem/SymbolTable.h index 2dfbc89..ed986f1 100644 --- a/include/sem/SymbolTable.h +++ b/include/sem/SymbolTable.h @@ -1,6 +1,7 @@ // 极简符号表:记录局部变量定义点。 #pragma once +#include #include #include #include @@ -17,46 +18,113 @@ enum class SymbolKind { Constant }; +// 符号条目 // 符号条目 struct Symbol { + // 基本信息 std::string name; SymbolKind kind; - std::shared_ptr type; // 指向 Type 对象的智能指针 - int scope_level = 0; // 定义时的作用域深度 - int stack_offset = -1; // 局部变量/参数栈偏移(全局变量为 -1) - bool is_initialized = false; // 是否已初始化 - bool is_builtin = false; // 是否为库函数 + std::shared_ptr type; + int scope_level = 0; + int stack_offset = -1; + bool is_initialized = false; + bool is_builtin = false; - // 对于数组参数,存储维度信息 - std::vector array_dims; // 数组各维长度(参数数组的第一维可能为0表示省略) - bool is_array_param = false; // 是否是数组参数 + // 数组参数相关 + std::vector array_dims; + bool is_array_param = false; - // 对于函数,额外存储参数列表(类型已包含在函数类型中,这里仅用于快速访问) + // 函数相关 std::vector> param_types; - // 对于常量,存储常量值(这里支持 int32 和 float) + // 常量值存储 union ConstantValue { int i32; float f32; - } const_value; - bool is_int_const = true; // 标记常量类型,用于区分 int 和 float + }; + + // 标量常量 + bool is_int_const = true; + ConstantValue const_value; + + // 数组常量(扁平化存储) + bool is_array_const = false; + std::vector array_const_values; - // 关联的语法树节点(用于报错位置或进一步分析) + // 语法树节点 SysYParser::VarDefContext* var_def_ctx = nullptr; SysYParser::ConstDefContext* const_def_ctx = nullptr; SysYParser::FuncFParamContext* param_def_ctx = nullptr; SysYParser::FuncDefContext* func_def_ctx = nullptr; + + // 辅助方法 + bool IsScalarConstant() const { + return kind == SymbolKind::Constant && !type->IsArray(); + } + + bool IsArrayConstant() const { + return kind == SymbolKind::Constant && type->IsArray(); + } + + int GetIntConstant() const { + if (!IsScalarConstant()) { + throw std::runtime_error("不是标量常量"); + } + if (!is_int_const) { + throw std::runtime_error("不是整型常量"); + } + return const_value.i32; + } + + float GetFloatConstant() const { + if (!IsScalarConstant()) { + throw std::runtime_error("不是标量常量"); + } + if (is_int_const) { + return static_cast(const_value.i32); + } + return const_value.f32; + } + + ConstantValue GetArrayElement(size_t index) const { + if (!IsArrayConstant()) { + throw std::runtime_error("不是数组常量"); + } + if (index >= array_const_values.size()) { + throw std::runtime_error("数组下标越界"); + } + return array_const_values[index]; + } + + size_t GetArraySize() const { + if (!IsArrayConstant()) return 0; + return array_const_values.size(); + } }; - class SymbolTable { public: SymbolTable(); ~SymbolTable() = default; - + // 添加调试方法 + size_t getScopeCount() const { return active_scope_stack_.size(); } + + void dump() const { + std::cerr << "=== SymbolTable Dump ===" << std::endl; + for (size_t i = 0; i < scopes_.size(); ++i) { + std::cerr << "Scope " << i << " (depth=" << i << ")"; + bool active = std::find(active_scope_stack_.begin(), active_scope_stack_.end(), i) != active_scope_stack_.end(); + std::cerr << (active ? " [active]" : " [inactive]") << std::endl; + for (const auto& [name, sym] : scopes_[i]) { + std::cerr << " " << name + << " (kind=" << (int)sym.kind + << ", level=" << sym.scope_level << ")" << std::endl; + } + } + } // ----- 作用域管理 ----- void enterScope(); // 进入新作用域 void exitScope(); // 退出当前作用域 - int currentScopeLevel() const { return static_cast(scopes_.size()) - 1; } + int currentScopeLevel() const { return static_cast(active_scope_stack_.size()) - 1; } // ----- 符号操作(推荐使用)----- bool addSymbol(const Symbol& sym); // 添加符号到当前作用域 @@ -64,6 +132,9 @@ class SymbolTable { Symbol* lookupCurrent(const std::string& name); // 仅在当前作用域查找 const Symbol* lookup(const std::string& name) const; const Symbol* lookupCurrent(const std::string& name) const; + const Symbol* lookupAll(const std::string& name) const; // 所有作用域查找,包括已结束的作用域 + const Symbol* lookupByVarDef(const SysYParser::VarDefContext* decl) const; // 通过定义节点查找符号 + const Symbol* lookupByConstDef(const SysYParser::ConstDefContext* decl) const; // 通过常量定义节点查找符号 // ----- 与原接口兼容(保留原有功能)----- void Add(const std::string& name, SysYParser::VarDefContext* decl); @@ -103,6 +174,7 @@ class SymbolTable { private: // 作用域栈:每个元素是一个从名字到符号的映射 std::vector> scopes_; + std::vector active_scope_stack_; static constexpr int GLOBAL_SCOPE = 0; // 全局作用域索引 diff --git a/optimized.bc b/optimized.bc new file mode 100644 index 0000000..479f919 Binary files /dev/null and b/optimized.bc differ diff --git a/optimized.ll b/optimized.ll new file mode 100644 index 0000000..7f41f77 --- /dev/null +++ b/optimized.ll @@ -0,0 +1,492 @@ +; ModuleID = 'optimized.bc' +source_filename = "./build/test_compiler/performance/03_sort1.ll" + +@a = global [30000010 x i32] zeroinitializer +@ans = local_unnamed_addr global i32 0 + +declare i32 @getarray(ptr) local_unnamed_addr + +declare void @putint(i32) local_unnamed_addr + +declare void @putch(i32) local_unnamed_addr + +declare void @starttime() local_unnamed_addr + +declare void @stoptime() local_unnamed_addr + +declare ptr @sysy_alloc_i32(i32) local_unnamed_addr + +declare void @sysy_free_i32(ptr) local_unnamed_addr + +declare void @sysy_zero_i32(ptr, i32) local_unnamed_addr + +; Function Attrs: nofree norecurse nosync nounwind memory(argmem: read) +define i32 @getMaxNum(i32 %n, ptr nocapture readonly %arr) local_unnamed_addr #0 { +entry: + %t95 = icmp sgt i32 %n, 0 + br i1 %t95, label %while.body.t5, label %while.exit.t6 + +while.body.t5: ; preds = %entry, %while.body.t5 + %t3_i.07 = phi i32 [ %t21, %while.body.t5 ], [ 0, %entry ] + %t2_ret.06 = phi i32 [ %spec.select, %while.body.t5 ], [ 0, %entry ] + %0 = zext nneg i32 %t3_i.07 to i64 + %t13 = getelementptr i32, ptr %arr, i64 %0 + %t14 = load i32, ptr %t13, align 4 + %spec.select = tail call i32 @llvm.smax.i32(i32 %t14, i32 %t2_ret.06) + %t21 = add nuw nsw i32 %t3_i.07, 1 + %t9 = icmp slt i32 %t21, %n + br i1 %t9, label %while.body.t5, label %while.exit.t6 + +while.exit.t6: ; preds = %while.body.t5, %entry + %t2_ret.0.lcssa = phi i32 [ 0, %entry ], [ %spec.select, %while.body.t5 ] + ret i32 %t2_ret.0.lcssa +} + +; Function Attrs: nofree norecurse nosync nounwind memory(none) +define i32 @getNumPos(i32 %num, i32 %pos) local_unnamed_addr #1 { +entry: + %t333 = icmp sgt i32 %pos, 0 + br i1 %t333, label %while.body.t29, label %while.exit.t30 + +while.body.t29: ; preds = %entry, %while.body.t29 + %t27_i.05 = phi i32 [ %t37, %while.body.t29 ], [ 0, %entry ] + %t24.04 = phi i32 [ %t35, %while.body.t29 ], [ %num, %entry ] + %t35 = sdiv i32 %t24.04, 16 + %t37 = add nuw nsw i32 %t27_i.05, 1 + %t33 = icmp slt i32 %t37, %pos + br i1 %t33, label %while.body.t29, label %while.exit.t30 + +while.exit.t30: ; preds = %while.body.t29, %entry + %t24.0.lcssa = phi i32 [ %num, %entry ], [ %t35, %while.body.t29 ] + %t39 = srem i32 %t24.0.lcssa, 16 + ret i32 %t39 +} + +define void @radixSort(i32 %bitround, ptr nocapture %a, i32 %l, i32 %r) local_unnamed_addr { +entry: + %t43 = tail call ptr @sysy_alloc_i32(i32 16) + tail call void @sysy_zero_i32(ptr %t43, i32 16) + %t46 = tail call ptr @sysy_alloc_i32(i32 16) + tail call void @sysy_zero_i32(ptr %t46, i32 16) + %t48 = tail call ptr @sysy_alloc_i32(i32 16) + tail call void @sysy_zero_i32(ptr %t48, i32 16) + %t54 = icmp eq i32 %bitround, -1 + %t56 = add i32 %l, 1 + %t58 = icmp sge i32 %t56, %r + %t59 = or i1 %t54, %t58 + br i1 %t59, label %cleanup.t44, label %while.cond.t62.preheader + +while.cond.t62.preheader: ; preds = %entry + %t6796 = icmp slt i32 %l, %r + br i1 %t6796, label %while.body.t63.lr.ph, label %while.exit.t64 + +while.body.t63.lr.ph: ; preds = %while.cond.t62.preheader + %t333.i = icmp sgt i32 %bitround, 0 + br label %while.body.t63 + +cleanup.t44: ; preds = %merge.t196, %entry + tail call void @sysy_free_i32(ptr %t48) + tail call void @sysy_free_i32(ptr %t46) + tail call void @sysy_free_i32(ptr %t43) + ret void + +while.body.t63: ; preds = %while.body.t63.lr.ph, %getNumPos.exit28 + %storemerge97 = phi i32 [ %l, %while.body.t63.lr.ph ], [ %t83, %getNumPos.exit28 ] + %0 = sext i32 %storemerge97 to i64 + %t69 = getelementptr i32, ptr %a, i64 %0 + %t70 = load i32, ptr %t69, align 4 + br i1 %t333.i, label %while.body.t29.i, label %getNumPos.exit.thread + +getNumPos.exit.thread: ; preds = %while.body.t63 + %t39.i80 = srem i32 %t70, 16 + %1 = sext i32 %t39.i80 to i64 + %t7381 = getelementptr i32, ptr %t48, i64 %1 + %t7482 = load i32, ptr %t7381, align 4 + br label %getNumPos.exit28 + +while.body.t29.i: ; preds = %while.body.t63, %while.body.t29.i + %t27_i.05.i = phi i32 [ %t37.i, %while.body.t29.i ], [ 0, %while.body.t63 ] + %t24.04.i = phi i32 [ %t35.i, %while.body.t29.i ], [ %t70, %while.body.t63 ] + %t35.i = sdiv i32 %t24.04.i, 16 + %t37.i = add nuw nsw i32 %t27_i.05.i, 1 + %t33.i = icmp slt i32 %t37.i, %bitround + br i1 %t33.i, label %while.body.t29.i, label %getNumPos.exit + +getNumPos.exit: ; preds = %while.body.t29.i + %t39.i = srem i32 %t35.i, 16 + %2 = sext i32 %t39.i to i64 + %t73 = getelementptr i32, ptr %t48, i64 %2 + %t74 = load i32, ptr %t73, align 4 + br label %while.body.t29.i22 + +while.body.t29.i22: ; preds = %getNumPos.exit, %while.body.t29.i22 + %t27_i.05.i23 = phi i32 [ %t37.i26, %while.body.t29.i22 ], [ 0, %getNumPos.exit ] + %t24.04.i24 = phi i32 [ %t35.i25, %while.body.t29.i22 ], [ %t70, %getNumPos.exit ] + %t35.i25 = sdiv i32 %t24.04.i24, 16 + %t37.i26 = add nuw nsw i32 %t27_i.05.i23, 1 + %t33.i27 = icmp slt i32 %t37.i26, %bitround + br i1 %t33.i27, label %while.body.t29.i22, label %getNumPos.exit28.loopexit + +getNumPos.exit28.loopexit: ; preds = %while.body.t29.i22 + %.pre114 = srem i32 %t35.i25, 16 + %.pre115 = sext i32 %.pre114 to i64 + br label %getNumPos.exit28 + +getNumPos.exit28: ; preds = %getNumPos.exit28.loopexit, %getNumPos.exit.thread + %.pre-phi116 = phi i64 [ %.pre115, %getNumPos.exit28.loopexit ], [ %1, %getNumPos.exit.thread ] + %t7584.in = phi i32 [ %t74, %getNumPos.exit28.loopexit ], [ %t7482, %getNumPos.exit.thread ] + %t7584 = add i32 %t7584.in, 1 + %t81 = getelementptr i32, ptr %t48, i64 %.pre-phi116 + store i32 %t7584, ptr %t81, align 4 + %t83 = add nsw i32 %storemerge97, 1 + %t67 = icmp slt i32 %t83, %r + br i1 %t67, label %while.body.t63, label %while.exit.t64 + +while.exit.t64: ; preds = %getNumPos.exit28, %while.cond.t62.preheader + store i32 %l, ptr %t43, align 4 + %t88 = load i32, ptr %t48, align 4 + %t89 = add i32 %t88, %l + store i32 %t89, ptr %t46, align 4 + %invariant.gep = getelementptr i32, ptr %t46, i64 -1 + %t101 = getelementptr i32, ptr %t43, i64 1 + store i32 %t89, ptr %t101, align 4 + %t106 = getelementptr i32, ptr %t48, i64 1 + %t107 = load i32, ptr %t106, align 4 + %t108 = add i32 %t107, %t89 + %t110 = getelementptr i32, ptr %t46, i64 1 + store i32 %t108, ptr %t110, align 4 + %t101.1 = getelementptr i32, ptr %t43, i64 2 + store i32 %t108, ptr %t101.1, align 4 + %t106.1 = getelementptr i32, ptr %t48, i64 2 + %t107.1 = load i32, ptr %t106.1, align 4 + %t108.1 = add i32 %t107.1, %t108 + %t110.1 = getelementptr i32, ptr %t46, i64 2 + store i32 %t108.1, ptr %t110.1, align 4 + %t101.2 = getelementptr i32, ptr %t43, i64 3 + store i32 %t108.1, ptr %t101.2, align 4 + %t106.2 = getelementptr i32, ptr %t48, i64 3 + %t107.2 = load i32, ptr %t106.2, align 4 + %t108.2 = add i32 %t107.2, %t108.1 + %t110.2 = getelementptr i32, ptr %t46, i64 3 + store i32 %t108.2, ptr %t110.2, align 4 + %t101.3 = getelementptr i32, ptr %t43, i64 4 + store i32 %t108.2, ptr %t101.3, align 4 + %t106.3 = getelementptr i32, ptr %t48, i64 4 + %t107.3 = load i32, ptr %t106.3, align 4 + %t108.3 = add i32 %t107.3, %t108.2 + %t110.3 = getelementptr i32, ptr %t46, i64 4 + store i32 %t108.3, ptr %t110.3, align 4 + %t101.4 = getelementptr i32, ptr %t43, i64 5 + store i32 %t108.3, ptr %t101.4, align 4 + %t106.4 = getelementptr i32, ptr %t48, i64 5 + %t107.4 = load i32, ptr %t106.4, align 4 + %t108.4 = add i32 %t107.4, %t108.3 + %t110.4 = getelementptr i32, ptr %t46, i64 5 + store i32 %t108.4, ptr %t110.4, align 4 + %t101.5 = getelementptr i32, ptr %t43, i64 6 + store i32 %t108.4, ptr %t101.5, align 4 + %t106.5 = getelementptr i32, ptr %t48, i64 6 + %t107.5 = load i32, ptr %t106.5, align 4 + %t108.5 = add i32 %t107.5, %t108.4 + %t110.5 = getelementptr i32, ptr %t46, i64 6 + store i32 %t108.5, ptr %t110.5, align 4 + %t101.6 = getelementptr i32, ptr %t43, i64 7 + store i32 %t108.5, ptr %t101.6, align 4 + %t106.6 = getelementptr i32, ptr %t48, i64 7 + %t107.6 = load i32, ptr %t106.6, align 4 + %t108.6 = add i32 %t107.6, %t108.5 + %t110.6 = getelementptr i32, ptr %t46, i64 7 + store i32 %t108.6, ptr %t110.6, align 4 + %t101.7 = getelementptr i32, ptr %t43, i64 8 + store i32 %t108.6, ptr %t101.7, align 4 + %t106.7 = getelementptr i32, ptr %t48, i64 8 + %t107.7 = load i32, ptr %t106.7, align 4 + %t108.7 = add i32 %t107.7, %t108.6 + %t110.7 = getelementptr i32, ptr %t46, i64 8 + store i32 %t108.7, ptr %t110.7, align 4 + %t101.8 = getelementptr i32, ptr %t43, i64 9 + store i32 %t108.7, ptr %t101.8, align 4 + %t106.8 = getelementptr i32, ptr %t48, i64 9 + %t107.8 = load i32, ptr %t106.8, align 4 + %t108.8 = add i32 %t107.8, %t108.7 + %t110.8 = getelementptr i32, ptr %t46, i64 9 + store i32 %t108.8, ptr %t110.8, align 4 + %t101.9 = getelementptr i32, ptr %t43, i64 10 + store i32 %t108.8, ptr %t101.9, align 4 + %t106.9 = getelementptr i32, ptr %t48, i64 10 + %t107.9 = load i32, ptr %t106.9, align 4 + %t108.9 = add i32 %t107.9, %t108.8 + %t110.9 = getelementptr i32, ptr %t46, i64 10 + store i32 %t108.9, ptr %t110.9, align 4 + %t101.10 = getelementptr i32, ptr %t43, i64 11 + store i32 %t108.9, ptr %t101.10, align 4 + %t106.10 = getelementptr i32, ptr %t48, i64 11 + %t107.10 = load i32, ptr %t106.10, align 4 + %t108.10 = add i32 %t107.10, %t108.9 + %t110.10 = getelementptr i32, ptr %t46, i64 11 + store i32 %t108.10, ptr %t110.10, align 4 + %t101.11 = getelementptr i32, ptr %t43, i64 12 + store i32 %t108.10, ptr %t101.11, align 4 + %t106.11 = getelementptr i32, ptr %t48, i64 12 + %t107.11 = load i32, ptr %t106.11, align 4 + %t108.11 = add i32 %t107.11, %t108.10 + %t110.11 = getelementptr i32, ptr %t46, i64 12 + store i32 %t108.11, ptr %t110.11, align 4 + %t101.12 = getelementptr i32, ptr %t43, i64 13 + store i32 %t108.11, ptr %t101.12, align 4 + %t106.12 = getelementptr i32, ptr %t48, i64 13 + %t107.12 = load i32, ptr %t106.12, align 4 + %t108.12 = add i32 %t107.12, %t108.11 + %t110.12 = getelementptr i32, ptr %t46, i64 13 + store i32 %t108.12, ptr %t110.12, align 4 + %t101.13 = getelementptr i32, ptr %t43, i64 14 + store i32 %t108.12, ptr %t101.13, align 4 + %t106.13 = getelementptr i32, ptr %t48, i64 14 + %t107.13 = load i32, ptr %t106.13, align 4 + %t108.13 = add i32 %t107.13, %t108.12 + %t110.13 = getelementptr i32, ptr %t46, i64 14 + store i32 %t108.13, ptr %t110.13, align 4 + %t101.14 = getelementptr i32, ptr %t43, i64 15 + store i32 %t108.13, ptr %t101.14, align 4 + %t106.14 = getelementptr i32, ptr %t48, i64 15 + %t107.14 = load i32, ptr %t106.14, align 4 + %t108.14 = add i32 %t107.14, %t108.13 + %t110.14 = getelementptr i32, ptr %t46, i64 15 + store i32 %t108.14, ptr %t110.14, align 4 + %t333.i29 = icmp sgt i32 %bitround, 0 + br label %while.cond.t118.preheader + +while.cond.t118.preheader: ; preds = %while.exit.t64, %while.exit.t120 + %storemerge17104 = phi i32 [ 0, %while.exit.t64 ], [ %t180, %while.exit.t120 ] + %3 = zext nneg i32 %storemerge17104 to i64 + %t122 = getelementptr i32, ptr %t43, i64 %3 + %t125 = getelementptr i32, ptr %t46, i64 %3 + %t123100 = load i32, ptr %t122, align 4 + %t126101 = load i32, ptr %t125, align 4 + %t127102 = icmp slt i32 %t123100, %t126101 + br i1 %t127102, label %while.body.t119, label %while.exit.t120 + +while.body.t191.peel.next: ; preds = %while.exit.t120 + store i32 %l, ptr %t43, align 4 + %t187 = load i32, ptr %t48, align 4 + %t188 = add i32 %t187, %l + store i32 %t188, ptr %t46, align 4 + %t215 = add i32 %bitround, -1 + %t218.peel.pre = load i32, ptr %t43, align 4 + tail call void @radixSort(i32 %t215, ptr %a, i32 %t218.peel.pre, i32 %t188) + br label %merge.t196 + +while.body.t119: ; preds = %while.cond.t118.preheader, %while.exit.t136 + %t123103 = phi i32 [ %t176, %while.exit.t136 ], [ %t123100, %while.cond.t118.preheader ] + %4 = sext i32 %t123103 to i64 + %t132 = getelementptr i32, ptr %a, i64 %4 + %t133 = load i32, ptr %t132, align 4 + br label %while.cond.t134 + +while.exit.t120: ; preds = %while.exit.t136, %while.cond.t118.preheader + %t180 = add nuw nsw i32 %storemerge17104, 1 + %t117 = icmp ult i32 %storemerge17104, 15 + br i1 %t117, label %while.cond.t118.preheader, label %while.body.t191.peel.next + +while.cond.t134: ; preds = %getNumPos.exit78, %while.body.t119 + %t15099 = phi i32 [ %t150129, %getNumPos.exit78 ], [ %t133, %while.body.t119 ] + br i1 %t333.i29, label %while.body.t29.i32, label %getNumPos.exit38.thread + +while.body.t29.i32: ; preds = %while.cond.t134, %while.body.t29.i32 + %t27_i.05.i33 = phi i32 [ %t37.i36, %while.body.t29.i32 ], [ 0, %while.cond.t134 ] + %t24.04.i34 = phi i32 [ %t35.i35, %while.body.t29.i32 ], [ %t15099, %while.cond.t134 ] + %t35.i35 = sdiv i32 %t24.04.i34, 16 + %t37.i36 = add nuw nsw i32 %t27_i.05.i33, 1 + %t33.i37 = icmp slt i32 %t37.i36, %bitround + br i1 %t33.i37, label %while.body.t29.i32, label %getNumPos.exit38 + +getNumPos.exit38: ; preds = %while.body.t29.i32 + %t39.i31 = srem i32 %t35.i35, 16 + %t141.not = icmp eq i32 %t39.i31, %storemerge17104 + br i1 %t141.not, label %while.exit.t136, label %while.body.t29.i42 + +getNumPos.exit38.thread: ; preds = %while.cond.t134 + %t39.i3186 = srem i32 %t15099, 16 + %t141.not87 = icmp eq i32 %t39.i3186, %storemerge17104 + br i1 %t141.not87, label %while.exit.t136, label %getNumPos.exit48.thread + +getNumPos.exit48.thread: ; preds = %getNumPos.exit38.thread + %5 = sext i32 %t39.i3186 to i64 + %t147125 = getelementptr i32, ptr %t43, i64 %5 + %t148126 = load i32, ptr %t147125, align 4 + %6 = sext i32 %t148126 to i64 + %t149127 = getelementptr i32, ptr %a, i64 %6 + %t150128 = load i32, ptr %t149127, align 4 + br label %getNumPos.exit68.thread + +while.body.t29.i42: ; preds = %getNumPos.exit38, %while.body.t29.i42 + %t27_i.05.i43 = phi i32 [ %t37.i46, %while.body.t29.i42 ], [ 0, %getNumPos.exit38 ] + %t24.04.i44 = phi i32 [ %t35.i45, %while.body.t29.i42 ], [ %t15099, %getNumPos.exit38 ] + %t35.i45 = sdiv i32 %t24.04.i44, 16 + %t37.i46 = add nuw nsw i32 %t27_i.05.i43, 1 + %t33.i47 = icmp slt i32 %t37.i46, %bitround + br i1 %t33.i47, label %while.body.t29.i42, label %getNumPos.exit48 + +getNumPos.exit48: ; preds = %while.body.t29.i42 + %.pre117 = srem i32 %t35.i45, 16 + %7 = sext i32 %.pre117 to i64 + %t147 = getelementptr i32, ptr %t43, i64 %7 + %t148 = load i32, ptr %t147, align 4 + %8 = sext i32 %t148 to i64 + %t149 = getelementptr i32, ptr %a, i64 %8 + %t150 = load i32, ptr %t149, align 4 + br i1 %t333.i29, label %while.body.t29.i52, label %getNumPos.exit68.thread + +while.body.t29.i52: ; preds = %getNumPos.exit48, %while.body.t29.i52 + %t27_i.05.i53 = phi i32 [ %t37.i56, %while.body.t29.i52 ], [ 0, %getNumPos.exit48 ] + %t24.04.i54 = phi i32 [ %t35.i55, %while.body.t29.i52 ], [ %t15099, %getNumPos.exit48 ] + %t35.i55 = sdiv i32 %t24.04.i54, 16 + %t37.i56 = add nuw nsw i32 %t27_i.05.i53, 1 + %t33.i57 = icmp slt i32 %t37.i56, %bitround + br i1 %t33.i57, label %while.body.t29.i52, label %while.body.t29.i62.preheader + +while.body.t29.i62.preheader: ; preds = %while.body.t29.i52 + %t39.i51 = srem i32 %t35.i55, 16 + %9 = sext i32 %t39.i51 to i64 + %t155 = getelementptr i32, ptr %t43, i64 %9 + %t156 = load i32, ptr %t155, align 4 + %10 = sext i32 %t156 to i64 + %t157 = getelementptr i32, ptr %a, i64 %10 + store i32 %t15099, ptr %t157, align 4 + br label %while.body.t29.i62 + +getNumPos.exit68.thread: ; preds = %getNumPos.exit48, %getNumPos.exit48.thread + %t150130 = phi i32 [ %t150128, %getNumPos.exit48.thread ], [ %t150, %getNumPos.exit48 ] + %t39.i51.c = srem i32 %t15099, 16 + %11 = sext i32 %t39.i51.c to i64 + %t155.c = getelementptr i32, ptr %t43, i64 %11 + %t156.c = load i32, ptr %t155.c, align 4 + %12 = sext i32 %t156.c to i64 + %t157.c = getelementptr i32, ptr %a, i64 %12 + store i32 %t15099, ptr %t157.c, align 4 + %t16191 = getelementptr i32, ptr %t43, i64 %11 + %t16292 = load i32, ptr %t16191, align 4 + br label %getNumPos.exit78 + +while.body.t29.i62: ; preds = %while.body.t29.i62.preheader, %while.body.t29.i62 + %t27_i.05.i63 = phi i32 [ %t37.i66, %while.body.t29.i62 ], [ 0, %while.body.t29.i62.preheader ] + %t24.04.i64 = phi i32 [ %t35.i65, %while.body.t29.i62 ], [ %t15099, %while.body.t29.i62.preheader ] + %t35.i65 = sdiv i32 %t24.04.i64, 16 + %t37.i66 = add nuw nsw i32 %t27_i.05.i63, 1 + %t33.i67 = icmp slt i32 %t37.i66, %bitround + br i1 %t33.i67, label %while.body.t29.i62, label %getNumPos.exit68 + +getNumPos.exit68: ; preds = %while.body.t29.i62 + %t39.i61 = srem i32 %t35.i65, 16 + %13 = sext i32 %t39.i61 to i64 + %t161 = getelementptr i32, ptr %t43, i64 %13 + %t162 = load i32, ptr %t161, align 4 + br label %while.body.t29.i72 + +while.body.t29.i72: ; preds = %getNumPos.exit68, %while.body.t29.i72 + %t27_i.05.i73 = phi i32 [ %t37.i76, %while.body.t29.i72 ], [ 0, %getNumPos.exit68 ] + %t24.04.i74 = phi i32 [ %t35.i75, %while.body.t29.i72 ], [ %t15099, %getNumPos.exit68 ] + %t35.i75 = sdiv i32 %t24.04.i74, 16 + %t37.i76 = add nuw nsw i32 %t27_i.05.i73, 1 + %t33.i77 = icmp slt i32 %t37.i76, %bitround + br i1 %t33.i77, label %while.body.t29.i72, label %getNumPos.exit78.loopexit + +getNumPos.exit78.loopexit: ; preds = %while.body.t29.i72 + %.pre118 = srem i32 %t35.i75, 16 + %.pre119 = sext i32 %.pre118 to i64 + br label %getNumPos.exit78 + +getNumPos.exit78: ; preds = %getNumPos.exit78.loopexit, %getNumPos.exit68.thread + %t150129 = phi i32 [ %t150, %getNumPos.exit78.loopexit ], [ %t150130, %getNumPos.exit68.thread ] + %.pre-phi120 = phi i64 [ %.pre119, %getNumPos.exit78.loopexit ], [ %11, %getNumPos.exit68.thread ] + %t16394.in = phi i32 [ %t162, %getNumPos.exit78.loopexit ], [ %t16292, %getNumPos.exit68.thread ] + %t16394 = add i32 %t16394.in, 1 + %t167 = getelementptr i32, ptr %t43, i64 %.pre-phi120 + store i32 %t16394, ptr %t167, align 4 + br label %while.cond.t134 + +while.exit.t136: ; preds = %getNumPos.exit38.thread, %getNumPos.exit38 + %t171 = load i32, ptr %t122, align 4 + %14 = sext i32 %t171 to i64 + %t172 = getelementptr i32, ptr %a, i64 %14 + store i32 %t15099, ptr %t172, align 4 + %t175 = load i32, ptr %t122, align 4 + %t176 = add i32 %t175, 1 + store i32 %t176, ptr %t122, align 4 + %t126 = load i32, ptr %t125, align 4 + %t127 = icmp slt i32 %t176, %t126 + br i1 %t127, label %while.body.t119, label %while.exit.t120 + +merge.t196: ; preds = %while.body.t191.peel.next, %merge.t196 + %storemerge18107 = phi i32 [ 1, %while.body.t191.peel.next ], [ %t224, %merge.t196 ] + %15 = zext nneg i32 %storemerge18107 to i64 + %gep106 = getelementptr i32, ptr %invariant.gep, i64 %15 + %t202 = load i32, ptr %gep106, align 4 + %t204 = getelementptr i32, ptr %t43, i64 %15 + store i32 %t202, ptr %t204, align 4 + %t209 = getelementptr i32, ptr %t48, i64 %15 + %t210 = load i32, ptr %t209, align 4 + %t211 = add i32 %t210, %t202 + %t213 = getelementptr i32, ptr %t46, i64 %15 + store i32 %t211, ptr %t213, align 4 + %t218.pre = load i32, ptr %t204, align 4 + tail call void @radixSort(i32 %t215, ptr %a, i32 %t218.pre, i32 %t211) + %t224 = add nuw nsw i32 %storemerge18107, 1 + %t194 = icmp ult i32 %storemerge18107, 15 + br i1 %t194, label %merge.t196, label %cleanup.t44, !llvm.loop !0 +} + +define noundef i32 @main() local_unnamed_addr { +entry: + %t231 = tail call i32 @getarray(ptr nonnull @a) + tail call void @starttime() + tail call void @radixSort(i32 8, ptr nonnull @a, i32 0, i32 %t231) + %ans.promoted = load i32, ptr @ans, align 4 + %t2427 = icmp sgt i32 %t231, 0 + br i1 %t2427, label %while.body.t238, label %while.exit.t239 + +while.body.t238: ; preds = %entry, %while.body.t238 + %t236_i.09 = phi i32 [ %t254, %while.body.t238 ], [ 0, %entry ] + %t25268 = phi i32 [ %t252, %while.body.t238 ], [ %ans.promoted, %entry ] + %0 = zext nneg i32 %t236_i.09 to i64 + %t246 = getelementptr [30000010 x i32], ptr @a, i64 0, i64 %0 + %t247 = load i32, ptr %t246, align 4 + %t249 = add nuw i32 %t236_i.09, 2 + %t250 = srem i32 %t247, %t249 + %t251 = mul i32 %t250, %t236_i.09 + %t252 = add i32 %t251, %t25268 + %t254 = add nuw nsw i32 %t236_i.09, 1 + %t242 = icmp slt i32 %t254, %t231 + br i1 %t242, label %while.body.t238, label %while.cond.t237.while.exit.t239_crit_edge + +while.cond.t237.while.exit.t239_crit_edge: ; preds = %while.body.t238 + store i32 %t252, ptr @ans, align 4 + br label %while.exit.t239 + +while.exit.t239: ; preds = %while.cond.t237.while.exit.t239_crit_edge, %entry + %t257 = phi i32 [ %t252, %while.cond.t237.while.exit.t239_crit_edge ], [ %ans.promoted, %entry ] + %t258 = icmp slt i32 %t257, 0 + br i1 %t258, label %then.t255, label %merge.t256 + +then.t255: ; preds = %while.exit.t239 + %t260 = sub i32 0, %t257 + store i32 %t260, ptr @ans, align 4 + br label %merge.t256 + +merge.t256: ; preds = %then.t255, %while.exit.t239 + tail call void @stoptime() + %t262 = load i32, ptr @ans, align 4 + tail call void @putint(i32 %t262) + tail call void @putch(i32 10) + ret i32 0 +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.smax.i32(i32, i32) #2 + +attributes #0 = { nofree norecurse nosync nounwind memory(argmem: read) } +attributes #1 = { nofree norecurse nosync nounwind memory(none) } +attributes #2 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } + +!0 = distinct !{!0, !1} +!1 = !{!"llvm.loop.peeled.count", i32 1} diff --git a/scripts/test_compiler.sh b/scripts/test_compiler.sh index 5045b56..2c9b672 100755 --- a/scripts/test_compiler.sh +++ b/scripts/test_compiler.sh @@ -5,6 +5,11 @@ ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" COMPILER="$ROOT_DIR/build/bin/compiler" TMP_DIR="$ROOT_DIR/build/test_compiler" TEST_DIRS=("$ROOT_DIR/test/test_case/functional" "$ROOT_DIR/test/test_case/performance") +CC_BIN="${CC:-cc}" +RUNTIME_SRC="$ROOT_DIR/sylib/sylib.c" +RUNTIME_OBJ="$TMP_DIR/sylib.o" +LLC_BIN="${LLC:-llc}" +CLANG_BIN="${CLANG:-clang}" if [[ ! -x "$COMPILER" ]]; then echo "未找到编译器: $COMPILER" @@ -14,6 +19,30 @@ fi mkdir -p "$TMP_DIR" +if ! command -v "$LLC_BIN" >/dev/null 2>&1; then + echo "未找到 llc: $LLC_BIN" + echo "请安装 LLVM,或通过 LLC 环境变量指定 llc 路径" + exit 1 +fi + +if ! command -v "$CLANG_BIN" >/dev/null 2>&1; then + echo "未找到 clang: $CLANG_BIN" + echo "请安装 Clang,或通过 CLANG 环境变量指定 clang 路径" + exit 1 +fi + +# 编译运行库(供链接生成的可执行文件) +runtime_ready=0 +if [[ -f "$RUNTIME_SRC" ]]; then + if "$CC_BIN" -c "$RUNTIME_SRC" -o "$RUNTIME_OBJ" >/dev/null 2>&1; then + runtime_ready=1 + else + echo "[WARN] 运行库编译失败,生成的可执行文件将不链接 sylib: $RUNTIME_SRC" + fi +else + echo "[WARN] 未找到运行库源码: $RUNTIME_SRC" +fi + ir_total=0 ir_pass=0 result_total=0 @@ -65,10 +94,15 @@ for test_dir in "${TEST_DIRS[@]}"; do continue fi - # 检查是否生成了有效的函数定义(在过滤后的内容中检查) - # 先过滤一下看看是否有define - filtered_content=$(sed -E '/^\[DEBUG\]|^SymbolTable::|^Check|^绑定|^保存|^dim_count:/d' "$raw_ll") - if ! echo "$filtered_content" | grep -qE '^define '; then + # 从混杂输出中提取 IR: + # - 顶层实体:define/declare/@global + # - 基本块标签 + # - 缩进的指令行 + # - 函数结束花括号 + grep -E '^(define |declare |@|[[:space:]]|})|^[A-Za-z_.$%][A-Za-z0-9_.$%]*:$' "$raw_ll" > "$ll_file" + + # 检查是否生成了有效函数定义 + if ! grep -qE '^define ' "$ll_file"; then echo " [IR] 失败: 未生成有效函数定义" ir_failures+=("$input: invalid IR output") # 失败:保留原始输出 @@ -76,17 +110,7 @@ for test_dir in "${TEST_DIRS[@]}"; do rm -f "$raw_ll" continue fi - - # 编译成功:过滤掉所有调试输出,只保留IR - # 过滤规则: - # 1. 以 [DEBUG] 开头的行 - # 2. SymbolTable:: 开头的行 - # 3. CheckLValue: 开头的行 - # 4. 绑定变量: 开头的行 - # 5. dim_count: 开头的行 - # 6. 空行(可选) - sed -E '/^(\[DEBUG|SymbolTable::|Check|绑定|保存|dim_)/d' "$raw_ll" > "$ll_file" - + # 可选:删除多余的空行 sed -i '/^$/N;/\n$/D' "$ll_file" @@ -96,30 +120,72 @@ for test_dir in "${TEST_DIRS[@]}"; do echo " [IR] 生成成功 (IR已保存到: $ll_file)" # 运行测试 + # 运行测试部分 if [[ -f "$expected_file" ]]; then result_total=$((result_total+1)) - # 运行LLVM IR + # 运行生成的可执行文件(优先链接运行库) run_status=0 + obj_file="$out_dir/$stem.o" + exe_file="$out_dir/$stem" + if ! "$LLC_BIN" -filetype=obj "$ll_file" -o "$obj_file" > "$stdout_file" 2>&1; then + echo " [RUN] llc 失败" + result_failures+=("$input: llc failed") + continue + fi + + if [[ $runtime_ready -eq 1 ]]; then + if ! "$CLANG_BIN" "$obj_file" "$RUNTIME_OBJ" -o "$exe_file" >> "$stdout_file" 2>&1; then + echo " [RUN] clang 链接失败" + result_failures+=("$input: clang link failed") + continue + fi + else + if ! "$CLANG_BIN" "$obj_file" -o "$exe_file" >> "$stdout_file" 2>&1; then + echo " [RUN] clang 链接失败" + result_failures+=("$input: clang link failed") + continue + fi + fi + if [[ -f "$stdin_file" ]]; then - lli "$ll_file" < "$stdin_file" > "$stdout_file" 2>&1 || run_status=$? + "$exe_file" < "$stdin_file" > "$stdout_file" 2>&1 || run_status=$? else - lli "$ll_file" > "$stdout_file" 2>&1 || run_status=$? + "$exe_file" > "$stdout_file" 2>&1 || run_status=$? fi - # 读取预期返回值 - expected=$(normalize_file "$expected_file") + # 读取预期文件内容 + expected_content=$(normalize_file "$expected_file") - # 比较返回值 - if [[ "$run_status" -eq "$expected" ]]; then - result_pass=$((result_pass+1)) - echo " [RUN] 返回值匹配: $run_status" - # 成功:保留已清理的.ll文件,删除输出文件 - rm -f "$stdout_file" + # 判断预期文件是只包含退出码,还是包含输出+退出码 + if [[ "$expected_content" =~ ^[0-9]+$ ]]; then + # 只包含退出码 + expected=$expected_content + if [[ "$run_status" -eq "$expected" ]]; then + result_pass=$((result_pass+1)) + echo " [RUN] 返回值匹配: $run_status" + rm -f "$stdout_file" + else + echo " [RUN] 返回值不匹配: got $run_status, expected $expected" + result_failures+=("$input: exit code mismatch (got $run_status, expected $expected)") + fi else - echo " [RUN] 返回值不匹配: got $run_status, expected $expected" - result_failures+=("$input: exit code mismatch (got $run_status, expected $expected)") - # 失败:.ll文件已经保留,输出文件也保留用于调试 + # 包含输出和退出码(最后一行是退出码) + expected_output=$(head -n -1 <<< "$expected_content") + expected_exit=$(tail -n 1 <<< "$expected_content") + actual_output=$(cat "$stdout_file") + + if [[ "$run_status" -eq "$expected_exit" ]] && [[ "$actual_output" == "$expected_output" ]]; then + result_pass=$((result_pass+1)) + echo " [RUN] 成功: 退出码和输出都匹配" + rm -f "$stdout_file" + else + echo " [RUN] 不匹配: 退出码 got $run_status, expected $expected_exit" + if [[ "$actual_output" != "$expected_output" ]]; then + echo " 输出不匹配" + fi + result_failures+=("$input: mismatch") + fi fi else echo " [RUN] 未找到预期返回值文件 $expected_file,跳过结果验证" diff --git a/src/ir/Context.cpp b/src/ir/Context.cpp index 2059cde..bcd356b 100644 --- a/src/ir/Context.cpp +++ b/src/ir/Context.cpp @@ -1,9 +1,8 @@ -// 管理基础类型、整型常量池和临时名生成。 // ir/IR.cpp - #include "ir/IR.h" -#include // for memcpy +#include #include +#include namespace ir { @@ -17,9 +16,7 @@ ConstantInt* Context::GetConstInt(int v) { return inserted->second.get(); } -// 新增:获取浮点常量 ConstantFloat* Context::GetConstFloat(float v) { - // 使用浮点数的二进制表示作为键,避免精度问题 uint32_t key; std::memcpy(&key, &v, sizeof(float)); @@ -35,16 +32,68 @@ ConstantFloat* Context::GetConstFloat(float v) { return ptr; } -// 新增:创建数组常量 ConstantArray* Context::GetConstArray(std::shared_ptr ty, std::vector elements) { + // 验证数组常量 + size_t expected_size = ty->GetElementCount(); + if (elements.size() != expected_size) { + // 如果元素数量不匹配,可能需要补零或报错 + // 这里根据需求处理 + if (elements.size() < expected_size) { + // 补零 + auto elem_type = ty->GetElementType(); + while (elements.size() < expected_size) { + if (elem_type->IsInt32()) { + elements.push_back(GetConstInt(0)); + } else if (elem_type->IsFloat()) { + elements.push_back(GetConstFloat(0.0f)); + } + } + } else { + throw std::runtime_error("Array constant size mismatch"); + } + } + + // 构建缓存键 + struct ArrayKey { + std::shared_ptr type; + std::vector elements; + + bool operator==(const ArrayKey& other) const { + if (type != other.type) return false; + if (elements.size() != other.elements.size()) return false; + for (size_t i = 0; i < elements.size(); ++i) { + if (elements[i] != other.elements[i]) return false; + } + return true; + } + }; + + struct ArrayKeyHash { + size_t operator()(const ArrayKey& key) const { + size_t hash = std::hash{}(key.type.get()); + for (auto* elem : key.elements) { + hash ^= std::hash{}(elem) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + } + return hash; + } + }; + + // 使用静态缓存(需要作为成员变量) + static std::unordered_map, ArrayKeyHash> cache; + + ArrayKey key{ty, elements}; + auto it = cache.find(key); + if (it != cache.end()) { + return it->second.get(); + } + auto constant = std::make_unique(ty, std::move(elements)); auto* ptr = constant.get(); - const_arrays_.push_back(std::move(constant)); + cache[std::move(key)] = std::move(constant); return ptr; } -// 新增:获取零常量 ConstantZero* Context::GetZeroConstant(std::shared_ptr ty) { auto it = zero_constants_.find(ty.get()); if (it != zero_constants_.end()) { @@ -57,7 +106,6 @@ ConstantZero* Context::GetZeroConstant(std::shared_ptr ty) { return ptr; } -// 新增:获取聚合类型的零常量 ConstantAggregateZero* Context::GetAggregateZero(std::shared_ptr ty) { auto it = aggregate_zeros_.find(ty.get()); if (it != aggregate_zeros_.end()) { @@ -76,5 +124,4 @@ std::string Context::NextTemp() { return oss.str(); } - } // namespace ir \ No newline at end of file diff --git a/src/ir/GlobalValue.cpp b/src/ir/GlobalValue.cpp index 5b68e0b..24686b1 100644 --- a/src/ir/GlobalValue.cpp +++ b/src/ir/GlobalValue.cpp @@ -1,9 +1,30 @@ // ir/GlobalValue.cpp #include "ir/IR.h" + #include namespace ir { +namespace { + +ConstantValue* GetScalarZeroConstant(const Type& type) { + if (type.IsInt32()) { + static ConstantInt* zero_i32 = new ConstantInt(Type::GetInt32Type(), 0); + return zero_i32; + } + if (type.IsFloat()) { + static ConstantFloat* zero_f32 = new ConstantFloat(Type::GetFloatType(), 0.0f); + return zero_f32; + } + if (type.IsInt1()) { + static ConstantInt* zero_i1 = new ConstantInt(Type::GetInt1Type(), 0); + return zero_i1; + } + return nullptr; +} + +} // namespace + GlobalValue::GlobalValue(std::shared_ptr ty, std::string name) : User(std::move(ty), std::move(name)) {} @@ -13,42 +34,10 @@ void GlobalValue::SetInitializer(ConstantValue* init) { } // 获取实际的值类型(用于类型检查) - std::shared_ptr value_type; - - // 如果当前类型是指针,获取指向的值类型 - if (GetType()->IsPtrInt32()) { - value_type = Type::GetInt32Type(); - } else if (GetType()->IsPtrFloat()) { - value_type = Type::GetFloatType(); - } else if (GetType()->IsPtrInt1()) { - value_type = Type::GetInt1Type(); - } else { - // 非指针类型:直接使用当前类型 - value_type = GetType(); - } + std::shared_ptr value_type = GetValueType(); // 类型检查 - bool type_match = false; - - // 检查标量类型 - if (value_type->IsInt32() && init->GetType()->IsInt32()) { - type_match = true; - } else if (value_type->IsFloat() && init->GetType()->IsFloat()) { - type_match = true; - } else if (value_type->IsInt1() && init->GetType()->IsInt1()) { - type_match = true; - } - // 检查数组类型:允许用单个标量初始化整个数组 - else if (value_type->IsArray()) { - auto* array_ty = static_cast(value_type.get()); - auto* elem_type = array_ty->GetElementType().get(); - - if (elem_type->IsInt32() && init->GetType()->IsInt32()) { - type_match = true; - } else if (elem_type->IsFloat() && init->GetType()->IsFloat()) { - type_match = true; - } - } + bool type_match = CheckTypeCompatibility(value_type, init); if (!type_match) { throw std::runtime_error("GlobalValue::SetInitializer: type mismatch"); @@ -60,23 +49,14 @@ void GlobalValue::SetInitializer(ConstantValue* init) { void GlobalValue::SetInitializer(const std::vector& init) { if (init.empty()) { + initializer_.clear(); return; } // 获取实际的值类型 - std::shared_ptr value_type; + std::shared_ptr value_type = GetValueType(); - if (GetType()->IsPtrInt32()) { - value_type = Type::GetInt32Type(); - } else if (GetType()->IsPtrFloat()) { - value_type = Type::GetFloatType(); - } else if (GetType()->IsPtrInt1()) { - value_type = Type::GetInt1Type(); - } else { - value_type = GetType(); - } - - // 检查类型 + // 类型检查 if (value_type->IsArray()) { auto* array_ty = static_cast(value_type.get()); size_t array_size = array_ty->GetElementCount(); @@ -87,16 +67,23 @@ void GlobalValue::SetInitializer(const std::vector& init) { // 检查每个初始化值的类型 auto* elem_type = array_ty->GetElementType().get(); - for (auto* elem : init) { + for (size_t i = 0; i < init.size(); ++i) { + auto* elem = init[i]; + if (!elem) { + throw std::runtime_error("GlobalValue::SetInitializer: null initializer at index " + std::to_string(i)); + } + bool elem_match = false; if (elem_type->IsInt32() && elem->GetType()->IsInt32()) { elem_match = true; } else if (elem_type->IsFloat() && elem->GetType()->IsFloat()) { elem_match = true; + } else if (elem_type->IsInt1() && elem->GetType()->IsInt1()) { + elem_match = true; } if (!elem_match) { - throw std::runtime_error("GlobalValue::SetInitializer: element type mismatch"); + throw std::runtime_error("GlobalValue::SetInitializer: element type mismatch at index " + std::to_string(i)); } } } @@ -105,6 +92,10 @@ void GlobalValue::SetInitializer(const std::vector& init) { throw std::runtime_error("GlobalValue::SetInitializer: scalar requires exactly one initializer"); } + if (!init[0]) { + throw std::runtime_error("GlobalValue::SetInitializer: null initializer"); + } + if ((value_type->IsInt32() && !init[0]->GetType()->IsInt32()) || (value_type->IsFloat() && !init[0]->GetType()->IsFloat()) || (value_type->IsInt1() && !init[0]->GetType()->IsInt1())) { @@ -118,4 +109,87 @@ void GlobalValue::SetInitializer(const std::vector& init) { initializer_ = init; } +// 辅助方法:获取实际的值类型(处理指针包装) +std::shared_ptr GlobalValue::GetValueType() const { + if (GetType()->IsPtrInt32()) { + return Type::GetInt32Type(); + } else if (GetType()->IsPtrFloat()) { + return Type::GetFloatType(); + } else if (GetType()->IsPtrInt1()) { + return Type::GetInt1Type(); + } + return GetType(); +} + +// 辅助方法:检查类型兼容性 +bool GlobalValue::CheckTypeCompatibility(std::shared_ptr value_type, + ConstantValue* init) const { + // 检查标量类型 + if (value_type->IsInt32() && init->GetType()->IsInt32()) { + return true; + } else if (value_type->IsFloat() && init->GetType()->IsFloat()) { + return true; + } else if (value_type->IsInt1() && init->GetType()->IsInt1()) { + return true; + } + // 检查数组类型:允许用单个标量初始化整个数组 + else if (value_type->IsArray()) { + auto* array_ty = static_cast(value_type.get()); + auto* elem_type = array_ty->GetElementType().get(); + + if (elem_type->IsInt32() && init->GetType()->IsInt32()) { + return true; + } else if (elem_type->IsFloat() && init->GetType()->IsFloat()) { + return true; + } else if (elem_type->IsInt1() && init->GetType()->IsInt1()) { + return true; + } + // 也可以允许 ConstantArray 作为初始化器 + else if (init->GetType()->IsArray()) { + auto* init_array = static_cast(init); + return init_array->IsValid(); + } + } + // 检查指针类型(用于数组参数) + else if (value_type->IsPtrInt32() && init->GetType()->IsInt32()) { + return true; + } else if (value_type->IsPtrFloat() && init->GetType()->IsFloat()) { + return true; + } + + return false; +} + +// 添加获取数组元素的便捷方法 +ConstantValue* GlobalValue::GetArrayElement(size_t index) const { + if (!GetType()->IsArray()) { + return nullptr; + } + + auto* array_ty = dynamic_cast(GetType().get()); + if (!array_ty) { + return nullptr; + } + if (index >= static_cast(array_ty->GetElementCount())) { + return nullptr; + } + if (index >= initializer_.size()) { + return GetScalarZeroConstant(*array_ty->GetElementType()); + } + return initializer_[index]; +} + +// 添加获取数组元素数量的方法 +size_t GlobalValue::GetArraySize() const { + if (!IsArrayConstant()) { + return 0; + } + return initializer_.size(); +} + +// 添加判断是否为数组常量的方法 +bool GlobalValue::IsArrayConstant() const { + return GetType()->IsArray() && !initializer_.empty(); +} + } // namespace ir \ No newline at end of file diff --git a/src/ir/IRBuilder.cpp b/src/ir/IRBuilder.cpp index 18b7169..9b7c545 100644 --- a/src/ir/IRBuilder.cpp +++ b/src/ir/IRBuilder.cpp @@ -119,6 +119,17 @@ BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs, return CreateBinary(Opcode::Add, lhs, rhs, name); } +AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr ty, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!ty) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateAlloca 缺少类型")); + } + return insert_block_->Append(ty, name); +} + AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); @@ -190,18 +201,21 @@ StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { } } else if (ptr_ty->IsPtrFloat()) { if (!val_ty->IsFloat()) { - throw std::runtime_error(FormatError("ir", "存储类型不匹配:期望 float")); + throw std::runtime_error( + FormatError("ir", "存储类型不匹配:期望 float, 实际 kind=" + + std::to_string(static_cast(val_ty->GetKind())))); } } else if (ptr_ty->IsArray()) { - // 数组存储:检查元素类型 - auto* array_ty = dynamic_cast(ptr_ty.get()); - if (array_ty) { - auto elem_ty = array_ty->GetElementType(); - if (elem_ty->IsInt32() && !val_ty->IsInt32()) { - throw std::runtime_error(FormatError("ir", "数组元素类型不匹配:期望 int32")); - } else if (elem_ty->IsFloat() && !val_ty->IsFloat()) { - throw std::runtime_error(FormatError("ir", "数组元素类型不匹配:期望 float")); - } + // 数组存储支持两种形式: + // 1. 标量元素写入(通常配合 GEP 后落到元素指针,不会走到这里) + // 2. 聚合数组整体写入,例如 `store [16 x i32] zeroinitializer, [16 x i32]* %arr` + if (!val_ty->IsArray()) { + throw std::runtime_error( + FormatError("ir", "数组地址仅支持聚合数组整体存储")); + } + if (val_ty->GetKind() != ptr_ty->GetKind()) { + throw std::runtime_error( + FormatError("ir", "聚合数组存储类型不匹配")); } } @@ -212,10 +226,6 @@ ReturnInst* IRBuilder::CreateRet(Value* v) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); } - if (!v) { - throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateRet 缺少返回值")); - } return insert_block_->Append(Type::GetVoidType(), v); } @@ -386,7 +396,8 @@ ZExtInst* IRBuilder::CreateZExt(Value* value, std::shared_ptr target_ty, FormatError("ir", "ZExt 目标类型必须是整数类型")); } - return insert_block_->Append(value, target_ty, name); + const std::string inst_name = name.empty() ? ctx_.NextTemp() : name; + return insert_block_->Append(value, target_ty, inst_name); } // 创建截断指令 @@ -416,7 +427,8 @@ TruncInst* IRBuilder::CreateTrunc(Value* value, std::shared_ptr target_ty, FormatError("ir", "Trunc 目标类型必须是整数类型")); } - return insert_block_->Append(value, target_ty, name); + const std::string inst_name = name.empty() ? ctx_.NextTemp() : name; + return insert_block_->Append(value, target_ty, inst_name); } // 便捷方法:i1 转 i32 @@ -466,7 +478,9 @@ BinaryInst* IRBuilder::CreateAnd(Value* lhs, Value* rhs, const std::string& name if (!rhs) { throw std::runtime_error(FormatError("ir", "IRBuilder::CreateAnd 缺少 rhs")); } - return insert_block_->Append(Opcode::And, Type::GetInt32Type(), lhs, rhs, name); + auto result_ty = lhs->GetType()->IsInt1() ? Type::GetInt1Type() + : Type::GetInt32Type(); + return insert_block_->Append(Opcode::And, result_ty, lhs, rhs, name); } BinaryInst* IRBuilder::CreateOr(Value* lhs, Value* rhs, const std::string& name) { @@ -479,7 +493,9 @@ BinaryInst* IRBuilder::CreateOr(Value* lhs, Value* rhs, const std::string& name) if (!rhs) { throw std::runtime_error(FormatError("ir", "IRBuilder::CreateOr 缺少 rhs")); } - return insert_block_->Append(Opcode::Or, Type::GetInt32Type(), lhs, rhs, name); + auto result_ty = lhs->GetType()->IsInt1() ? Type::GetInt1Type() + : Type::GetInt32Type(); + return insert_block_->Append(Opcode::Or, result_ty, lhs, rhs, name); } IcmpInst* IRBuilder::CreateNot(Value* val, const std::string& name) { @@ -489,7 +505,12 @@ IcmpInst* IRBuilder::CreateNot(Value* val, const std::string& name) { if (!val) { throw std::runtime_error(FormatError("ir", "IRBuilder::CreateNot 缺少 operand")); } - auto zero = CreateConstInt(0); + if (val->GetType()->IsInt1()) { + auto* ext = CreateZExtI1ToI32(val, ""); + auto* zero = CreateConstInt(0); + return CreateICmpEQ(ext, zero, name); + } + auto* zero = CreateConstInt(0); return CreateICmpEQ(val, zero, name); } @@ -511,8 +532,29 @@ GEPInst* IRBuilder::CreateGEP(Value* base, } } - // GEP返回指针类型,假设与base类型相同 - return insert_block_->Append(base->GetType(), base, indices, name); + // 结果类型推断: + // - 对 i32*/float* 基址,结果仍分别为 i32*/float* + // - 对数组基址,按多索引向下剥离元素类型;若到达标量则返回对应标量指针 + // (本项目没有“指向数组的指针类型”,未完全剥离时退回数组类型) + std::shared_ptr result_ty = base->GetType(); + if (base->GetType()->IsPtrInt32()) { + result_ty = Type::GetPtrInt32Type(); + } else if (base->GetType()->IsPtrFloat()) { + result_ty = Type::GetPtrFloatType(); + } else if (base->GetType()->IsArray()) { + std::shared_ptr cur = base->GetType(); + for (size_t i = 1; i < indices.size(); ++i) { + auto* at = dynamic_cast(cur.get()); + if (!at) break; + cur = at->GetElementType(); + } + + if (cur->IsInt32()) result_ty = Type::GetPtrInt32Type(); + else if (cur->IsFloat()) result_ty = Type::GetPtrFloatType(); + else result_ty = cur; + } + + return insert_block_->Append(result_ty, base, indices, name); } @@ -609,12 +651,20 @@ FcmpInst* IRBuilder::CreateFCmpOGE(Value* lhs, Value* rhs, const std::string& na // 类型转换 SIToFPInst* IRBuilder::CreateSIToFP(Value* value, std::shared_ptr target_ty, const std::string& name) { - return insert_block_->Append(value, target_ty, name); + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + const std::string inst_name = name.empty() ? ctx_.NextTemp() : name; + return insert_block_->Append(value, target_ty, inst_name); } FPToSIInst* IRBuilder::CreateFPToSI(Value* value, std::shared_ptr target_ty, const std::string& name) { - return insert_block_->Append(value, target_ty, name); + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + const std::string inst_name = name.empty() ? ctx_.NextTemp() : name; + return insert_block_->Append(value, target_ty, inst_name); } } // namespace ir diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index c4b0624..a3a6278 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -4,6 +4,9 @@ #include "ir/IR.h" +#include +#include +#include #include #include #include @@ -12,7 +15,109 @@ namespace ir { -static const char* TypeToString(const Type& ty) { +static std::string TypeToString(const Type& ty); + +static std::string ArrayTypeToStringFrom(const Type& base_ty, + const std::vector& dims, + size_t begin) { + std::string s = TypeToString(base_ty); + for (size_t i = dims.size(); i-- > begin;) { + s = "[" + std::to_string(dims[i]) + " x " + s + "]"; + } + return s; +} + +static bool IsZeroConstant(const ConstantValue* value) { + if (!value) { + return true; + } + if (auto* ci = dynamic_cast(value)) { + return ci->GetValue() == 0; + } + if (auto* cf = dynamic_cast(value)) { + return cf->GetValue() == 0.0f; + } + if (dynamic_cast(value) || + dynamic_cast(value)) { + return true; + } + if (auto* arr = dynamic_cast(value)) { + for (auto* elem : arr->GetElements()) { + if (!IsZeroConstant(elem)) { + return false; + } + } + return true; + } + return false; +} + +static size_t AggregateSpan(const std::vector& dims, size_t level) { + size_t span = 1; + for (size_t i = level; i < dims.size(); ++i) { + span *= static_cast(dims[i]); + } + return span; +} + +static bool IsZeroRange(const std::vector& init, + size_t begin, + size_t count) { + for (size_t i = 0; i < count; ++i) { + const size_t index = begin + i; + if (index >= init.size()) { + continue; + } + if (!IsZeroConstant(init[index])) { + return false; + } + } + return true; +} + +static void PrintFlatArrayBody(std::ostream& os, + const Type& base_ty, + const std::vector& dims, + size_t level, + const std::vector& init, + size_t& flat_index) { + const size_t span = AggregateSpan(dims, level); + if (IsZeroRange(init, flat_index, span)) { + os << "zeroinitializer"; + flat_index += span; + return; + } + + os << "["; + for (int i = 0; i < dims[level]; ++i) { + if (i > 0) os << ", "; + + if (level + 1 < dims.size()) { + os << ArrayTypeToStringFrom(base_ty, dims, level + 1) << " "; + PrintFlatArrayBody(os, base_ty, dims, level + 1, init, flat_index); + continue; + } + + os << TypeToString(base_ty) << " "; + if (flat_index < init.size() && init[flat_index]) { + if (auto* ci = dynamic_cast(init[flat_index])) { + os << ci->GetValue(); + } else if (auto* cf = dynamic_cast(init[flat_index])) { + os << cf->GetValue(); + } else if (IsZeroConstant(init[flat_index])) { + os << "0"; + } else { + os << "0"; + } + } else { + os << "0"; + } + ++flat_index; + } + os << "]"; +} + +static std::string TypeToString(const Type& ty) { switch (ty.GetKind()) { case Type::Kind::Void: return "void"; case Type::Kind::Int32: return "i32"; @@ -20,10 +125,23 @@ static const char* TypeToString(const Type& ty) { case Type::Kind::PtrInt32: return "i32*"; case Type::Kind::PtrFloat: return "float*"; case Type::Kind::Label: return "label"; - case Type::Kind::Array: return "array"; case Type::Kind::Function: return "function"; case Type::Kind::Int1: return "i1"; case Type::Kind::PtrInt1: return "i1*"; + case Type::Kind::Array: { + // 打印数组类型为 LLVM 风格,如 [4 x [2 x i32]] + auto* at = dynamic_cast(&ty); + if (!at) return "array"; + // 递归构建类型字符串 + std::string elem = TypeToString(*at->GetElementType()); + const auto& dims = at->GetDimensions(); + // 从外到内构建 + std::string s = elem; + for (auto it = dims.rbegin(); it != dims.rend(); ++it) { + s = "[" + std::to_string(*it) + " x " + s + "]"; + } + return s; + } default: return "unknown"; } throw std::runtime_error(FormatError("ir", "未知类型")); @@ -54,9 +172,9 @@ static const char* OpcodeToString(Opcode op) { case Opcode::Icmp: return "icmp"; case Opcode::Div: - return "div"; + return "sdiv"; case Opcode::Mod: - return "mod"; + return "srem"; case Opcode::ZExt: return "zext"; case Opcode::Trunc: @@ -82,12 +200,29 @@ static const char* OpcodeToString(Opcode op) { return "?"; } +// 将 float 值转为 LLVM IR 接受的 64-bit 十六进制浮点格式 +static std::string FloatToLLVMHex(float f) { + double d = static_cast(f); + uint64_t bits; + memcpy(&bits, &d, sizeof(bits)); + char buf[20]; + snprintf(buf, sizeof(buf), "0x%016llX", (unsigned long long)bits); + return buf; +} + static std::string ValueToString(const Value* v) { + if (!v) { + return ""; + } + if (dynamic_cast(v) || + dynamic_cast(v)) { + return "zeroinitializer"; + } if (auto* ci = dynamic_cast(v)) { return std::to_string(ci->GetValue()); } - if (!v) { - return ""; + if (auto* cf = dynamic_cast(v)) { + return FloatToLLVMHex(cf->GetValue()); } const auto& name = v->GetName(); if (name.empty()) { @@ -102,27 +237,107 @@ static std::string ValueToString(const Value* v) { return "%" + name; } +static std::string MemoryTypeToString(const Type& ty) { + std::string text = TypeToString(ty); + if (ty.IsArray()) { + text += "*"; + } + return text; +} + void IRPrinter::Print(const Module& module, std::ostream& os) { for (const auto& global : module.GetGlobals()) { if (!global) continue; - os << "@" << global->GetName() << " = global "; + os << "@" << global->GetName() << " = " + << (global->IsConstant() ? "constant " : "global "); + if (global->GetType()->IsPtrInt32()) { - os << "i32 0\n"; - } else if (global->GetType()->IsPtrFloat()) { - os << "float 0.0\n"; - } else { - os << TypeToString(*global->GetType()) << " zeroinitializer\n"; + os << "i32 "; + if (global->HasInitializer()) { + auto* ci = dynamic_cast(global->GetInitializer().front()); + os << (ci ? ci->GetValue() : 0); + } else { + os << "0"; + } + os << "\n"; + continue; + } + + if (global->GetType()->IsPtrFloat()) { + os << "float "; + if (global->HasInitializer()) { + auto* cf = dynamic_cast(global->GetInitializer().front()); + os << (cf ? ValueToString(cf) : FloatToLLVMHex(0.0f)); + } else { + os << FloatToLLVMHex(0.0f); + } + os << "\n"; + continue; + } + + if (global->GetType()->IsArray()) { + auto* at = dynamic_cast(global->GetType().get()); + os << TypeToString(*global->GetType()) << " "; + if (!at || !global->HasInitializer() || + IsZeroRange(global->GetInitializer(), 0, AggregateSpan(at->GetDimensions(), 0))) { + os << "zeroinitializer\n"; + continue; + } + + size_t flat_index = 0; + PrintFlatArrayBody(os, + *at->GetElementType(), + at->GetDimensions(), + 0, + global->GetInitializer(), + flat_index); + os << "\n"; + continue; } + + os << TypeToString(*global->GetType()) << " zeroinitializer\n"; } - for (const auto& func : module.GetFunctions()) { - auto* func_ty = static_cast(func->GetType().get()); - os << "define " << TypeToString(*func_ty->GetReturnType()) << " @" << func->GetName() << "("; + + auto print_func_params = [&](const Function* func, + const FunctionType* func_ty) { bool first = true; - for (const auto& arg : func->GetArguments()) { + if (!func->GetArguments().empty()) { + for (const auto& arg : func->GetArguments()) { + if (!first) os << ", "; + first = false; + os << TypeToString(*arg->GetType()) << " %" << arg->GetName(); + } + return; + } + + for (const auto& pty : func_ty->GetParamTypes()) { if (!first) os << ", "; first = false; - os << TypeToString(*arg->GetType()) << " %" << arg->GetName(); + os << TypeToString(*pty); } + }; + + auto is_declaration_only = [](const Function* func) { + const auto& blocks = func->GetBlocks(); + if (blocks.size() != 1) return false; + const auto& only = blocks.front(); + if (!only) return false; + return only->GetInstructions().empty(); + }; + + for (const auto& func : module.GetFunctions()) { + auto* func_ty = static_cast(func->GetType().get()); + if (is_declaration_only(func.get())) { + os << "declare " << TypeToString(*func_ty->GetReturnType()) << " @" + << func->GetName() << "("; + print_func_params(func.get(), func_ty); + os << ")\n"; + continue; + } + + os << "define " << TypeToString(*func_ty->GetReturnType()) << " @" + << func->GetName() << "("; + print_func_params(func.get(), func_ty); os << ") {\n"; for (const auto& bb : func->GetBlocks()) { if (!bb) { @@ -139,7 +354,11 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { case Opcode::Mod: case Opcode::And: case Opcode::Not: - case Opcode::Or: + case Opcode::Or: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: { auto* bin = static_cast(inst); os << " " << bin->GetName() << " = " @@ -166,7 +385,7 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { auto* load = static_cast(inst); os << " " << load->GetName() << " = load " << TypeToString(*load->GetType()) << ", " - << TypeToString(*load->GetPtr()->GetType()) << " " + << MemoryTypeToString(*load->GetPtr()->GetType()) << " " << ValueToString(load->GetPtr()) << "\n"; break; } @@ -174,21 +393,29 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { auto* store = static_cast(inst); os << " store " << TypeToString(*store->GetValue()->GetType()) << " " << ValueToString(store->GetValue()) - << ", " << TypeToString(*store->GetPtr()->GetType()) << " " + << ", " << MemoryTypeToString(*store->GetPtr()->GetType()) << " " << ValueToString(store->GetPtr()) << "\n"; break; } case Opcode::Ret: { auto* ret = static_cast(inst); - os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " - << ValueToString(ret->GetValue()) << "\n"; + if (!ret->GetValue()) { + os << " ret void\n"; + } else { + os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " + << ValueToString(ret->GetValue()) << "\n"; + } break; } // CallInst类在 include/ir/IR.h 中定义 case Opcode::Call: { auto* call = static_cast(inst); - os << " " << call->GetName() << " = call " - << TypeToString(*call->GetType()) << " @" << call->GetCallee()->GetName() << "("; + os << " "; + if (!call->GetType()->IsVoid()) { + os << call->GetName() << " = "; + } + os << "call " << TypeToString(*call->GetType()) << " @" + << call->GetCallee()->GetName() << "("; bool first = true; for (auto* arg : call->GetArgs()) { if (!first) os << ", "; @@ -248,16 +475,81 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { break; } case Opcode::GEP:{ - // 简化打印:只打印基本信息和操作数数量 + // 打印为类似 LLVM 的 getelementptr 形式: + // getelementptr , , i32 , i32 , ... os << " " << inst->GetName() << " = getelementptr "; - os << TypeToString(*inst->GetType()) << " ("; - for (size_t i = 0; i < inst->GetNumOperands(); ++i) { - if (i > 0) os << ", "; - os << ValueToString(inst->GetOperand(i)); + // 基地址类型使用第一个操作数的类型 + Value* base = inst->GetOperand(0); + // GEP 的第一个类型参数应是基址指向的元素类型(pointee)。 + std::string elem_ty; + if (base->GetType()->IsPtrInt32()) elem_ty = "i32"; + else if (base->GetType()->IsPtrFloat()) elem_ty = "float"; + else if (base->GetType()->IsArray()) elem_ty = TypeToString(*base->GetType()); + else elem_ty = TypeToString(*inst->GetType()); + + std::string base_ty = TypeToString(*base->GetType()); + if (base->GetType()->IsArray()) { + base_ty += "*"; } - os << ")\n"; + + os << elem_ty << ", " << base_ty << " " << ValueToString(base); + + // 后续操作数为索引,按照 i32 打印 + // 特殊处理:如果 base 是标量指针(i32*/float*)且第一个索引是常量 0 + // 且后续还有索引,则丢弃第一个 0(对 T* 来说多余且会导致无效 IR)。 + size_t start_idx = 1; + if ((base->GetType()->IsPtrInt32() || base->GetType()->IsPtrFloat()) && + inst->GetNumOperands() >= 3) { + // 检查第一个索引是否为常量 0 + auto* first_idx = inst->GetOperand(1); + if (auto* ci = dynamic_cast(first_idx)) { + if (ci->GetValue() == 0) { + start_idx = 2; // 跳过第一个 0 + } + } + } + for (size_t i = start_idx; i < inst->GetNumOperands(); ++i) { + os << ", i32 " << ValueToString(inst->GetOperand(i)); + } + os << "\n"; + break; + } + case Opcode::FCmp: { + auto* fcmp = static_cast(inst); + os << " " << fcmp->GetName() << " = fcmp "; + switch (fcmp->GetPredicate()) { + case FcmpInst::Predicate::OEQ: os << "oeq"; break; + case FcmpInst::Predicate::ONE: os << "one"; break; + case FcmpInst::Predicate::OLT: os << "olt"; break; + case FcmpInst::Predicate::OLE: os << "ole"; break; + case FcmpInst::Predicate::OGT: os << "ogt"; break; + case FcmpInst::Predicate::OGE: os << "oge"; break; + default: os << "oeq"; break; + } + os << " " << TypeToString(*fcmp->GetLhs()->GetType()) + << " " << ValueToString(fcmp->GetLhs()) + << ", " << ValueToString(fcmp->GetRhs()) << "\n"; break; } + + case Opcode::SIToFP: { + auto* sitofp = static_cast(inst); + os << " " << sitofp->GetName() << " = sitofp " + << TypeToString(*sitofp->GetValue()->GetType()) << " " + << ValueToString(sitofp->GetValue()) << " to " + << TypeToString(*sitofp->GetType()) << "\n"; + break; + } + + case Opcode::FPToSI: { + auto* fptosi = static_cast(inst); + os << " " << fptosi->GetName() << " = fptosi " + << TypeToString(*fptosi->GetValue()->GetType()) << " " + << ValueToString(fptosi->GetValue()) << " to " + << TypeToString(*fptosi->GetType()) << "\n"; + break; + } + default: { // 处理未知操作码 os << " ; 未知指令: " << OpcodeToString(inst->GetOpcode()) << "\n"; @@ -286,10 +578,10 @@ void IRPrinter::PrintConstant(const ConstantValue* constant, std::ostream& os) { } os << "]"; } - else if (auto* zero = dynamic_cast(constant)) { + else if (dynamic_cast(constant)) { os << "zero"; } - else if (auto* agg_zero = dynamic_cast(constant)) { + else if (dynamic_cast(constant)) { os << "zeroinitializer"; } } diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index 76ea812..d0f280a 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -73,6 +73,10 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, case Opcode::Mod: case Opcode::And: case Opcode::Or: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: // 有效的二元操作符 break; case Opcode::Not: @@ -96,21 +100,26 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, throw std::runtime_error(FormatError("ir", "BinaryInst 操作数类型不匹配")); } + bool is_logical = (op == Opcode::And || op == Opcode::Or); + // 检查操作数类型是否支持 - if (!lhs->GetType()->IsInt32() && !lhs->GetType()->IsFloat()) { - throw std::runtime_error( - FormatError("ir", "BinaryInst 只支持 int32 和 float 类型")); + if (is_logical) { + if (!lhs->GetType()->IsInt32() && !lhs->GetType()->IsInt1()) { + throw std::runtime_error( + FormatError("ir", "逻辑运算仅支持 i32/i1")); + } + } else { + if (!lhs->GetType()->IsInt32() && !lhs->GetType()->IsFloat()) { + throw std::runtime_error( + FormatError("ir", "BinaryInst 只支持 int32 和 float 类型")); + } } - // 对于算术运算,结果类型应与操作数类型相同 - - bool is_logical = (op == Opcode::And || op == Opcode::Or); - if (is_logical) { - // 比较和逻辑运算的结果应该是整数类型 - if (!type_->IsInt32()) { + // 逻辑运算结果类型应与操作数一致(i1 或 i32)。 + if (type_->GetKind() != lhs->GetType()->GetKind()) { throw std::runtime_error( - FormatError("ir", "比较和逻辑运算的结果类型必须是 int32")); + FormatError("ir", "逻辑运算结果类型与操作数类型不匹配")); } } else { // 算术运算的结果类型应与操作数类型相同 @@ -130,21 +139,27 @@ Value* BinaryInst::GetRhs() const { return GetOperand(1); } ReturnInst::ReturnInst(std::shared_ptr void_ty, Value* val) : Instruction(Opcode::Ret, std::move(void_ty), "") { - if (!val) { - throw std::runtime_error(FormatError("ir", "ReturnInst 缺少返回值")); - } if (!type_ || !type_->IsVoid()) { throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void")); } - AddOperand(val); + if (val) { + AddOperand(val); + } } -Value* ReturnInst::GetValue() const { return GetOperand(0); } +Value* ReturnInst::GetValue() const { + if (GetNumOperands() == 0) { + return nullptr; + } + return GetOperand(0); +} AllocaInst::AllocaInst(std::shared_ptr ptr_ty, std::string name) : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) { - if (!type_ || !type_->IsPtrInt32()) { - throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*")); + if (!type_ || + (!type_->IsPtrInt32() && !type_->IsPtrFloat() && !type_->IsArray())) { + throw std::runtime_error( + FormatError("ir", "AllocaInst 仅支持 i32* / float* / array")); } } @@ -153,12 +168,15 @@ LoadInst::LoadInst(std::shared_ptr val_ty, Value* ptr, std::string name) if (!ptr) { throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr")); } - if (!type_ || !type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32")); + if (!type_ || (!type_->IsInt32() && !type_->IsFloat() && !type_->IsInt1())) { + throw std::runtime_error( + FormatError("ir", "LoadInst 仅支持加载 i32/float/i1")); } - if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { + if (!ptr->GetType() || + (!ptr->GetType()->IsPtrInt32() && !ptr->GetType()->IsPtrFloat() && + !ptr->GetType()->IsArray() && !ptr->GetType()->IsPtrInt1())) { throw std::runtime_error( - FormatError("ir", "LoadInst 当前只支持从 i32* 加载")); + FormatError("ir", "LoadInst 仅支持从指针或数组地址加载")); } AddOperand(ptr); } @@ -176,13 +194,25 @@ StoreInst::StoreInst(std::shared_ptr void_ty, Value* val, Value* ptr) if (!type_ || !type_->IsVoid()) { throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void")); } - if (!val->GetType() || !val->GetType()->IsInt32()) { - throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32")); - } - if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { + if (!val->GetType() || + (!val->GetType()->IsInt32() && !val->GetType()->IsFloat() && + !val->GetType()->IsInt1() && !val->GetType()->IsArray())) { throw std::runtime_error( - FormatError("ir", "StoreInst 当前只支持写入 i32*")); + FormatError("ir", "StoreInst 仅支持存储 i32/float/i1/array")); } + if (!ptr->GetType() || + (!ptr->GetType()->IsPtrInt32() && !ptr->GetType()->IsPtrFloat() && + !ptr->GetType()->IsArray() && !ptr->GetType()->IsPtrInt1())) { + throw std::runtime_error(FormatError("ir", "StoreInst 仅支持写入指针或数组地址")); + } + if (ptr->GetType()->IsArray()) { + if (!val->GetType()->IsArray() || + val->GetType()->GetKind() != ptr->GetType()->GetKind()) { + throw std::runtime_error( + FormatError("ir", "StoreInst 聚合存储要求 value/ptr 具有相同数组类型")); + } + } + AddOperand(val); AddOperand(ptr); } diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index bc324f1..e2143bc 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -1,5 +1,7 @@ +// IRGenDecl.cpp #include "irgen/IRGen.h" +#include #include #include "SysYParser.h" @@ -8,27 +10,95 @@ namespace { -// 使用 LValContext 而不是 LValueContext -std::string GetLValueName(SysYParser::LValContext& lvalue) { - if (!lvalue.Ident()) { - throw std::runtime_error(FormatError("irgen", "非法左值")); +constexpr size_t kLocalArrayHeapThresholdBytes = 1024 * 1024; + +size_t GetArrayElementByteWidth(const ir::ArrayType& array_ty) { + auto* elem_ty = array_ty.GetElementType().get(); + if (elem_ty->IsInt1()) { + return 1; } - return lvalue.Ident()->getText(); + return 4; } -int TryGetConstInt(SysYParser::ConstExpContext* ctx) { - // 这里是一个简化的版本,实际上应该调用语义分析的常量求值 - // 暂时假设所有常量表达式都是整数常量 - // 实际实现需要更复杂的逻辑 - - // 简化为返回10 - return 10; +size_t GetArrayStorageBytes(const ir::ArrayType& array_ty) { + return static_cast(array_ty.GetElementCount()) * + GetArrayElementByteWidth(array_ty); } -} // namespace +bool IsZeroIRValue(const ir::Value* value) { + if (!value) { + return true; + } + if (auto* ci = dynamic_cast(value)) { + return ci->GetValue() == 0; + } + if (auto* cf = dynamic_cast(value)) { + return cf->GetValue() == 0.0f; + } + if (dynamic_cast(value) || + dynamic_cast(value)) { + return true; + } + return false; +} -// 注意:visitBlock 已经在 IRGenFunc.cpp 中实现,这里不要重复定义 +bool IsZeroConstantValue(const ir::ConstantValue* value) { + return IsZeroIRValue(value); +} + +std::vector TrimTrailingZeroConstants( + std::vector values) { + while (!values.empty() && IsZeroConstantValue(values.back())) { + values.pop_back(); + } + return values; +} + +std::vector BuildArrayIndices(ir::IRBuilder& builder, + const std::vector& dims, + size_t flat_idx) { + std::vector indices; + indices.reserve(dims.size() + 1); + indices.push_back(builder.CreateConstInt(0)); + + size_t rem = flat_idx; + for (size_t i = 0; i < dims.size(); ++i) { + size_t stride = 1; + for (size_t j = i + 1; j < dims.size(); ++j) { + stride *= static_cast(dims[j]); + } + const int idx = static_cast(rem / stride); + rem %= stride; + indices.push_back(builder.CreateConstInt(idx)); + } + return indices; +} +std::vector BuildZeroIndices(ir::IRBuilder& builder, + size_t dims_count) { + std::vector indices; + indices.reserve(dims_count + 1); + for (size_t i = 0; i <= dims_count; ++i) { + indices.push_back(builder.CreateConstInt(0)); + } + return indices; +} + +std::string MakeStaticArrayName(const ir::Function& func, + const std::string& var_name, + std::string suffix) { + for (char& ch : suffix) { + if (ch == '%') { + ch = '_'; + } + } + return "__static_array." + func.GetName() + "." + var_name + "." + suffix; +} + + +} // namespace + +// visitDecl: 处理声明 std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { std::cerr << "[DEBUG] visitDecl: 开始处理声明" << std::endl; if (!ctx) { @@ -38,13 +108,8 @@ std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { // 处理 varDecl if (auto* varDecl = ctx->varDecl()) { std::cerr << "[DEBUG] visitDecl: 处理变量声明" << std::endl; - // 检查类型 - if (varDecl->bType() && varDecl->bType()->Int()) { - for (auto* varDef : varDecl->varDef()) { - varDef->accept(this); - } - } else { - throw std::runtime_error(FormatError("irgen", "当前仅支持 int 类型变量")); + for (auto* varDef : varDecl->varDef()) { + varDef->accept(this); } } @@ -52,355 +117,291 @@ std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { if (ctx->constDecl()) { std::cerr << "[DEBUG] visitDecl: 处理常量声明" << std::endl; auto* constDecl = ctx->constDecl(); - - if (constDecl->bType() && constDecl->bType()->Int()) { - for (auto* constDef : constDecl->constDef()) { - constDef->accept(this); - } - } else if (constDecl->bType() && constDecl->bType()->Float()) { - throw std::runtime_error(FormatError("irgen", "float常量暂未实现")); - } else { - throw std::runtime_error(FormatError("irgen", "未知的常量类型")); + for (auto* constDef : constDecl->constDef()) { + constDef->accept(this); } } + std::cerr << "[DEBUG] visitDecl: 声明处理完成" << std::endl; return {}; } -// 在 IRGenDecl.cpp 中确保有这个函数 +// visitConstDecl: 处理常量声明 std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) { std::cerr << "[DEBUG] visitConstDecl: 开始处理常量声明" << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法常量声明")); } - std::cerr << "[DEBUG] visitConstDecl: processing constant declaration" << std::endl; - - // 检查类型 - if (ctx->bType()) { - if (ctx->bType()->Int()) { - // int 类型常量 - for (auto* constDef : ctx->constDef()) { - if (constDef) { - constDef->accept(this); - } - } - } else if (ctx->bType()->Float()) { - // float 类型常量(暂不支持) - throw std::runtime_error(FormatError("irgen", "float常量暂未实现")); - } else { - throw std::runtime_error(FormatError("irgen", "未知的常量类型")); + for (auto* constDef : ctx->constDef()) { + if (constDef) { + constDef->accept(this); } - } else { - throw std::runtime_error(FormatError("irgen", "常量声明缺少类型")); } std::cerr << "[DEBUG] visitConstDecl: 常量声明处理完成" << std::endl; return {}; } +// visitConstDef: 处理常量定义 - 从符号表获取常量值 std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) { + std::cerr << "[DEBUG] visitConstDef: 开始处理常量定义" << std::endl; if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("irgen", "非法常量定义")); } std::string const_name = ctx->Ident()->getText(); - // 检查是否为数组 - bool is_array = !ctx->constExp().empty(); - - // 获取常量类型(int 或 float) - bool is_float = false; - auto* constDecl = dynamic_cast(ctx->parent); - if (constDecl && constDecl->bType()) { - if (constDecl->bType()->Float()) { - is_float = true; - std::cerr << "[DEBUG] visitConstDef: 常量 " << const_name << " 是 float 类型" << std::endl; - } + // 从符号表获取常量符号 + const Symbol* sym = symbol_table_.lookupByConstDef(ctx); + if (!sym || sym->kind != SymbolKind::Constant) { + throw std::runtime_error(FormatError("irgen", "常量符号未找到: " + const_name)); } - if (is_array) { - // 数组常量处理 - std::vector dimensions; - for (auto* const_exp : ctx->constExp()) { - int dim_size = TryEvaluateConstInt(const_exp); - if (dim_size <= 0) dim_size = 1; - dimensions.push_back(dim_size); - } - - // 创建数组类型 - std::shared_ptr element_type; - if (is_float) { - element_type = ir::Type::GetFloatType(); - } else { - element_type = ir::Type::GetInt32Type(); + std::cerr << "[DEBUG] visitConstDef: 从符号表获取常量 " << const_name + << ", is_array_const: " << sym->IsArrayConstant() << std::endl; + + // 根据符号表中的常量值创建 IR 常量 + if (sym->IsArrayConstant()) { + auto* array_ty = dynamic_cast(sym->type.get()); + if (!array_ty) { + throw std::runtime_error(FormatError("irgen", "数组类型转换失败")); } - - auto array_type = ir::Type::GetArrayType(element_type, dimensions); - ir::GlobalValue* global_array = module_.CreateGlobal(const_name, array_type); - - // 处理初始化值 + + const bool is_global_scope = (func_ == nullptr) || (sym->scope_level == 0); + const bool use_heap_storage = !is_global_scope && + (current_function_is_recursive_ || + GetArrayStorageBytes(*array_ty) > kLocalArrayHeapThresholdBytes); + + // 先把符号表中的扁平初始化值转换为 IR 常量值 std::vector init_consts; - if (auto* const_init_val = ctx->constInitVal()) { - auto result = const_init_val->accept(this); - if (result.has_value()) { - try { - auto init_vec = std::any_cast>(result); - for (auto* val : init_vec) { - if (is_float) { - if (auto* const_float = dynamic_cast(val)) { - init_consts.push_back(const_float); - } else if (auto* const_int = dynamic_cast(val)) { - // 整数转浮点 - float float_val = static_cast(const_int->GetValue()); - init_consts.push_back(builder_.CreateConstFloat(float_val)); - } else { - init_consts.push_back(builder_.CreateConstFloat(0.0f)); - } - } else { - if (auto* const_int = dynamic_cast(val)) { - init_consts.push_back(const_int); - } else if (auto* const_float = dynamic_cast(val)) { - // 浮点转整数 - int int_val = static_cast(const_float->GetValue()); - init_consts.push_back(builder_.CreateConstInt(int_val)); - } else { - init_consts.push_back(builder_.CreateConstInt(0)); - } - } - } - } catch (const std::bad_any_cast&) { - try { - ir::Value* single_val = std::any_cast(result); - if (is_float) { - if (auto* const_float = dynamic_cast(single_val)) { - init_consts.push_back(const_float); - } else { - init_consts.push_back(builder_.CreateConstFloat(0.0f)); - } - } else { - if (auto* const_int = dynamic_cast(single_val)) { - init_consts.push_back(const_int); - } else { - init_consts.push_back(builder_.CreateConstInt(0)); - } - } - } catch (...) {} - } - } - } - - // 补0 - int total_size = 1; - for (int dim : dimensions) total_size *= dim; - while (init_consts.size() < static_cast(total_size)) { - if (is_float) { - init_consts.push_back(builder_.CreateConstFloat(0.0f)); - } else { - init_consts.push_back(builder_.CreateConstInt(0)); + init_consts.reserve(sym->GetArraySize()); + for (size_t i = 0; i < sym->GetArraySize(); ++i) { + auto elem = sym->GetArrayElement(i); + if (array_ty->GetElementType()->IsInt32()) { + init_consts.push_back(builder_.CreateConstInt(elem.i32)); + } else if (array_ty->GetElementType()->IsFloat()) { + init_consts.push_back(builder_.CreateConstFloat(elem.f32)); } } - - global_array->SetInitializer(init_consts); - global_array->SetConstant(true); - - const_storage_map_[ctx] = global_array; - const_global_map_[const_name] = global_array; - - } else { - // 标量常量处理 - if (!ctx->constInitVal()) { - throw std::runtime_error(FormatError("irgen", "常量缺少初始值")); - } - - ir::ConstantValue* const_value = nullptr; - auto* const_init_val = ctx->constInitVal(); - - if (const_init_val->constExp()) { - // 对于常量表达式,我们可以尝试直接求值 - if (is_float) { - // TODO: 实现浮点常量表达式的求值 - const_value = module_.GetContext().GetConstFloat(0.0f); - } else { - int value = TryEvaluateConstInt(const_init_val->constExp()); - const_value = module_.GetContext().GetConstInt(value); + + init_consts = TrimTrailingZeroConstants(std::move(init_consts)); + + if (is_global_scope) { + // 全局 const 数组:全局存储,并标记为 constant + ir::GlobalValue* global_array = module_.CreateGlobal(const_name, sym->type); + if (!init_consts.empty()) { + global_array->SetInitializer(init_consts); } + global_array->SetConstant(true); + const_global_map_[const_name] = global_array; + const_storage_map_[ctx] = global_array; } else { - if (is_float) { - const_value = module_.GetContext().GetConstFloat(0.0f); + ir::Value* array_slot = nullptr; + if (use_heap_storage) { + const bool is_float_array = array_ty->GetElementType()->IsFloat(); + const std::string alloc_name = is_float_array ? "sysy_alloc_f32" : "sysy_alloc_i32"; + const std::string free_name = is_float_array ? "sysy_free_f32" : "sysy_free_i32"; + ir::Function* alloc_func = module_.FindFunction(alloc_name); + if (!alloc_func) alloc_func = CreateRuntimeFunctionDecl(alloc_name); + ir::Function* free_func = module_.FindFunction(free_name); + if (!free_func) free_func = CreateRuntimeFunctionDecl(free_name); + array_slot = builder_.CreateCall( + alloc_func, + {builder_.CreateConstInt(static_cast(array_ty->GetElementCount()))}, + module_.GetContext().NextTemp()); + heap_local_array_names_.insert(const_name); + RegisterCleanup(free_func, array_slot); } else { - const_value = module_.GetContext().GetConstInt(0); + array_slot = CreateEntryAlloca(sym->type, + module_.GetContext().NextTemp() + "_" + const_name); + } + + const auto& dims = array_ty->GetDimensions(); + const size_t total_size = array_ty->GetElementCount(); + + if (init_consts.empty()) { + if (use_heap_storage) { + const std::string zero_name = array_ty->GetElementType()->IsFloat() + ? "sysy_zero_f32" + : "sysy_zero_i32"; + ir::Function* zero_func = module_.FindFunction(zero_name); + if (!zero_func) zero_func = CreateRuntimeFunctionDecl(zero_name); + builder_.CreateCall(zero_func, + {array_slot, builder_.CreateConstInt(static_cast(total_size))}, + module_.GetContext().NextTemp()); + } else { + builder_.CreateStore(module_.GetContext().GetAggregateZero(sym->type), array_slot); + } } + + for (size_t i = 0; i < total_size; ++i) { + ir::Value* init = nullptr; + if (i < init_consts.size()) { + init = init_consts[i]; + } else if (array_ty->GetElementType()->IsFloat()) { + init = builder_.CreateConstFloat(0.0f); + } else { + init = builder_.CreateConstInt(0); + } + + ir::Value* elem_ptr = nullptr; + if (use_heap_storage) { + if (IsZeroIRValue(init)) { + continue; + } + elem_ptr = builder_.CreateGEP( + array_slot, {builder_.CreateConstInt(static_cast(i))}, + module_.GetContext().NextTemp()); + } else { + elem_ptr = builder_.CreateGEP( + array_slot, BuildArrayIndices(builder_, dims, i), + module_.GetContext().NextTemp()); + } + builder_.CreateStore(init, elem_ptr); + } + + local_var_map_[const_name] = array_slot; + const_storage_map_[ctx] = array_slot; + } + + } else if (sym->IsScalarConstant()) { + // 标量常量:存储常量值 + ir::ConstantValue* const_value = nullptr; + if (sym->type->IsInt32()) { + const_value = builder_.CreateConstInt(sym->GetIntConstant()); + std::cerr << "[DEBUG] visitConstDef: 整型常量 " << const_name + << " = " << sym->GetIntConstant() << std::endl; + } else if (sym->type->IsFloat()) { + const_value = builder_.CreateConstFloat(sym->GetFloatConstant()); + std::cerr << "[DEBUG] visitConstDef: 浮点常量 " << const_name + << " = " << sym->GetFloatConstant() << std::endl; } - // 存储常量值到映射 const_value_map_[const_name] = const_value; + const_storage_map_[ctx] = const_value; } return {}; } - -// TO DO:visitVarDef来区分全局和局部变量,并且正确处理数组变量的定义和初始化 +// visitVarDef: 处理变量定义 - 从符号表获取类型信息 std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { std::cerr << "[DEBUG] visitVarDef: 开始处理变量定义" << std::endl; - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少变量定义")); - } - - if (!ctx->Ident()) { - throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("irgen", "非法变量定义")); } std::string varName = ctx->Ident()->getText(); std::cerr << "[DEBUG] visitVarDef: 变量名称: " << varName << std::endl; - // 防止同一个变量被多次分配存储空间。 + + // 防止重复分配 if (storage_map_.find(ctx) != storage_map_.end()) { throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位: " + varName)); } - bool is_array = !ctx->constExp().empty(); - std::cerr << "[DEBUG] visitVarDef: 是否为数组: " << (is_array ? "是" : "否") << std::endl; + // 从符号表获取变量信息 + const Symbol* sym = symbol_table_.lookupByVarDef(ctx); + if (!sym) { + throw std::runtime_error(FormatError("irgen", "变量符号未找到: " + varName)); + } + + std::cerr << "[DEBUG] visitVarDef: 变量类型: " + << (sym->type->IsInt32() ? "int" : + sym->type->IsFloat() ? "float" : + sym->type->IsArray() ? "array" : "unknown") << std::endl; - // 使用 func_ 来判断:func_ == nullptr 表示在全局作用域 + // 根据作用域处理 if (func_ == nullptr) { std::cerr << "[DEBUG] visitVarDef: 处理全局变量" << std::endl; - // 全局变量处理 - return HandleGlobalVariable(ctx, varName, is_array); + return HandleGlobalVariable(ctx, varName, sym); } else { std::cerr << "[DEBUG] visitVarDef: 处理局部变量" << std::endl; - // 局部变量处理 - return HandleLocalVariable(ctx, varName, is_array); + return HandleLocalVariable(ctx, varName, sym); } } +// HandleGlobalVariable: 处理全局变量 std::any IRGenImpl::HandleGlobalVariable(SysYParser::VarDefContext* ctx, - const std::string& varName, - bool is_array) { + const std::string& varName, + const Symbol* sym) { std::cerr << "[DEBUG] HandleGlobalVariable: 开始处理全局变量 " << varName << std::endl; - // 获取变量类型(int 或 float) - bool is_float = false; - auto* varDecl = dynamic_cast(ctx->parent); - if (varDecl && varDecl->bType()) { - if (varDecl->bType()->Float()) { - is_float = true; - std::cerr << "[DEBUG] HandleGlobalVariable: 变量 " << varName << " 是 float 类型" << std::endl; - } else if (varDecl->bType()->Int()) { - std::cerr << "[DEBUG] HandleGlobalVariable: 变量 " << varName << " 是 int 类型" << std::endl; + if (!sym) { + throw std::runtime_error(FormatError("irgen", "符号表信息缺失: " + varName)); + } + + bool is_array = sym->type->IsArray(); + bool is_float = sym->type->IsFloat(); + if (is_array) { + if (auto* array_ty = dynamic_cast(sym->type.get())) { + is_float = array_ty->GetElementType()->IsFloat(); } } if (is_array) { - // 全局数组变量 - int total_size = 1; - std::vector dimensions; - - // 计算总大小 - for (auto* const_exp : ctx->constExp()) { - int dim_size = TryEvaluateConstInt(const_exp); - if (dim_size <= 0) { - dim_size = 1; - std::cerr << "[WARNING] HandleGlobalVariable: 无法确定数组维度大小,使用1" << std::endl; - } - dimensions.push_back(dim_size); - total_size *= dim_size; + // 从符号表获取数组类型和维度 + auto* array_ty = dynamic_cast(sym->type.get()); + if (!array_ty) { + throw std::runtime_error(FormatError("irgen", "数组类型转换失败")); } - std::cerr << "[DEBUG] HandleGlobalVariable: 数组总大小: " << total_size << std::endl; + const auto& dimensions = array_ty->GetDimensions(); + size_t total_size = array_ty->GetElementCount(); - // 创建数组类型 - std::shared_ptr element_type; - if (is_float) { - element_type = ir::Type::GetFloatType(); - } else { - element_type = ir::Type::GetInt32Type(); - } + std::cerr << "[DEBUG] HandleGlobalVariable: 全局数组 " << varName << " 维度: "; + for (int d : dimensions) std::cerr << d << " "; + std::cerr << ", 总大小: " << total_size << std::endl; - auto array_type = ir::Type::GetArrayType(element_type, dimensions); - ir::GlobalValue* global_array = module_.CreateGlobal(varName, array_type); - std::cerr << "[DEBUG] HandleGlobalVariable: 创建全局数组: " << varName << std::endl; + // 创建全局数组 + ir::GlobalValue* global_array = module_.CreateGlobal(varName, sym->type); - // 处理初始化值 + // 处理初始化值(使用带维度感知的展平) std::vector init_consts; if (auto* initVal = ctx->initVal()) { std::cerr << "[DEBUG] HandleGlobalVariable: 处理初始化值" << std::endl; - auto result = initVal->accept(this); - if (result.has_value()) { - try { - auto init_vec = std::any_cast>(result); - std::cerr << "[DEBUG] HandleGlobalVariable: 获取到初始化值列表, 大小: " << init_vec.size() << std::endl; - for (auto* val : init_vec) { - if (auto* const_int = dynamic_cast(val)) { - init_consts.push_back(const_int); - } else if (auto* const_float = dynamic_cast(val)) { - init_consts.push_back(const_float); - } else { - // 非常量表达式,使用0 - if (is_float) { - init_consts.push_back(builder_.CreateConstFloat(0.0f)); - } else { - init_consts.push_back(builder_.CreateConstInt(0)); - } - } + // 全局变量的初始化必须是常量表达式(语义检查已保证) + std::vector flat_vals = FlattenInitVal( + initVal, dimensions, is_float); + for (auto* val : flat_vals) { + if (is_float) { + if (auto* cf = dynamic_cast(val)) { + init_consts.push_back(cf); + } else if (auto* ci = dynamic_cast(val)) { + init_consts.push_back(builder_.CreateConstFloat(static_cast(ci->GetValue()))); + } else { + init_consts.push_back(builder_.CreateConstFloat(0.0f)); } - } catch (const std::bad_any_cast&) { - try { - ir::Value* single_val = std::any_cast(result); - std::cerr << "[DEBUG] HandleGlobalVariable: 获取到单个初始化值" << std::endl; - if (auto* const_int = dynamic_cast(single_val)) { - init_consts.push_back(const_int); - } else if (auto* const_float = dynamic_cast(single_val)) { - init_consts.push_back(const_float); - } else { - if (is_float) { - init_consts.push_back(builder_.CreateConstFloat(0.0f)); - } else { - init_consts.push_back(builder_.CreateConstInt(0)); - } - } - } catch (const std::bad_any_cast&) { - std::cerr << "[WARNING] HandleGlobalVariable: 无法解析数组初始化值" << std::endl; + } else { + if (auto* ci = dynamic_cast(val)) { + init_consts.push_back(ci); + } else if (auto* cf = dynamic_cast(val)) { + init_consts.push_back(builder_.CreateConstInt(static_cast(cf->GetValue()))); + } else { + init_consts.push_back(builder_.CreateConstInt(0)); } } } } + + init_consts = TrimTrailingZeroConstants(std::move(init_consts)); - // 如果初始化值不足,补0 - while (init_consts.size() < static_cast(total_size)) { - if (is_float) { - init_consts.push_back(builder_.CreateConstFloat(0.0f)); - } else { - init_consts.push_back(builder_.CreateConstInt(0)); - } - } - - // 设置全局数组的初始化器 + // 设置初始化器 if (!init_consts.empty()) { global_array->SetInitializer(init_consts); - std::cerr << "[DEBUG] HandleGlobalVariable: 设置全局数组初始化器" << std::endl; } - // 存储全局变量引用 storage_map_[ctx] = global_array; global_map_[varName] = global_array; } else { // 全局标量变量 - std::shared_ptr var_type; - if (is_float) { - var_type = ir::Type::GetFloatType(); - } else { - var_type = ir::Type::GetInt32Type(); - } - + std::shared_ptr var_type = sym->type; ir::GlobalValue* global_var = module_.CreateGlobal(varName, var_type); - std::cerr << "[DEBUG] HandleGlobalVariable: 创建全局标量变量: " << varName << std::endl; // 处理初始化值 ir::ConstantValue* init_value = nullptr; if (auto* initVal = ctx->initVal()) { - std::cerr << "[DEBUG] HandleGlobalVariable: 处理标量初始化值" << std::endl; auto result = initVal->accept(this); if (result.has_value()) { try { @@ -409,191 +410,171 @@ std::any IRGenImpl::HandleGlobalVariable(SysYParser::VarDefContext* ctx, if (auto* const_float = dynamic_cast(val)) { init_value = const_float; } else if (auto* const_int = dynamic_cast(val)) { - // 整数转浮点 - float float_val = static_cast(const_int->GetValue()); - init_value = builder_.CreateConstFloat(float_val); - } else { - init_value = builder_.CreateConstFloat(0.0f); + init_value = builder_.CreateConstFloat(static_cast(const_int->GetValue())); } } else { if (auto* const_int = dynamic_cast(val)) { init_value = const_int; } else if (auto* const_float = dynamic_cast(val)) { - // 浮点转整数 - int int_val = static_cast(const_float->GetValue()); - init_value = builder_.CreateConstInt(int_val); - } else { - init_value = builder_.CreateConstInt(0); + init_value = builder_.CreateConstInt(static_cast(const_float->GetValue())); } } } catch (const std::bad_any_cast&) { - if (is_float) { - init_value = builder_.CreateConstFloat(0.0f); - } else { - init_value = builder_.CreateConstInt(0); - } + // 使用默认值 } } } - // 如果没有初始化值,默认初始化 + //正确:只在没有初始化值时才设置默认值 if (!init_value) { if (is_float) { - init_value = builder_.CreateConstFloat(0.0f); + init_value = builder_.CreateConstFloat(0.0f); } else { - init_value = builder_.CreateConstInt(0); + init_value = builder_.CreateConstInt(0); } } - // 设置全局变量的初始化器 global_var->SetInitializer(init_value); - - // 存储全局变量引用 storage_map_[ctx] = global_var; global_map_[varName] = global_var; } + std::cerr << "[DEBUG] HandleGlobalVariable: 全局变量处理完成" << std::endl; return {}; } -// 修改 HandleLocalVariable 函数中的数组处理部分 - +// HandleLocalVariable: 处理局部变量 std::any IRGenImpl::HandleLocalVariable(SysYParser::VarDefContext* ctx, - const std::string& varName, - bool is_array) { + const std::string& varName, + const Symbol* sym) { std::cerr << "[DEBUG] HandleLocalVariable: 开始处理局部变量 " << varName << std::endl; - - // 获取变量类型 - bool is_float = false; - auto* varDecl = dynamic_cast(ctx->parent); - if (varDecl && varDecl->bType()) { - if (varDecl->bType()->Float()) { - is_float = true; - std::cerr << "[DEBUG] HandleLocalVariable: 变量 " << varName << " 是 float 类型" << std::endl; - } + + if (!sym) { + throw std::runtime_error(FormatError("irgen", "符号表信息缺失: " + varName)); } - + + bool is_array = sym->type->IsArray(); + bool is_float = sym->type->IsFloat(); if (is_array) { - // 局部数组变量 - int total_size = 1; - std::vector dimensions; - - // 获取数组维度 - for (auto* const_exp : ctx->constExp()) { - try { - int dim_size = TryEvaluateConstInt(const_exp); - if (dim_size <= 0) dim_size = 1; - dimensions.push_back(dim_size); - total_size *= dim_size; - } catch (const std::exception& e) { - std::cerr << "[WARNING] HandleLocalVariable: 无法获取数组维度,使用维度1" << std::endl; - dimensions.push_back(1); - total_size *= 1; - } + if (auto* array_ty = dynamic_cast(sym->type.get())) { + is_float = array_ty->GetElementType()->IsFloat(); } - - // 创建数组类型 - std::shared_ptr elem_type; - if (is_float) { - elem_type = ir::Type::GetFloatType(); - } else { - elem_type = ir::Type::GetInt32Type(); + } + + if (is_array) { + // 从符号表获取数组信息 + auto* array_ty = dynamic_cast(sym->type.get()); + if (!array_ty) { + throw std::runtime_error(FormatError("irgen", "数组类型转换失败")); } - // 修正:使用完整的维度列表创建数组类型 - auto array_type = ir::Type::GetArrayType(elem_type, dimensions); - - // 分配数组内存 - 为每个元素创建独立的 alloca - std::vector element_slots; - for (int i = 0; i < total_size; i++) { - ir::AllocaInst* slot; - if (is_float) { - slot = builder_.CreateAllocaFloat( - module_.GetContext().NextTemp() + "_" + varName + "_" + std::to_string(i)); - } else { - slot = builder_.CreateAllocaI32( - module_.GetContext().NextTemp() + "_" + varName + "_" + std::to_string(i)); + size_t total_size = array_ty->GetElementCount(); + const size_t total_bytes = GetArrayStorageBytes(*array_ty); + const bool use_heap_storage = + current_function_is_recursive_ || total_bytes > kLocalArrayHeapThresholdBytes; + + std::cerr << "[DEBUG] HandleLocalVariable: 局部数组 " << varName + << " 总大小: " << total_size << std::endl; + + ir::Value* array_slot = nullptr; + if (use_heap_storage) { + const std::string alloc_name = is_float ? "sysy_alloc_f32" : "sysy_alloc_i32"; + const std::string free_name = is_float ? "sysy_free_f32" : "sysy_free_i32"; + ir::Function* alloc_func = module_.FindFunction(alloc_name); + if (!alloc_func) { + alloc_func = CreateRuntimeFunctionDecl(alloc_name); + } + ir::Function* free_func = module_.FindFunction(free_name); + if (!free_func) { + free_func = CreateRuntimeFunctionDecl(free_name); } - element_slots.push_back(slot); + array_slot = builder_.CreateCall( + alloc_func, + {builder_.CreateConstInt(static_cast(total_size))}, + module_.GetContext().NextTemp()); + heap_local_array_names_.insert(varName); + RegisterCleanup(free_func, array_slot); + } else { + array_slot = CreateEntryAlloca( + sym->type, module_.GetContext().NextTemp() + "_" + varName); } - // 存储第一个元素的地址作为数组的基地址 - storage_map_[ctx] = element_slots[0]; - local_var_map_[varName] = element_slots[0]; + const auto& dims = array_ty->GetDimensions(); + + storage_map_[ctx] = array_slot; + local_var_map_[varName] = array_slot; // 处理初始化 if (auto* initVal = ctx->initVal()) { - auto result = initVal->accept(this); - if (result.has_value()) { - try { - std::vector init_values = - std::any_cast>(result); - - // 初始化数组元素 - for (size_t i = 0; i < init_values.size() && i < static_cast(total_size); i++) { - builder_.CreateStore(init_values[i], element_slots[i]); - } - - // 剩余元素初始化为0 - for (size_t i = init_values.size(); i < static_cast(total_size); i++) { - if (is_float) { - builder_.CreateStore(builder_.CreateConstFloat(0.0f), element_slots[i]); - } else { - builder_.CreateStore(builder_.CreateConstInt(0), element_slots[i]); - } - } - } catch (const std::bad_any_cast&) { - try { - ir::Value* single_value = std::any_cast(result); - // 只初始化第一个元素 - builder_.CreateStore(single_value, element_slots[0]); - - // 其他元素初始化为0 - for (int i = 1; i < total_size; i++) { - if (is_float) { - builder_.CreateStore(builder_.CreateConstFloat(0.0f), element_slots[i]); - } else { - builder_.CreateStore(builder_.CreateConstInt(0), element_slots[i]); - } - } - } catch (const std::bad_any_cast&) { - // 全部初始化为0 - for (int i = 0; i < total_size; i++) { - if (is_float) { - builder_.CreateStore(builder_.CreateConstFloat(0.0f), element_slots[i]); - } else { - builder_.CreateStore(builder_.CreateConstInt(0), element_slots[i]); - } - } - } + std::vector init_values = FlattenInitVal( + initVal, array_ty->GetDimensions(), is_float); + + bool is_all_zero_init = true; + for (auto* value : init_values) { + if (!IsZeroIRValue(value)) { + is_all_zero_init = false; + break; } - } else { - // 没有初始化值,全部初始化为0 - for (int i = 0; i < total_size; i++) { - if (is_float) { - builder_.CreateStore(builder_.CreateConstFloat(0.0f), element_slots[i]); - } else { - builder_.CreateStore(builder_.CreateConstInt(0), element_slots[i]); - } + } + + if (is_all_zero_init && !use_heap_storage) { + builder_.CreateStore(module_.GetContext().GetAggregateZero(sym->type), + array_slot); + std::cerr << "[DEBUG] HandleLocalVariable: aggregate zeroinitializer store for " + << varName << std::endl; + return {}; + } + + if (use_heap_storage) { + const std::string zero_name = is_float ? "sysy_zero_f32" : "sysy_zero_i32"; + ir::Function* zero_func = module_.FindFunction(zero_name); + if (!zero_func) { + zero_func = CreateRuntimeFunctionDecl(zero_name); + } + builder_.CreateCall(zero_func, + {array_slot, builder_.CreateConstInt(static_cast(total_size))}, + module_.GetContext().NextTemp()); + if (is_all_zero_init) { + return {}; } } - } else { - // 无初始化,所有元素初始化为0 - for (int i = 0; i < total_size; i++) { - if (is_float) { - builder_.CreateStore(builder_.CreateConstFloat(0.0f), element_slots[i]); + + for (size_t i = 0; i < total_size; i++) { + ir::Value* val = (i < init_values.size() && init_values[i]) + ? init_values[i] + : (is_float ? static_cast(builder_.CreateConstFloat(0.0f)) + : static_cast(builder_.CreateConstInt(0))); + if (use_heap_storage && IsZeroIRValue(val)) { + continue; + } + if (is_float && val->GetType()->IsInt32()) { + val = builder_.CreateSIToFP(val, ir::Type::GetFloatType(), + module_.GetContext().NextTemp()); + } else if (!is_float && val->GetType()->IsFloat()) { + val = builder_.CreateFPToSI(val, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + ir::Value* elem_ptr = nullptr; + if (use_heap_storage) { + elem_ptr = builder_.CreateGEP( + array_slot, {builder_.CreateConstInt(static_cast(i))}, + module_.GetContext().NextTemp()); } else { - builder_.CreateStore(builder_.CreateConstInt(0), element_slots[i]); + elem_ptr = builder_.CreateGEP( + array_slot, BuildArrayIndices(builder_, dims, i), + module_.GetContext().NextTemp()); } + builder_.CreateStore(val, elem_ptr); } } + } else { // 局部标量变量 ir::AllocaInst* slot; if (is_float) { - slot = builder_.CreateAllocaFloat(module_.GetContext().NextTemp() + "_" + varName); + slot = CreateEntryAllocaFloat(module_.GetContext().NextTemp() + "_" + varName); } else { - slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp() + "_" + varName); + slot = CreateEntryAllocaI32(module_.GetContext().NextTemp() + "_" + varName); } storage_map_[ctx] = slot; @@ -608,47 +589,39 @@ std::any IRGenImpl::HandleLocalVariable(SysYParser::VarDefContext* ctx, init = std::any_cast(result); } catch (const std::bad_any_cast&) { try { - std::vector init_values = - std::any_cast>(result); - if (!init_values.empty()) { - init = init_values[0]; - } else { - if (is_float) { - init = builder_.CreateConstFloat(0.0f); - } else { - init = builder_.CreateConstInt(0); - } + auto init_vec = std::any_cast>(result); + if (!init_vec.empty()) { + init = init_vec[0]; } } catch (const std::bad_any_cast&) { - if (is_float) { - init = builder_.CreateConstFloat(0.0f); - } else { - init = builder_.CreateConstInt(0); - } + // 使用默认值 } } - } else { - if (is_float) { - init = builder_.CreateConstFloat(0.0f); - } else { - init = builder_.CreateConstInt(0); - } - } - } else { - if (is_float) { - init = builder_.CreateConstFloat(0.0f); - } else { - init = builder_.CreateConstInt(0); } } + if (!init) { + // SysY 局部变量未显式初始化时为未定义值:不生成默认 store。 + return {}; + } + + // 标量初始化支持 int/float 隐式转换。 + if (is_float && init->GetType()->IsInt32()) { + init = builder_.CreateSIToFP(init, ir::Type::GetFloatType(), + module_.GetContext().NextTemp()); + } else if (!is_float && init->GetType()->IsFloat()) { + init = builder_.CreateFPToSI(init, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + builder_.CreateStore(init, slot); } + std::cerr << "[DEBUG] HandleLocalVariable: 局部变量处理完成" << std::endl; return {}; } - +// visitInitVal: 处理初始化值 std::any IRGenImpl::visitInitVal(SysYParser::InitValContext* ctx) { std::cerr << "[DEBUG] visitInitVal: 开始处理初始化值" << std::endl; if (!ctx) { @@ -663,16 +636,14 @@ std::any IRGenImpl::visitInitVal(SysYParser::InitValContext* ctx) { // 如果是聚合初始化(花括号列表) else if (!ctx->initVal().empty()) { std::cerr << "[DEBUG] visitInitVal: 处理聚合初始化" << std::endl; - // 处理嵌套聚合初始化 return ProcessNestedInitVals(ctx); } std::cerr << "[DEBUG] visitInitVal: 空初始化列表" << std::endl; - // 空初始化列表 return std::vector{}; } -// 新增:处理嵌套聚合初始化的辅助函数 +// ProcessNestedInitVals: 处理嵌套聚合初始化 std::vector IRGenImpl::ProcessNestedInitVals(SysYParser::InitValContext* ctx) { std::cerr << "[DEBUG] ProcessNestedInitVals: 开始处理嵌套初始化值" << std::endl; std::vector all_values; @@ -690,157 +661,107 @@ std::vector IRGenImpl::ProcessNestedInitVals(SysYParser::InitValCont // 尝试获取值列表(嵌套情况) std::vector nested_values = std::any_cast>(result); - std::cerr << "[DEBUG] ProcessNestedInitVals: 获取到嵌套值列表, 大小: " << nested_values.size() << std::endl; - // 展平嵌套的值 + std::cerr << "[DEBUG] ProcessNestedInitVals: 获取到嵌套值列表, 大小: " + << nested_values.size() << std::endl; all_values.insert(all_values.end(), nested_values.begin(), nested_values.end()); } catch (const std::bad_any_cast&) { - // 未知类型 std::cerr << "[ERROR] ProcessNestedInitVals: 不支持的初始化值类型" << std::endl; throw std::runtime_error( FormatError("irgen", "不支持的初始化值类型")); } } - } else { - std::cerr << "[DEBUG] ProcessNestedInitVals: 无初始化值结果" << std::endl; } } - std::cerr << "[DEBUG] ProcessNestedInitVals: 共获取 " << all_values.size() << " 个初始化值" << std::endl; + std::cerr << "[DEBUG] ProcessNestedInitVals: 共获取 " << all_values.size() + << " 个初始化值" << std::endl; return all_values; } -int IRGenImpl::TryEvaluateConstInt(SysYParser::ConstExpContext* ctx) { - std::cerr << "[DEBUG] TryEvaluateConstInt: 开始求值常量表达式" << std::endl; - if (!ctx) { - std::cerr << "[DEBUG] TryEvaluateConstInt: ctx is null" << std::endl; - return 0; - } - - // 直接访问常量表达式树,计算数值 - // 这里需要实现真正的常量求值逻辑 - // 简化版本:假设常量表达式是整数常量 - - if (ctx->addExp()) { - // 尝试从 addExp 求值 - return EvaluateConstAddExp(ctx->addExp()); - } - - return 0; -} +// FlattenInitVal:按 C 语言花括号对齐规则展平初始化列表 +// dims[0] 是最外层维度,dims.back() 是最内层维度(元素层) +// 总元素数 = prod(dims),结果向量长度恰好为总元素数(不足处补零) +std::vector IRGenImpl::FlattenInitVal( + SysYParser::InitValContext* ctx, + const std::vector& dims, + bool is_float) { -// 添加辅助函数来求值常量表达式 -int IRGenImpl::EvaluateConstAddExp(SysYParser::AddExpContext* ctx) { - if (!ctx) return 0; - - // 如果没有左操作数,直接求值右操作数 - if (!ctx->addExp()) { - return EvaluateConstMulExp(ctx->mulExp()); - } - - int left = EvaluateConstAddExp(ctx->addExp()); - int right = EvaluateConstMulExp(ctx->mulExp()); - - if (ctx->AddOp()) { - return left + right; - } else if (ctx->SubOp()) { - return left - right; - } - - return 0; -} + // 计算总元素数 + size_t total = 1; + for (int d : dims) total *= static_cast(d); -int IRGenImpl::EvaluateConstMulExp(SysYParser::MulExpContext* ctx) { - if (!ctx) return 0; - - // 如果没有左操作数,直接求值右操作数 - if (!ctx->mulExp()) { - return EvaluateConstUnaryExp(ctx->unaryExp()); - } - - int left = EvaluateConstMulExp(ctx->mulExp()); - int right = EvaluateConstUnaryExp(ctx->unaryExp()); - - if (ctx->MulOp()) { - return left * right; - } else if (ctx->DivOp()) { - return left / right; - } else if (ctx->QuoOp()) { - return left % right; - } - - return 0; -} + // 零值工厂 + auto make_zero = [&]() -> ir::Value* { + if (is_float) return builder_.CreateConstFloat(0.0f); + return builder_.CreateConstInt(0); + }; -int IRGenImpl::EvaluateConstUnaryExp(SysYParser::UnaryExpContext* ctx) { - if (!ctx) return 0; - - // 基本表达式(数字字面量) - if (ctx->primaryExp()) { - return EvaluateConstPrimaryExp(ctx->primaryExp()); - } - - // 一元运算 - if (ctx->unaryOp() && ctx->unaryExp()) { - int operand = EvaluateConstUnaryExp(ctx->unaryExp()); - std::string op = ctx->unaryOp()->getText(); - - if (op == "+") { - return operand; - } else if (op == "-") { - return -operand; - } else if (op == "!") { - return !operand; + // 先全部填零 + std::vector flat(total, nullptr); + + // 计算 depth 层从 begin 开始的子数组跨度 + // depth=0 → 每个子聚合占 dims[1]*dims[2]*... 个元素 + auto subspan = [&](size_t depth) -> size_t { + size_t span = 1; + for (size_t i = depth + 1; i < dims.size(); ++i) + span *= static_cast(dims[i]); + return span; + }; + + // 递归 fill_impl:将 node 的内容按 C 规则写入 flat[begin..end-1],返回填完后的光标 + std::function fill_impl; + fill_impl = [&](SysYParser::InitValContext* node, + size_t depth, + size_t begin, + size_t end) -> size_t { + if (!node || begin >= end) return begin; + + // 单标量初始化项(叶节点) + if (node->exp()) { + ir::Value* v = EvalExpr(*node->exp()); + if (begin < flat.size()) flat[begin] = v; + return std::min(begin + 1, end); } - } - - return 0; -} -int IRGenImpl::EvaluateConstPrimaryExp(SysYParser::PrimaryExpContext* ctx) { - if (!ctx) return 0; - - // 处理数字字面量 - if (ctx->DECIMAL_INT()) { - return std::stoi(ctx->DECIMAL_INT()->getText()); - } - - if (ctx->HEX_INT()) { - std::string hex = ctx->HEX_INT()->getText(); - return std::stoi(hex, nullptr, 16); - } - - if (ctx->OCTAL_INT()) { - std::string oct = ctx->OCTAL_INT()->getText(); - return std::stoi(oct, nullptr, 8); - } - - if (ctx->ZERO()) { - return 0; - } - - // 处理括号表达式 - if (ctx->L_PAREN() && ctx->exp()) { - return EvaluateConstExp(ctx->exp()); - } - - // 常量标识符(引用其他常量) - if (ctx->lVal()) { - std::string const_name = ctx->lVal()->Ident()->getText(); - auto it = const_value_map_.find(const_name); - if (it != const_value_map_.end()) { - if (auto* const_int = dynamic_cast(it->second)) { - return const_int->GetValue(); + // 聚合初始化(花括号列表) + size_t cursor = begin; + for (auto* child : node->initVal()) { + if (cursor >= end) break; + + if (child->exp()) { + // 标量子项 + ir::Value* v = EvalExpr(*child->exp()); + if (cursor < flat.size()) flat[cursor] = v; + ++cursor; + continue; + } + + // 花括号子列表 + if (depth + 1 < dims.size()) { + // 对齐到下一个子聚合边界 + const size_t span = subspan(depth); + const size_t rel = (cursor - begin) % span; + if (rel != 0) cursor += (span - rel); + if (cursor >= end) break; + + const size_t sub_end = std::min(cursor + span, end); + fill_impl(child, depth + 1, cursor, sub_end); + cursor = sub_end; // 消耗一个子聚合 + } else { + // 最内层遇到额外花括号,按顺序展开 + cursor = fill_impl(child, depth, cursor, end); } } - // 也可以从语义分析获取常量值 - // ... + return cursor; + }; + + fill_impl(ctx, 0, 0, total); + + // 把 nullptr(未显式初始化)替换为零值 + for (auto*& v : flat) { + if (!v) v = make_zero(); } - - return 0; -} -int IRGenImpl::EvaluateConstExp(SysYParser::ExpContext* ctx) { - if (!ctx || !ctx->addExp()) return 0; - return EvaluateConstAddExp(ctx->addExp()); + return flat; } \ No newline at end of file diff --git a/src/irgen/IRGenDriver.cpp b/src/irgen/IRGenDriver.cpp index ff94412..ac19a7b 100644 --- a/src/irgen/IRGenDriver.cpp +++ b/src/irgen/IRGenDriver.cpp @@ -6,10 +6,11 @@ #include "ir/IR.h" #include "utils/Log.h" +// 修改 GenerateIR 函数 std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, - const SemanticContext& sema) { + const SemaResult& sema_result) { auto module = std::make_unique(); - IRGenImpl gen(*module, sema); + IRGenImpl gen(*module, sema_result.context, sema_result.symbol_table); tree.accept(&gen); return module; } diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index 9e33971..5d8128b 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -21,8 +21,9 @@ // - 条件与比较表达式 // - ... +// 表达式生成 ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { - std::cout << "[DEBUG IRGEN] EvalExpr: " << expr.getText() << std::endl; + std::cerr << "[DEBUG IRGEN] EvalExpr: 开始处理表达式 " << expr.getText() << std::endl; try { auto result_any = expr.accept(this); @@ -38,15 +39,7 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { } catch (const std::bad_any_cast& e) { std::cerr << "[ERROR] EvalExpr: bad any_cast - " << e.what() << std::endl; std::cerr << " Type info: " << result_any.type().name() << std::endl; - - // 尝试其他可能的类型 - try { - // 检查是否是无值的any(可能来自visit函数返回{}) - std::cerr << "[DEBUG] EvalExpr: Trying to handle empty any" << std::endl; - return nullptr; - } catch (...) { - throw std::runtime_error(FormatError("irgen", "表达式求值返回了错误的类型")); - } + throw std::runtime_error(FormatError("irgen", "表达式求值返回了错误的类型")); } } catch (const std::exception& e) { std::cerr << "[ERROR] Exception in EvalExpr: " << e.what() << std::endl; @@ -54,16 +47,14 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { } } - ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) { + std::cerr << "[DEBUG IRGEN] EvalCond: 开始处理条件表达式 " << cond.getText() << std::endl; return std::any_cast(cond.accept(this)); } - - // 基本表达式:数字、变量、括号表达式 std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { - std::cout << "[DEBUG IRGEN] visitPrimaryExp: " << (ctx ? ctx->getText() : "") << std::endl; + std::cerr << "[DEBUG IRGEN] visitPrimaryExp: 开始处理基本表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少基本表达式")); } @@ -82,9 +73,7 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { if (ctx->HEX_FLOAT()) { std::string hex_float_str = ctx->HEX_FLOAT()->getText(); float value = 0.0f; - // 解析十六进制浮点数 try { - // C++11 的 std::stof 支持十六进制浮点数表示 value = std::stof(hex_float_str); } catch (const std::exception& e) { std::cerr << "[WARNING] 无法解析十六进制浮点数: " << hex_float_str @@ -97,7 +86,6 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { return static_cast(const_float); } - // 处理十进制浮点常量 if (ctx->DEC_FLOAT()) { std::string dec_float_str = ctx->DEC_FLOAT()->getText(); float value = 0.0f; @@ -118,6 +106,8 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { std::string hex = ctx->HEX_INT()->getText(); int value = std::stoi(hex, nullptr, 16); ir::Value* const_int = builder_.CreateConstInt(value); + std::cerr << "[DEBUG] visitPrimaryExp: constant hex int " << value + << " created as " << (void*)const_int << std::endl; return static_cast(const_int); } @@ -125,11 +115,14 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { std::string oct = ctx->OCTAL_INT()->getText(); int value = std::stoi(oct, nullptr, 8); ir::Value* const_int = builder_.CreateConstInt(value); + std::cerr << "[DEBUG] visitPrimaryExp: constant octal int " << value + << " created as " << (void*)const_int << std::endl; return static_cast(const_int); } if (ctx->ZERO()) { ir::Value* const_int = builder_.CreateConstInt(0); + std::cerr << "[DEBUG] visitPrimaryExp: constant zero int created" << std::endl; return static_cast(const_int); } @@ -149,12 +142,9 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { throw std::runtime_error(FormatError("irgen", "不支持的基本表达式类型")); } - // 左值(变量)处理 -// 1. 先通过语义分析结果把变量使用绑定回声明; -// 2. 再通过 storage_map_ 找到该声明对应的栈槽位; -// 3. 最后生成 load,把内存中的值读出来。 std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { + std::cerr << "[DEBUG IRGEN] visitLVal: 开始处理左值 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("irgen", "非法左值")); } @@ -162,42 +152,95 @@ std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { std::string varName = ctx->Ident()->getText(); std::cerr << "[DEBUG] visitLVal: " << varName << std::endl; - // 优先检查是否是常量 - auto const_it = const_value_map_.find(varName); - if (const_it != const_value_map_.end()) { - // 常量直接返回值,不需要load - std::cerr << "[DEBUG] visitLVal: constant " << varName << std::endl; - return static_cast(const_it->second); - } - - // 检查全局常量 - auto const_global_it = const_global_map_.find(varName); - if (const_global_it != const_global_map_.end()) { - // 全局常量,需要load - ir::Value* ptr = const_global_it->second; - if (!ctx->exp().empty()) { - // 数组访问 - std::vector indices; - indices.push_back(builder_.CreateConstInt(0)); - for (auto* exp : ctx->exp()) { - ir::Value* index = EvalExpr(*exp); - indices.push_back(index); + // 先检查语义分析中常量绑定 + const SysYParser::ConstDefContext* const_decl = sema_.ResolveConstUse(ctx); + const Symbol* sym = nullptr; + if (const_decl) { + sym = symbol_table_.lookupByConstDef(const_decl); + if (!sym) { + sym = symbol_table_.lookupAll(varName); + } + } else { + sym = symbol_table_.lookup(varName); + } + + // 如果是常量,直接返回常量值 + if (sym && sym->kind == SymbolKind::Constant) { + std::cerr << "[DEBUG] visitLVal: 找到常量 " << varName << std::endl; + + if (sym->IsScalarConstant()) { + if (sym->type->IsInt32()) { + ir::ConstantValue* const_val = builder_.CreateConstInt(sym->GetIntConstant()); + return static_cast(const_val); + } else if (sym->type->IsFloat()) { + ir::ConstantValue* const_val = builder_.CreateConstFloat(sym->GetFloatConstant()); + return static_cast(const_val); + } + } else if (sym->IsArrayConstant()) { + auto it = const_global_map_.find(varName); + if (it != const_global_map_.end()) { + ir::GlobalValue* global_array = it->second; + + // 尝试获取类型信息,用于维度判断与下标线性化 + auto* array_ty = dynamic_cast(sym->type.get()); + if (!array_ty) { + // 无法获取数组类型,退回返回全局对象 + return static_cast(global_array); + } + + size_t ndims = array_ty->GetDimensions().size(); + + // 有下标访问 + if (!ctx->exp().empty()) { + size_t provided = ctx->exp().size(); + + // 完全索引(所有维度都有下标)——直接返回常量元素,不生成 Load + if (provided == ndims) { + std::vector idxs; + idxs.reserve(provided); + for (auto* exp : ctx->exp()) { + ir::Value* v = EvalExpr(*exp); + if (!v || !v->IsConstant()) { + throw std::runtime_error(FormatError("irgen", "常量数组索引必须为常量整数: " + varName)); + } + auto* ci = dynamic_cast(v); + if (!ci) { + throw std::runtime_error(FormatError("irgen", "常量数组索引非整型常量: " + varName)); + } + idxs.push_back(ci->GetValue()); + } + + // 计算线性下标(行主序) + const auto& dims = array_ty->GetDimensions(); + int flat = idxs[0]; + for (size_t i = 1; i < ndims; ++i) { + flat = flat * dims[i] + idxs[i]; + } + + ir::ConstantValue* elem = global_array->GetArrayElement(static_cast(flat)); + return static_cast(elem); + } + + // 部分索引:返回指针(不做 Load),由上层按需处理 + std::vector indices; + indices.push_back(builder_.CreateConstInt(0)); + for (auto* exp : ctx->exp()) { + indices.push_back(EvalExpr(*exp)); + } + return static_cast( + builder_.CreateGEP(global_array, indices, module_.GetContext().NextTemp())); + } else { + // 无下标,直接返回全局常量对象 + return static_cast(global_array); + } } - ir::Value* elem_ptr = builder_.CreateGEP( - ptr, indices, module_.GetContext().NextTemp()); - return static_cast( - builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp())); - } else { - return static_cast( - builder_.CreateLoad(ptr, module_.GetContext().NextTemp())); } } // 不是常量,按正常变量处理 - // ... 原有的变量查找代码 ... - auto* decl = sema_.ResolveVarUse(ctx); ir::Value* ptr = nullptr; + if (decl) { auto it = storage_map_.find(decl); if (it != storage_map_.end()) { @@ -234,27 +277,124 @@ std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { // 检查是否有数组下标 bool is_array_access = !ctx->exp().empty(); if (is_array_access) { - std::vector indices; - indices.push_back(builder_.CreateConstInt(0)); - + // 收集下标表达式(不含前导0) + std::vector idx_vals; for (auto* exp : ctx->exp()) { ir::Value* index = EvalExpr(*exp); - indices.push_back(index); + idx_vals.push_back(index); } - - ir::Value* elem_ptr = builder_.CreateGEP( - ptr, indices, module_.GetContext().NextTemp()); - - return static_cast( - builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp())); + + const Symbol* var_sym = sym; + if (!var_sym) { + var_sym = symbol_table_.lookup(varName); + } + if (!var_sym && decl) { + var_sym = symbol_table_.lookupByVarDef(decl); + } + if (!var_sym) { + var_sym = symbol_table_.lookupAll(varName); + } + + std::vector dims; + if (var_sym) { + if (var_sym->is_array_param && !var_sym->array_dims.empty()) { + dims = var_sym->array_dims; + } else if (var_sym->type && var_sym->type->IsArray()) { + auto* at = dynamic_cast(var_sym->type.get()); + if (at) dims = at->GetDimensions(); + } + } + + if (dims.empty() && ptr->GetType()->IsArray()) { + if (auto* at = dynamic_cast(ptr->GetType().get())) { + dims = at->GetDimensions(); + } + } + + // 兜底:从语法树声明提取维度,避免作用域关闭后符号查询不完整。 + if (dims.empty() && const_decl) { + auto* mutable_const_decl = const_cast(const_decl); + for (auto* cexp : mutable_const_decl->constExp()) { + dims.push_back(symbol_table_.EvaluateConstExp(cexp)); + } + } + if (dims.empty() && decl) { + for (auto* cexp : decl->constExp()) { + dims.push_back(symbol_table_.EvaluateConstExp(cexp)); + } + } + + const bool is_partial_array_access = + !dims.empty() && idx_vals.size() < dims.size(); + + // 如果 base 是标量指针(例如局部扁平数组或数组参数), + // 需要把多维下标折合为单一线性下标,然后用一个索引进行 GEP。 + if (ptr->GetType()->IsPtrInt32() || ptr->GetType()->IsPtrFloat()) { + // 如果没有维度信息,仍尝试用运行时算术合并下标(按后维乘积) + // flat = idx0 * (prod dims[1..]) + idx1 * (prod dims[2..]) + ... + ir::Value* flat = nullptr; + for (size_t i = 0; i < idx_vals.size(); ++i) { + ir::Value* term = idx_vals[i]; + if (!term) continue; + + // 计算乘数(后续维度乘积) + int mult = 1; + if (!dims.empty() && i + 1 < dims.size()) { + for (size_t j = i + 1; j < dims.size(); ++j) { + // 数组参数首维可能是 0(表示省略),不参与乘数。 + if (dims[j] > 0) mult *= dims[j]; + } + } + + if (mult != 1) { + auto* mval = builder_.CreateConstInt(mult); + term = builder_.CreateMul(term, mval, module_.GetContext().NextTemp()); + } + + if (!flat) flat = term; + else flat = builder_.CreateAdd(flat, term, module_.GetContext().NextTemp()); + } + + if (!flat) flat = builder_.CreateConstInt(0); + + // 使用单一索引创建 GEP + std::vector gep_indices = { flat }; + ir::Value* elem_ptr = builder_.CreateGEP(ptr, gep_indices, module_.GetContext().NextTemp()); + if (is_partial_array_access) { + return elem_ptr; + } + return static_cast(builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp())); + } + + std::vector indices; + // 标量指针(T*)使用单索引;数组对象使用前导0进入首层。 + if (ptr->GetType()->IsPtrInt32() || ptr->GetType()->IsPtrFloat()) { + for (auto* v : idx_vals) indices.push_back(v); + } else { + indices.push_back(builder_.CreateConstInt(0)); + for (auto* v : idx_vals) indices.push_back(v); + } + + ir::Value* elem_ptr = builder_.CreateGEP(ptr, indices, module_.GetContext().NextTemp()); + if (is_partial_array_access) { + return elem_ptr; + } + return static_cast(builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp())); } else { - return static_cast( - builder_.CreateLoad(ptr, module_.GetContext().NextTemp())); + if ((sym && sym->is_array_param) || + pointer_param_names_.find(varName) != pointer_param_names_.end() || + heap_local_array_names_.find(varName) != heap_local_array_names_.end()) { + return ptr; + } + if (ptr->GetType()->IsArray()) { + return ptr; + } + return static_cast(builder_.CreateLoad(ptr, module_.GetContext().NextTemp())); } } std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { - std::cout << "[DEBUG IRGEN] visitAddExp: " << (ctx ? ctx->getText() : "") << std::endl; + std::cerr << "[DEBUG IRGEN] visitAddExp: 开始处理加法表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法加法表达式")); } @@ -318,7 +458,7 @@ std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { - std::cout << "[DEBUG IRGEN] visitMulExp: " << (ctx ? ctx->getText() : "") << std::endl; + std::cerr << "[DEBUG IRGEN] visitMulExp: 开始处理乘法表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); } @@ -392,6 +532,7 @@ std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { // 逻辑与 std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { + std::cerr << "[DEBUG IRGEN] visitLAndExp: 开始处理逻辑与表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); if (!ctx->lAndExp()) { @@ -400,14 +541,28 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { ir::Value* left = std::any_cast(ctx->lAndExp()->accept(this)); ir::Value* right = std::any_cast(ctx->eqExp()->accept(this)); - auto zero = builder_.CreateConstInt(0); - auto left_bool = builder_.CreateICmpNE(left, zero, module_.GetContext().NextTemp()); - auto right_bool = builder_.CreateICmpNE(right, zero, module_.GetContext().NextTemp()); - return builder_.CreateAnd(left_bool, right_bool, module_.GetContext().NextTemp()); + + auto to_bool = [&](ir::Value* v) -> ir::Value* { + if (v->GetType()->IsInt1()) { + return v; + } + if (v->GetType()->IsFloat()) { + return builder_.CreateFCmpONE(v, builder_.CreateConstFloat(0.0f), + module_.GetContext().NextTemp()); + } + return builder_.CreateICmpNE(v, builder_.CreateConstInt(0), + module_.GetContext().NextTemp()); + }; + + auto* left_bool = to_bool(left); + auto* right_bool = to_bool(right); + return static_cast( + builder_.CreateAnd(left_bool, right_bool, module_.GetContext().NextTemp())); } // 逻辑或 std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { + std::cerr << "[DEBUG IRGEN] visitLOrExp: 开始处理逻辑或表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); if (!ctx->lOrExp()) { @@ -416,23 +571,39 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { ir::Value* left = std::any_cast(ctx->lOrExp()->accept(this)); ir::Value* right = std::any_cast(ctx->lAndExp()->accept(this)); - auto zero = builder_.CreateConstInt(0); - auto left_bool = builder_.CreateICmpNE(left, zero, module_.GetContext().NextTemp()); - auto right_bool = builder_.CreateICmpNE(right, zero, module_.GetContext().NextTemp()); - return builder_.CreateOr(left_bool, right_bool, module_.GetContext().NextTemp()); + + auto to_bool = [&](ir::Value* v) -> ir::Value* { + if (v->GetType()->IsInt1()) { + return v; + } + if (v->GetType()->IsFloat()) { + return builder_.CreateFCmpONE(v, builder_.CreateConstFloat(0.0f), + module_.GetContext().NextTemp()); + } + return builder_.CreateICmpNE(v, builder_.CreateConstInt(0), + module_.GetContext().NextTemp()); + }; + + auto* left_bool = to_bool(left); + auto* right_bool = to_bool(right); + return static_cast( + builder_.CreateOr(left_bool, right_bool, module_.GetContext().NextTemp())); } std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) { + std::cerr << "[DEBUG IRGEN] visitExp: 开始处理表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) throw std::runtime_error(FormatError("irgen", "非法表达式")); return ctx->addExp()->accept(this); } std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) { + std::cerr << "[DEBUG IRGEN] visitCond: 开始处理条件 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) throw std::runtime_error(FormatError("irgen", "非法条件表达式")); return ctx->lOrExp()->accept(this); } std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) { + std::cerr << "[DEBUG IRGEN] visitCallExp: 开始处理函数调用 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("irgen", "非法函数调用")); } @@ -466,6 +637,33 @@ std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) { } } + // 按形参类型修正实参(数组衰减为指针等)。 + if (auto* fty = dynamic_cast(callee->GetType().get())) { + const auto& param_tys = fty->GetParamTypes(); + size_t n = std::min(param_tys.size(), args.size()); + for (size_t i = 0; i < n; ++i) { + if (!args[i] || !param_tys[i]) continue; + + // 数组实参传给指针形参时,执行数组到指针衰减。 + if (args[i]->GetType()->IsArray() && + (param_tys[i]->IsPtrInt32() || param_tys[i]->IsPtrFloat())) { + std::vector idx; + idx.push_back(builder_.CreateConstInt(0)); + idx.push_back(builder_.CreateConstInt(0)); + args[i] = builder_.CreateGEP(args[i], idx, module_.GetContext().NextTemp()); + } + + // 标量实参的隐式类型转换(int <-> float)。 + if (param_tys[i]->IsFloat() && args[i]->GetType()->IsInt32()) { + args[i] = builder_.CreateSIToFP(args[i], ir::Type::GetFloatType(), + module_.GetContext().NextTemp()); + } else if (param_tys[i]->IsInt32() && args[i]->GetType()->IsFloat()) { + args[i] = builder_.CreateFPToSI(args[i], ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + } + } + // 生成调用指令 ir::Value* callResult = builder_.CreateCall(callee, args, module_.GetContext().NextTemp()); @@ -481,7 +679,7 @@ std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) { // 动态创建运行时函数声明的辅助函数 ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName) { - std::cout << "[DEBUG IRGEN] CreateRuntimeFunctionDecl: " << funcName << std::endl; + std::cerr << "[DEBUG IRGEN] CreateRuntimeFunctionDecl: 开始创建运行时函数声明 " << funcName << std::endl; // 根据常见运行时函数名创建对应的函数类型 if (funcName == "getint" || funcName == "getch") { @@ -498,7 +696,7 @@ ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName) return module_.CreateFunction(funcName, ir::Type::GetFunctionType( ir::Type::GetInt32Type(), - {ir::Type::GetPtrInt32Type(), ir::Type::GetInt32Type()})); + {ir::Type::GetPtrInt32Type()})); } else if (funcName == "putarray") { return module_.CreateFunction(funcName, @@ -522,15 +720,27 @@ ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName) ir::Type::GetVoidType(), {ir::Type::GetInt32Type()})); } - else if (funcName == "read_map") { - return module_.CreateFunction(funcName, - ir::Type::GetFunctionType(ir::Type::GetInt32Type(), {})); + else if (funcName == "getfloat") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType(ir::Type::GetFloatType(), {})); } - else if (funcName == "float_eq") { - return module_.CreateFunction(funcName, + else if (funcName == "putfloat") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetFloatType()})); + } + else if (funcName == "getfarray") { + return module_.CreateFunction(funcName, ir::Type::GetFunctionType( ir::Type::GetInt32Type(), - {ir::Type::GetFloatType(), ir::Type::GetFloatType()})); + {ir::Type::GetPtrFloatType()})); + } + else if (funcName == "putfarray") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetInt32Type(), ir::Type::GetPtrFloatType()})); } else if (funcName == "memset") { return module_.CreateFunction(funcName, @@ -540,12 +750,49 @@ ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName) ir::Type::GetInt32Type(), ir::Type::GetInt32Type()})); } + else if (funcName == "sysy_alloc_i32") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetPtrInt32Type(), + {ir::Type::GetInt32Type()})); + } + else if (funcName == "sysy_alloc_f32") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetPtrFloatType(), + {ir::Type::GetInt32Type()})); + } + else if (funcName == "sysy_free_i32") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrInt32Type()})); + } + else if (funcName == "sysy_free_f32") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrFloatType()})); + } + else if (funcName == "sysy_zero_i32") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrInt32Type(), ir::Type::GetInt32Type()})); + } + else if (funcName == "sysy_zero_f32") { + return module_.CreateFunction(funcName, + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrFloatType(), ir::Type::GetInt32Type()})); + } // 其他函数不支持动态创建 return nullptr; } std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { + std::cerr << "[DEBUG IRGEN] visitUnaryExp: 开始处理一元表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法一元表达式")); } @@ -587,14 +834,15 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { ir::Value* zero; if (operand->GetType()->IsFloat()) { zero = builder_.CreateConstFloat(0.0f); - // 浮点比较:不等于0 - ir::Value* cmp = builder_.CreateFCmpONE(operand, zero, module_.GetContext().NextTemp()); + // 浮点逻辑非:x == 0.0 + ir::Value* cmp = builder_.CreateFCmpOEQ(operand, zero, module_.GetContext().NextTemp()); // 将bool转换为int return static_cast( builder_.CreateZExt(cmp, ir::Type::GetInt32Type())); } else { zero = builder_.CreateConstInt(0); - return builder_.CreateNot(operand, module_.GetContext().NextTemp()); + return static_cast( + builder_.CreateNot(operand, module_.GetContext().NextTemp())); } } } @@ -604,6 +852,7 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { // 实现函数调用 std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) { + std::cerr << "[DEBUG IRGEN] visitFuncRParams: 开始处理函数参数 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) return std::vector{}; std::vector args; for (auto* exp : ctx->exp()) { @@ -612,67 +861,37 @@ std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) { return args; } -// 修改 visitConstExp 以支持常量表达式求值 +// visitConstExp - 处理常量表达式 std::any IRGenImpl::visitConstExp(SysYParser::ConstExpContext* ctx) { - if (!ctx) { + std::cerr << "[DEBUG IRGEN] visitConstExp: 开始处理常量表达式 " << (ctx ? ctx->getText() : "") << std::endl; + if (!ctx || !ctx->addExp()) { throw std::runtime_error(FormatError("irgen", "非法常量表达式")); } + auto result = ctx->addExp()->accept(this); + + if (!result.has_value()) { + throw std::runtime_error(FormatError("irgen", "常量表达式求值失败")); + } + try { - if (ctx->addExp()) { - // 尝试获取数值 - auto result = ctx->addExp()->accept(this); - if (result.has_value()) { - try { - ir::Value* value = std::any_cast(result); - // 尝试判断是否是 ConstantInt - // 暂时简化:返回 IR 值 - return static_cast(value); - } catch (const std::bad_any_cast&) { - // 可能是其他类型 - return static_cast(builder_.CreateConstInt(0)); - } - } - } - return static_cast(builder_.CreateConstInt(0)); - } catch (const std::exception& e) { - std::cerr << "[WARNING] visitConstExp: 常量表达式求值失败: " << e.what() - << ",返回0" << std::endl; - // 如果普通表达式求值失败,返回0 - return static_cast(builder_.CreateConstInt(0)); + return std::any_cast(result); + } catch (const std::bad_any_cast& e) { + throw std::runtime_error(FormatError("irgen", + "常量表达式返回类型错误: " + std::string(e.what()))); } } +// visitConstInitVal - 处理常量初始化值 std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) { + std::cerr << "[DEBUG IRGEN] visitConstInitVal: 开始处理常量初始化值 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法常量初始化值")); } // 如果是单个常量表达式 if (ctx->constExp()) { - try { - auto result = ctx->constExp()->accept(this); - if (result.has_value()) { - try { - ir::Value* value = std::any_cast(result); - // 尝试提取常量值 - if (auto* const_int = dynamic_cast(value)) { - return static_cast(const_int); - } else { - // 如果不是常量,尝试计算数值 - int int_val = TryEvaluateConstInt(ctx->constExp()); - return static_cast(builder_.CreateConstInt(int_val)); - } - } catch (const std::bad_any_cast&) { - int int_val = TryEvaluateConstInt(ctx->constExp()); - return static_cast(builder_.CreateConstInt(int_val)); - } - } - return static_cast(builder_.CreateConstInt(0)); - } catch (const std::exception& e) { - std::cerr << "[WARNING] visitConstInitVal: " << e.what() << std::endl; - return static_cast(builder_.CreateConstInt(0)); - } + return ctx->constExp()->accept(this); } // 如果是聚合初始化(花括号列表) else if (!ctx->constInitVal().empty()) { @@ -680,22 +899,24 @@ std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) { for (auto* init_val : ctx->constInitVal()) { auto result = init_val->accept(this); - if (result.has_value()) { + if (!result.has_value()) { + throw std::runtime_error(FormatError("irgen", "常量初始化值求值失败")); + } + + try { + // 尝试获取单个常量值 + ir::Value* value = std::any_cast(result); + all_values.push_back(value); + } catch (const std::bad_any_cast&) { try { - // 尝试获取单个常量值 - ir::Value* value = std::any_cast(result); - all_values.push_back(value); - } catch (const std::bad_any_cast&) { - try { - // 尝试获取值列表(嵌套情况) - std::vector nested_values = - std::any_cast>(result); - all_values.insert(all_values.end(), - nested_values.begin(), nested_values.end()); - } catch (const std::bad_any_cast&) { - throw std::runtime_error( - FormatError("irgen", "不支持的常量初始化值类型")); - } + // 尝试获取值列表(嵌套情况) + std::vector nested_values = + std::any_cast>(result); + all_values.insert(all_values.end(), + nested_values.begin(), nested_values.end()); + } catch (const std::bad_any_cast& e) { + throw std::runtime_error(FormatError("irgen", + "不支持的常量初始化值类型: " + std::string(e.what()))); } } } @@ -708,6 +929,7 @@ std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) { } std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { + std::cerr << "[DEBUG IRGEN] visitRelExp: 开始处理关系表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法关系表达式")); } @@ -782,6 +1004,7 @@ std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { } std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { + std::cerr << "[DEBUG IRGEN] visitEqExp: 开始处理相等表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法相等表达式")); } @@ -839,6 +1062,7 @@ std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) { + std::cerr << "[DEBUG IRGEN] EvalAssign: 开始处理赋值语句 " << (ctx ? ctx->getText() : "") << std::endl; std::cout << "[DEBUG IRGEN] visitCond: " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx || !ctx->lVal() || !ctx->exp()) { throw std::runtime_error(FormatError("irgen", "非法赋值语句")); @@ -864,15 +1088,29 @@ ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) { ir::Value* base_ptr = it->second; + auto convert_for_store = [&](ir::Value* value, ir::Value* ptr) -> ir::Value* { + if (ptr->GetType()->IsPtrFloat() && value->GetType()->IsInt32()) { + return builder_.CreateSIToFP(value, ir::Type::GetFloatType(), + module_.GetContext().NextTemp()); + } + if (ptr->GetType()->IsPtrInt32() && value->GetType()->IsFloat()) { + return builder_.CreateFPToSI(value, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + return value; + }; + // 检查是否有数组下标 auto exp_list = lval->exp(); if (!exp_list.empty()) { // 这是数组元素赋值,需要生成GEP指令 std::vector indices; - - // 第一个索引是0(假设一维数组) - indices.push_back(builder_.CreateConstInt(0)); - + + // 标量指针参数(T*)不应添加前导0;数组对象需要前导0。 + if (!(base_ptr->GetType()->IsPtrInt32() || base_ptr->GetType()->IsPtrFloat())) { + indices.push_back(builder_.CreateConstInt(0)); + } + // 添加用户提供的下标 for (auto* exp : exp_list) { ir::Value* index = EvalExpr(*exp); @@ -884,6 +1122,7 @@ ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) { base_ptr, indices, module_.GetContext().NextTemp()); // 生成store指令 + rhs = convert_for_store(rhs, elem_ptr); builder_.CreateStore(rhs, elem_ptr); } else { // 普通标量赋值 @@ -891,11 +1130,12 @@ ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) { std::cerr << "[DEBUG] base_ptr type: " << base_ptr->GetType() << std::endl; std::cerr << "[DEBUG] rhs type: " << rhs->GetType()<< std::endl; - // 如果 base_ptr 不是指针类型,可能需要特殊处理 - if (!base_ptr->GetType()->IsPtrInt32()) { + // 如果 base_ptr 不是标量指针类型,可能需要特殊处理 + if (!base_ptr->GetType()->IsPtrInt32() && !base_ptr->GetType()->IsPtrFloat()) { std::cerr << "[ERROR] base_ptr is not a pointer type!" << std::endl; throw std::runtime_error("尝试存储到非指针类型"); } + rhs = convert_for_store(rhs, base_ptr); builder_.CreateStore(rhs, base_ptr); } } else { diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index f9b42f2..d5334e0 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -20,18 +20,41 @@ void VerifyFunctionStructure(const ir::Function& func) { } } -} // namespace +bool HasDirectSelfCall(antlr4::ParserRuleContext* node, + const std::string& func_name) { + if (!node) { + return false; + } -IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema) - : module_(module), - sema_(sema), - func_(nullptr), - builder_(module.GetContext(), nullptr) { - AddRuntimeFunctions(); + if (auto* unary = dynamic_cast(node)) { + if (unary->Ident() && unary->Ident()->getText() == func_name) { + return true; + } + } + + for (auto* child : node->children) { + if (auto* rule = dynamic_cast(child)) { + if (HasDirectSelfCall(rule, func_name)) { + return true; } + } + } + return false; +} + +} // namespace + +// 实现新的构造函数 +IRGenImpl::IRGenImpl(ir::Module& module, + const SemanticContext& sema, + const SymbolTable& sym_table) + : module_(module), sema_(sema), symbol_table_(sym_table), + builder_(module.GetContext(), nullptr), func_(nullptr) { + AddRuntimeFunctions(); +} void IRGenImpl::AddRuntimeFunctions() { - std::cout << "[DEBUG IRGEN] 添加运行时库函数声明" << std::endl; + std::cerr << "[DEBUG IRGEN] 添加运行时库函数声明" << std::endl; // 输入函数(返回 int) module_.CreateFunction("getint", @@ -43,7 +66,7 @@ void IRGenImpl::AddRuntimeFunctions() { module_.CreateFunction("getarray", ir::Type::GetFunctionType( ir::Type::GetInt32Type(), - {ir::Type::GetPtrInt32Type(), ir::Type::GetInt32Type()})); + {ir::Type::GetPtrInt32Type()})); // 输出函数(返回 void) module_.CreateFunction("putint", @@ -83,16 +106,22 @@ void IRGenImpl::AddRuntimeFunctions() { module_.CreateFunction("stoptime", ir::Type::GetFunctionType(ir::Type::GetVoidType(), {})); - // 其他可能需要的函数 - module_.CreateFunction("read_map", - ir::Type::GetFunctionType(ir::Type::GetInt32Type(), {})); - - // 浮点数 - module_.CreateFunction("float_eq", + // 浮点 I/O + module_.CreateFunction("getfloat", + ir::Type::GetFunctionType(ir::Type::GetFloatType(), {})); + module_.CreateFunction("putfloat", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetFloatType()})); + module_.CreateFunction("getfarray", ir::Type::GetFunctionType( ir::Type::GetInt32Type(), - {ir::Type::GetFloatType(), ir::Type::GetFloatType()})); - + {ir::Type::GetPtrFloatType()})); + module_.CreateFunction("putfarray", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetInt32Type(), ir::Type::GetPtrFloatType()})); + // 内存操作函数 module_.CreateFunction("memset", ir::Type::GetFunctionType( @@ -100,13 +129,48 @@ void IRGenImpl::AddRuntimeFunctions() { {ir::Type::GetPtrInt32Type(), ir::Type::GetInt32Type(), ir::Type::GetInt32Type()})); + + module_.CreateFunction("sysy_alloc_i32", + ir::Type::GetFunctionType( + ir::Type::GetPtrInt32Type(), + {ir::Type::GetInt32Type()})); + module_.CreateFunction("sysy_alloc_f32", + ir::Type::GetFunctionType( + ir::Type::GetPtrFloatType(), + {ir::Type::GetInt32Type()})); + module_.CreateFunction("sysy_free_i32", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrInt32Type()})); + module_.CreateFunction("sysy_free_f32", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrFloatType()})); + module_.CreateFunction("sysy_zero_i32", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrInt32Type(), ir::Type::GetInt32Type()})); + module_.CreateFunction("sysy_zero_f32", + ir::Type::GetFunctionType( + ir::Type::GetVoidType(), + {ir::Type::GetPtrFloatType(), ir::Type::GetInt32Type()})); - std::cout << "[DEBUG IRGEN] 运行时库函数声明完成" << std::endl; + std::cerr << "[DEBUG IRGEN] 运行时库函数声明完成" << std::endl; } // 修正:没有 mainFuncDef,通过函数名找到 main std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { - std::cout << "[DEBUG IRGEN] visitCompUnit" << std::endl; + std::cerr << "[DEBUG IRGEN] visitCompUnit" << std::endl; + std::cerr << "[DEBUG] IRGen: 符号表地址 = " << &symbol_table_ << std::endl; + std::cerr << "[DEBUG] IRGen: 开始生成 IR" << std::endl; + + // 尝试查找 main 函数 + const Symbol* main_sym = symbol_table_.lookup("main"); + if (main_sym) { + std::cerr << "[DEBUG] IRGen: 找到 main 函数符号" << std::endl; + } else { + std::cerr << "[DEBUG] IRGen: 未找到 main 函数符号" << std::endl; + } if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少编译单元")); } @@ -129,7 +193,7 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { } std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { - std::cout << "[DEBUG IRGEN] visitFuncDef: " << (ctx && ctx->Ident() ? ctx->Ident()->getText() : "") << std::endl; + std::cerr << "[DEBUG IRGEN] visitFuncDef: " << (ctx && ctx->Ident() ? ctx->Ident()->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少函数定义")); } @@ -216,6 +280,19 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { builder_.SetInsertPoint(entry_block); storage_map_.clear(); param_map_.clear(); + pointer_param_names_.clear(); + heap_local_array_names_.clear(); + current_function_name_ = funcName; + current_function_is_recursive_ = HasDirectSelfCall(ctx->block(), funcName); + function_cleanup_block_ = nullptr; + function_cleanup_actions_.clear(); + function_return_slot_ = nullptr; + + if (ret_type->IsInt32()) { + function_return_slot_ = CreateEntryAllocaI32(module_.GetContext().NextTemp() + ".retval"); + } else if (ret_type->IsFloat()) { + function_return_slot_ = CreateEntryAllocaFloat(module_.GetContext().NextTemp() + ".retval"); + } // 函数参数处理 if (ctx->funcFParams()) { @@ -271,23 +348,29 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { throw std::runtime_error(FormatError("irgen", "添加参数失败: " + name)); } - // 为参数创建存储槽位 - ir::AllocaInst* slot = nullptr; - if (param_ty->IsInt32() || param_ty->IsPtrInt32()) { - slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); - } else if (param_ty->IsFloat() || param_ty->IsPtrFloat()) { - slot = builder_.CreateAllocaFloat(module_.GetContext().NextTemp()); + // 标量参数:入栈到本地槽位;数组参数(指针)直接作为地址使用。 + if (param_ty->IsPtrInt32() || param_ty->IsPtrFloat()) { + param_map_[name] = added_arg; + pointer_param_names_.insert(name); } else { - throw std::runtime_error(FormatError("irgen", "不支持的参数类型")); - } - - if (!slot) { - throw std::runtime_error(FormatError("irgen", "创建参数存储槽位失败: " + name)); + ir::AllocaInst* slot = nullptr; + if (param_ty->IsInt32()) { + slot = CreateEntryAllocaI32(module_.GetContext().NextTemp()); + } else if (param_ty->IsFloat()) { + slot = CreateEntryAllocaFloat(module_.GetContext().NextTemp()); + } else { + throw std::runtime_error(FormatError("irgen", "不支持的参数类型")); + } + + if (!slot) { + throw std::runtime_error(FormatError("irgen", "创建参数存储槽位失败: " + name)); + } + + builder_.CreateStore(added_arg, slot); + param_map_[name] = slot; + pointer_param_names_.erase(name); } - builder_.CreateStore(added_arg, slot); - param_map_[name] = slot; - std::cerr << "[DEBUG] visitFuncDef: 参数 " << name << " 处理完成" << std::endl; } } @@ -296,11 +379,37 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { std::cerr << "[DEBUG] visitFuncDef: 开始生成函数体" << std::endl; ctx->block()->accept(this); - // 如果函数没有终止指令,添加默认返回 - if (!func_->GetEntry()->HasTerminator()) { + // 如果当前插入块没有终止指令,添加默认返回 + if (auto* cur = builder_.GetInsertBlock(); cur && !cur->HasTerminator()) { std::cerr << "[DEBUG] visitFuncDef: 函数体没有终止指令,添加默认返回" << std::endl; - auto retVal = builder_.CreateConstInt(0); - builder_.CreateRet(retVal); + if (function_cleanup_block_) { + if (ret_type->IsFloat()) { + builder_.CreateStore(builder_.CreateConstFloat(0.0f), function_return_slot_); + } else if (ret_type->IsInt32()) { + builder_.CreateStore(builder_.CreateConstInt(0), function_return_slot_); + } + builder_.CreateBr(function_cleanup_block_); + } else if (ret_type->IsVoid()) { + builder_.CreateRet(nullptr); + } else if (ret_type->IsFloat()) { + builder_.CreateRet(builder_.CreateConstFloat(0.0f)); + } else { + builder_.CreateRet(builder_.CreateConstInt(0)); + } + } + + if (function_cleanup_block_ && !function_cleanup_block_->HasTerminator()) { + builder_.SetInsertPoint(function_cleanup_block_); + for (auto it = function_cleanup_actions_.rbegin(); + it != function_cleanup_actions_.rend(); ++it) { + builder_.CreateCall(it->first, {it->second}, module_.GetContext().NextTemp()); + } + + if (ret_type->IsVoid()) { + builder_.CreateRet(nullptr); + } else { + builder_.CreateRet(builder_.CreateLoad(function_return_slot_, module_.GetContext().NextTemp())); + } } // 验证函数结构 @@ -313,12 +422,52 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { std::cerr << "[DEBUG] visitFuncDef: 函数 " << funcName << " 生成完成" << std::endl; func_ = nullptr; + current_function_name_.clear(); + current_function_is_recursive_ = false; + function_return_slot_ = nullptr; + function_cleanup_block_ = nullptr; + function_cleanup_actions_.clear(); return {}; } +ir::BasicBlock* IRGenImpl::EnsureCleanupBlock() { + if (!function_cleanup_block_) { + std::string name = module_.GetContext().NextTemp(); + if (!name.empty() && name[0] == '%') { + name.erase(0, 1); + } + function_cleanup_block_ = func_->CreateBlock("cleanup." + name); + } + return function_cleanup_block_; +} + +void IRGenImpl::RegisterCleanup(ir::Function* free_func, ir::Value* ptr) { + if (!free_func || !ptr) { + return; + } + EnsureCleanupBlock(); + function_cleanup_actions_.push_back({free_func, ptr}); +} + +ir::AllocaInst* IRGenImpl::CreateEntryAlloca(std::shared_ptr ty, + const std::string& name) { + if (!func_ || !func_->GetEntry()) { + throw std::runtime_error(FormatError("irgen", "缺少函数入口块,无法创建入口栈槽位")); + } + return func_->GetEntry()->InsertBeforeTerminator(ty, name); +} + +ir::AllocaInst* IRGenImpl::CreateEntryAllocaI32(const std::string& name) { + return CreateEntryAlloca(ir::Type::GetPtrInt32Type(), name); +} + +ir::AllocaInst* IRGenImpl::CreateEntryAllocaFloat(const std::string& name) { + return CreateEntryAlloca(ir::Type::GetPtrFloatType(), name); +} + std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) { - std::cout << "[DEBUG IRGEN] visitBlock: " << (ctx ? ctx->getText() : "") << std::endl; + std::cerr << "[DEBUG IRGEN] visitBlock: " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句块")); } @@ -333,7 +482,7 @@ std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) { } auto* cur = builder_.GetInsertBlock(); - std::cout << "[DEBUG] current insert block: " + std::cerr << "[DEBUG] current insert block: " << (cur ? cur->GetName() : "") << std::endl; if (cur && cur->HasTerminator()) { break; @@ -351,7 +500,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult( } // 用于遍历块内项,返回是否继续访问后续项(如遇到 return/break/continue 则终止访问) std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { - std::cout << "[DEBUG IRGEN] visitBlockItem: " << (ctx ? ctx->getText() : "") << std::endl; + std::cerr << "[DEBUG IRGEN] visitBlockItem: " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少块内项")); } diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index 18c74d7..6b4661d 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -16,7 +16,7 @@ // - 空语句、块语句嵌套分发之外的更多语句形态 std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { - std::cout << "[DEBUG IRGEN] visitStmt: " << (ctx ? ctx->getText() : "") << std::endl; + std::cerr << "[DEBUG IRGEN] visitStmt: " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句")); } @@ -65,7 +65,7 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { // 修改 HandleReturnStmt 函数 IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) { - std::cout << "[DEBUG IRGEN] HandleReturnStmt: " << (ctx ? ctx->getText() : "") << std::endl; + std::cerr << "[DEBUG IRGEN] HandleReturnStmt: " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少 return 语句")); } @@ -88,8 +88,12 @@ IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) { // 表达式被忽略(可计算但不使用) EvalExpr(*ctx->exp()); } - // 对于void函数,创建返回指令(不传参数) - builder_.CreateRet(nullptr); + if (function_cleanup_block_) { + builder_.CreateBr(function_cleanup_block_); + } else { + // 对于void函数,创建返回指令(不传参数) + builder_.CreateRet(nullptr); + } } else { ir::Value* retValue = nullptr; if (ctx->exp()) { @@ -115,7 +119,12 @@ IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) { retValue = builder_.CreateConstInt(0); // fallback } } - builder_.CreateRet(retValue); + if (function_cleanup_block_) { + builder_.CreateStore(retValue, function_return_slot_); + builder_.CreateBr(function_cleanup_block_); + } else { + builder_.CreateRet(retValue); + } } return BlockFlow::Terminated; } @@ -123,49 +132,63 @@ IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) { // if语句(待实现) IRGenImpl::BlockFlow IRGenImpl::HandleIfStmt(SysYParser::StmtContext* ctx) { - std::cout << "[DEBUG IRGEN] HandleIfStmt: " << (ctx ? ctx->getText() : "") << std::endl; + std::cerr << "[DEBUG IRGEN] HandleIfStmt: " << (ctx ? ctx->getText() : "") << std::endl; auto* cond = ctx->cond(); auto* thenStmt = ctx->stmt(0); auto* elseStmt = ctx->stmt(1); - // 创建基本块 - auto* thenBlock = func_->CreateBlock("then"); - auto* elseBlock = (ctx->Else() && elseStmt) ? func_->CreateBlock("else") : nullptr; - auto* mergeBlock = func_->CreateBlock("merge"); - - std::cout << "[DEBUG IF] thenBlock: " << thenBlock->GetName() << std::endl; - if (elseBlock) std::cout << "[DEBUG IF] elseBlock: " << elseBlock->GetName() << std::endl; - std::cout << "[DEBUG IF] mergeBlock: " << mergeBlock->GetName() << std::endl; - std::cout << "[DEBUG IF] current insert block before cond: " + // 创建基本块(使用唯一名称,避免同名标签) + auto uniq = [&](const std::string& prefix) { + std::string t = module_.GetContext().NextTemp(); + if (!t.empty() && t[0] == '%') t.erase(0, 1); + return prefix + "." + t; + }; + auto* thenBlock = func_->CreateBlock(uniq("then")); + auto* elseBlock = (ctx->Else() && elseStmt) ? func_->CreateBlock(uniq("else")) : nullptr; + auto* mergeBlock = func_->CreateBlock(uniq("merge")); + + std::cerr << "[DEBUG IF] thenBlock: " << thenBlock->GetName() << std::endl; + if (elseBlock) std::cerr << "[DEBUG IF] elseBlock: " << elseBlock->GetName() << std::endl; + std::cerr << "[DEBUG IF] mergeBlock: " << mergeBlock->GetName() << std::endl; + std::cerr << "[DEBUG IF] current insert block before cond: " << builder_.GetInsertBlock()->GetName() << std::endl; // 生成条件 auto* condValue = EvalCond(*cond); + if (!condValue->GetType()->IsInt1()) { + if (condValue->GetType()->IsFloat()) { + condValue = builder_.CreateFCmpONE( + condValue, builder_.CreateConstFloat(0.0f), module_.GetContext().NextTemp()); + } else { + condValue = builder_.CreateICmpNE( + condValue, builder_.CreateConstInt(0), module_.GetContext().NextTemp()); + } + } // 创建条件跳转 if (elseBlock) { - std::cout << "[DEBUG IF] Creating condbr: " << condValue->GetName() + std::cerr << "[DEBUG IF] Creating condbr: " << condValue->GetName() << " -> " << thenBlock->GetName() << ", " << elseBlock->GetName() << std::endl; builder_.CreateCondBr(condValue, thenBlock, elseBlock); } else { - std::cout << "[DEBUG IF] Creating condbr: " << condValue->GetName() + std::cerr << "[DEBUG IF] Creating condbr: " << condValue->GetName() << " -> " << thenBlock->GetName() << ", " << mergeBlock->GetName() << std::endl; builder_.CreateCondBr(condValue, thenBlock, mergeBlock); } // 生成 then 分支 - std::cout << "[DEBUG IF] Generating then branch in block: " << thenBlock->GetName() << std::endl; + std::cerr << "[DEBUG IF] Generating then branch in block: " << thenBlock->GetName() << std::endl; builder_.SetInsertPoint(thenBlock); auto thenResult = thenStmt->accept(this); bool thenTerminated = (std::any_cast(thenResult) == BlockFlow::Terminated); - std::cout << "[DEBUG IF] then branch terminated: " << thenTerminated << std::endl; + std::cerr << "[DEBUG IF] then branch terminated: " << thenTerminated << std::endl; if (!thenTerminated) { - std::cout << "[DEBUG IF] Adding br to merge block from then" << std::endl; + std::cerr << "[DEBUG IF] Adding br to merge block from then" << std::endl; builder_.CreateBr(mergeBlock); } - std::cout << "[DEBUG IF] then block has terminator: " << thenBlock->HasTerminator() << std::endl; + std::cerr << "[DEBUG IF] then block has terminator: " << thenBlock->HasTerminator() << std::endl; // 生成 else 分支 bool elseTerminated = false; @@ -188,16 +211,9 @@ IRGenImpl::BlockFlow IRGenImpl::HandleIfStmt(SysYParser::StmtContext* ctx) { << ", elseTerminated=" << elseTerminated << std::endl; if (elseBlock) { - if (thenTerminated && elseTerminated) { - auto* afterIfBlock = func_->CreateBlock("after.if"); - std::cout << "[DEBUG IF] Both branches terminated, creating new block: " - << afterIfBlock->GetName() << std::endl; - builder_.SetInsertPoint(afterIfBlock); - } else { - std::cout << "[DEBUG IF] Setting insert point to merge block: " - << mergeBlock->GetName() << std::endl; - builder_.SetInsertPoint(mergeBlock); - } + std::cout << "[DEBUG IF] Setting insert point to merge block: " + << mergeBlock->GetName() << std::endl; + builder_.SetInsertPoint(mergeBlock); } else { std::cout << "[DEBUG IF] No else, setting insert point to merge block: " << mergeBlock->GetName() << std::endl; @@ -221,9 +237,14 @@ IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) { std::cout << "[DEBUG WHILE] Current insert block before while: " << builder_.GetInsertBlock()->GetName() << std::endl; - auto* condBlock = func_->CreateBlock("while.cond"); - auto* bodyBlock = func_->CreateBlock("while.body"); - auto* exitBlock = func_->CreateBlock("while.exit"); + auto uniq = [&](const std::string& prefix) { + std::string t = module_.GetContext().NextTemp(); + if (!t.empty() && t[0] == '%') t.erase(0, 1); + return prefix + "." + t; + }; + auto* condBlock = func_->CreateBlock(uniq("while.cond")); + auto* bodyBlock = func_->CreateBlock(uniq("while.body")); + auto* exitBlock = func_->CreateBlock(uniq("while.exit")); std::cout << "[DEBUG WHILE] condBlock: " << condBlock->GetName() << std::endl; std::cout << "[DEBUG WHILE] bodyBlock: " << bodyBlock->GetName() << std::endl; @@ -239,6 +260,15 @@ IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) { std::cout << "[DEBUG WHILE] Generating condition in block: " << condBlock->GetName() << std::endl; builder_.SetInsertPoint(condBlock); auto* condValue = EvalCond(*ctx->cond()); + if (!condValue->GetType()->IsInt1()) { + if (condValue->GetType()->IsFloat()) { + condValue = builder_.CreateFCmpONE( + condValue, builder_.CreateConstFloat(0.0f), module_.GetContext().NextTemp()); + } else { + condValue = builder_.CreateICmpNE( + condValue, builder_.CreateConstInt(0), module_.GetContext().NextTemp()); + } + } builder_.CreateCondBr(condValue, bodyBlock, exitBlock); std::cout << "[DEBUG WHILE] condBlock has terminator: " << condBlock->HasTerminator() << std::endl; @@ -387,17 +417,83 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) { auto exp_list = lval->exp(); if (!exp_list.empty()) { // 数组元素赋值 - std::vector indices; - indices.push_back(builder_.CreateConstInt(0)); - + std::vector idx_vals; for (auto* exp : exp_list) { ir::Value* index = EvalExpr(*exp); - indices.push_back(index); + idx_vals.push_back(index); + } + + ir::Value* elem_ptr = nullptr; + + // 扁平数组/数组参数(T*)的多维访问:先线性化,再单索引 GEP。 + if ((base_ptr->GetType()->IsPtrInt32() || base_ptr->GetType()->IsPtrFloat()) && idx_vals.size() > 1) { + const Symbol* var_sym = symbol_table_.lookup(varName); + if (!var_sym && var_decl) { + var_sym = symbol_table_.lookupByVarDef(var_decl); + } + if (!var_sym) { + var_sym = symbol_table_.lookupAll(varName); + } + + std::vector dims; + if (var_sym) { + if (var_sym->is_array_param && !var_sym->array_dims.empty()) { + dims = var_sym->array_dims; + } else if (var_sym->type && var_sym->type->IsArray()) { + auto* at = dynamic_cast(var_sym->type.get()); + if (at) dims = at->GetDimensions(); + } + } + + if (dims.empty() && var_decl) { + for (auto* cexp : var_decl->constExp()) { + dims.push_back(symbol_table_.EvaluateConstExp(cexp)); + } + } + + ir::Value* flat = nullptr; + for (size_t i = 0; i < idx_vals.size(); ++i) { + ir::Value* term = idx_vals[i]; + if (!term) continue; + + int mult = 1; + if (!dims.empty() && i + 1 < dims.size()) { + for (size_t j = i + 1; j < dims.size(); ++j) { + if (dims[j] > 0) mult *= dims[j]; + } + } + + if (mult != 1) { + auto* mval = builder_.CreateConstInt(mult); + term = builder_.CreateMul(term, mval, module_.GetContext().NextTemp()); + } + + if (!flat) flat = term; + else flat = builder_.CreateAdd(flat, term, module_.GetContext().NextTemp()); + } + + if (!flat) flat = builder_.CreateConstInt(0); + + std::vector gep_indices = {flat}; + elem_ptr = builder_.CreateGEP(base_ptr, gep_indices, module_.GetContext().NextTemp()); + } else { + std::vector indices; + if (base_ptr->GetType()->IsPtrInt32() || base_ptr->GetType()->IsPtrFloat()) { + for (auto* v : idx_vals) indices.push_back(v); + } else { + indices.push_back(builder_.CreateConstInt(0)); + for (auto* v : idx_vals) indices.push_back(v); + } + elem_ptr = builder_.CreateGEP(base_ptr, indices, module_.GetContext().NextTemp()); } - ir::Value* elem_ptr = builder_.CreateGEP( - base_ptr, indices, module_.GetContext().NextTemp()); - + if (elem_ptr->GetType()->IsPtrFloat() && rhs->GetType()->IsInt32()) { + rhs = builder_.CreateSIToFP(rhs, ir::Type::GetFloatType(), + module_.GetContext().NextTemp()); + } else if (elem_ptr->GetType()->IsPtrInt32() && rhs->GetType()->IsFloat()) { + rhs = builder_.CreateFPToSI(rhs, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } builder_.CreateStore(rhs, elem_ptr); } else { // 普通标量赋值 @@ -417,7 +513,14 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) { std::cerr << "[DEBUG] Value is int32: " << rhs->GetType()->IsInt32() << std::endl; } - builder_.CreateStore(rhs, base_ptr); + if (base_ptr->GetType()->IsPtrFloat() && rhs->GetType()->IsInt32()) { + rhs = builder_.CreateSIToFP(rhs, ir::Type::GetFloatType(), + module_.GetContext().NextTemp()); + } else if (base_ptr->GetType()->IsPtrInt32() && rhs->GetType()->IsFloat()) { + rhs = builder_.CreateFPToSI(rhs, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + builder_.CreateStore(rhs, base_ptr); } return BlockFlow::Continue; diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index cbedf1d..8d51715 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -35,7 +35,6 @@ public: if (!ctx) { throw std::runtime_error(FormatError("sema", "缺少编译单元")); } - table_.enterScope(); // 创建全局作用域 for (auto* func : ctx->funcDef()) { // 收集所有函数声明(处理互相调用) CollectFunctionDeclaration(func); } @@ -46,7 +45,6 @@ public: if (func) func->accept(this); } CheckMainFunction(); // 检查 main 函数存在且正确 - table_.exitScope(); // 退出全局作用域 return {}; } @@ -238,6 +236,157 @@ public: << std::endl; } + void CheckConstDef(SysYParser::ConstDefContext* ctx, + std::shared_ptr base_type) { + if (!ctx || !ctx->Ident()) { + throw std::runtime_error(FormatError("sema", "非法常量定义")); + } + std::string name = ctx->Ident()->getText(); + if (table_.lookupCurrent(name)) { + throw std::runtime_error(FormatError("sema", "重复定义常量: " + name)); + } + + // 确定类型 + std::shared_ptr type = base_type; + std::vector dims; + bool is_array = !ctx->constExp().empty(); + std::cout << "[DEBUG] CheckConstDef: " << name + << " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown") + << " is_array: " << is_array + << " dim_count: " << ctx->constExp().size() << std::endl; + + if (is_array) { + for (auto* dim_exp : ctx->constExp()) { + int dim = table_.EvaluateConstExp(dim_exp); + if (dim <= 0) { + throw std::runtime_error(FormatError("sema", "数组维度必须为正整数")); + } + dims.push_back(dim); + std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl; + } + type = ir::Type::GetArrayType(base_type, dims); + std::cout << "[DEBUG] 创建数组类型完成,IsArray: " << type->IsArray() << std::endl; + } + + // ========== 绑定维度表达式 ========== + for (auto* dim_exp : ctx->constExp()) { + dim_exp->addExp()->accept(this); + } + + // 求值初始化器 + std::vector init_values; + if (ctx->constInitVal()) { + // ========== 绑定初始化表达式 ========== + BindConstInitVal(ctx->constInitVal()); + + init_values = table_.EvaluateConstInitVal(ctx->constInitVal(), dims, base_type); + std::cout << "[DEBUG] 初始化值数量: " << init_values.size() << std::endl; + } + + // 计算期望的元素数量 + size_t expected_count = 1; + if (is_array) { + expected_count = 1; + for (int d : dims) expected_count *= d; + std::cout << "[DEBUG] 期望元素数量: " << expected_count << std::endl; + } + + // 如果初始化值不足,补零 + if (is_array && init_values.size() < expected_count) { + std::cout << "[DEBUG] 初始化值不足,补零" << std::endl; + SymbolTable::ConstValue zero; + if (base_type->IsInt32()) { + zero.kind = SymbolTable::ConstValue::INT; + zero.int_val = 0; + } else { + zero.kind = SymbolTable::ConstValue::FLOAT; + zero.float_val = 0.0f; + } + init_values.resize(expected_count, zero); + } + + // 检查初始化值数量 + if (init_values.size() > expected_count) { + throw std::runtime_error(FormatError("sema", "初始化值过多")); + } + + // 创建符号 + Symbol sym; + sym.name = name; + sym.kind = SymbolKind::Constant; + std::cout << "CheckConstDef: before addSymbol, sym.kind = " << (int)sym.kind << std::endl; + sym.type = type; + sym.scope_level = table_.currentScopeLevel(); + sym.is_initialized = true; + sym.var_def_ctx = nullptr; + sym.const_def_ctx = ctx; + std::cout << "保存常量定义上下文: " << name << ", ctx: " << ctx << std::endl; + + // ========== 存储常量值 ========== + if (is_array) { + // 存储数组常量(扁平化存储) + sym.is_array_const = true; + sym.array_const_values.clear(); + + for (const auto& val : init_values) { + Symbol::ConstantValue cv; + if (val.kind == SymbolTable::ConstValue::INT) { + cv.i32 = val.int_val; + } else { + cv.f32 = val.float_val; + } + sym.array_const_values.push_back(cv); + } + + std::cout << "[DEBUG] 存储数组常量,共 " << sym.array_const_values.size() + << " 个元素" << std::endl; + + } else if (!init_values.empty()) { + // 存储标量常量 + if (base_type->IsInt32() && init_values[0].kind == SymbolTable::ConstValue::INT) { + sym.is_int_const = true; + sym.const_value.i32 = init_values[0].int_val; + std::cout << "[DEBUG] 存储整型常量: " << init_values[0].int_val << std::endl; + } else if (base_type->IsFloat() && init_values[0].kind == SymbolTable::ConstValue::FLOAT) { + sym.is_int_const = false; + sym.const_value.f32 = init_values[0].float_val; + std::cout << "[DEBUG] 存储浮点常量: " << init_values[0].float_val << std::endl; + } else if (base_type->IsInt32() && init_values[0].kind == SymbolTable::ConstValue::FLOAT) { + // 整型常量用浮点数初始化(需要检查是否为整数) + float f = init_values[0].float_val; + int i = static_cast(f); + if (std::abs(f - i) > 1e-6) { + throw std::runtime_error(FormatError("sema", + "整型常量不能用非整数值的浮点数初始化: " + std::to_string(f))); + } + sym.is_int_const = true; + sym.const_value.i32 = i; + std::cout << "[DEBUG] 浮点转整型常量: " << f << " -> " << i << std::endl; + } else if (base_type->IsFloat() && init_values[0].kind == SymbolTable::ConstValue::INT) { + // 浮点常量用整型初始化,隐式转换 + sym.is_int_const = false; + sym.const_value.f32 = static_cast(init_values[0].int_val); + std::cout << "[DEBUG] 整型转浮点常量: " << init_values[0].int_val + << " -> " << static_cast(init_values[0].int_val) << std::endl; + } + } else { + // 没有初始化值,对于标量常量这是错误的 + if (!is_array) { + throw std::runtime_error(FormatError("sema", "常量必须有初始化值: " + name)); + } + std::cout << "[DEBUG] 数组常量无初始化器,将全部补零" << std::endl; + } + + table_.addSymbol(sym); + std::cout << "CheckConstDef: after addSymbol, sym.kind = " << (int)sym.kind << std::endl; + auto* stored = table_.lookup(name); + std::cout << "CheckConstDef: after addSymbol, stored const_def_ctx = " << stored->const_def_ctx << std::endl; + + std::cout << "[DEBUG] 常量符号添加完成: " << name + << " is_array_const: " << sym.is_array_const + << " element_count: " << sym.array_const_values.size() << std::endl; +} + // ==================== 常量声明 ==================== std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override { if (!ctx || !ctx->bType()) { @@ -252,91 +401,6 @@ public: return {}; } - void CheckConstDef(SysYParser::ConstDefContext* ctx, - std::shared_ptr base_type) { - if (!ctx || !ctx->Ident()) { - throw std::runtime_error(FormatError("sema", "非法常量定义")); - } - std::string name = ctx->Ident()->getText(); - if (table_.lookupCurrent(name)) { - throw std::runtime_error(FormatError("sema", "重复定义常量: " + name)); - } - // 确定类型 - std::shared_ptr type = base_type; - std::vector dims; - bool is_array = !ctx->constExp().empty(); - std::cout << "[DEBUG] CheckConstDef: " << name - << " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown") - << " is_array: " << is_array - << " dim_count: " << ctx->constExp().size() << std::endl; - if (is_array) { - for (auto* dim_exp : ctx->constExp()) { - int dim = table_.EvaluateConstExp(dim_exp); - if (dim <= 0) { - throw std::runtime_error(FormatError("sema", "数组维度必须为正整数")); - } - dims.push_back(dim); - std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl; - } - type = ir::Type::GetArrayType(base_type, dims); - std::cout << "[DEBUG] 创建数组类型完成,IsArray: " << type->IsArray() << std::endl; - } - - // ========== 绑定维度表达式 ========== - for (auto* dim_exp : ctx->constExp()) { - dim_exp->addExp()->accept(this); - } - - // 求值初始化器 - std::vector init_values; - if (ctx->constInitVal()) { - // ========== 绑定初始化表达式 ========== - BindConstInitVal(ctx->constInitVal()); - - init_values = table_.EvaluateConstInitVal(ctx->constInitVal(), dims, base_type); - std::cout << "[DEBUG] 初始化值数量: " << init_values.size() << std::endl; - } - // 检查初始化值数量 - size_t expected_count = 1; - if (is_array) { - expected_count = 1; - for (int d : dims) expected_count *= d; - std::cout << "[DEBUG] 期望元素数量: " << expected_count << std::endl; - } - if (init_values.size() > expected_count) { - throw std::runtime_error(FormatError("sema", "初始化值过多")); - } - Symbol sym; - sym.name = name; - sym.kind = SymbolKind::Constant; - std::cout << "CheckConstDef: before addSymbol, sym.kind = " << (int)sym.kind << std::endl; - sym.type = type; - sym.scope_level = table_.currentScopeLevel(); - sym.is_initialized = true; - sym.var_def_ctx = nullptr; - sym.const_def_ctx = ctx; - sym.const_def_ctx = ctx; - std::cout << "保存常量定义上下文: " << name << ", ctx: " << ctx << std::endl; - // 存储常量值(仅对非数组有效) - if (!is_array && !init_values.empty()) { - if (base_type->IsInt32() && init_values[0].kind == SymbolTable::ConstValue::INT) { - sym.is_int_const = true; - sym.const_value.i32 = init_values[0].int_val; - } else if (base_type->IsFloat() && init_values[0].kind == SymbolTable::ConstValue::FLOAT) { - sym.is_int_const = false; - sym.const_value.f32 = init_values[0].float_val; - } - } else if (is_array) { - std::cout << "[DEBUG] 数组常量,不存储单个常量值" << std::endl; - } - table_.addSymbol(sym); - std::cout << "CheckConstDef: after addSymbol, sym.kind = " << (int)sym.kind << std::endl; - auto* stored = table_.lookup(name); - std::cout << "CheckConstDef: after addSymbol, stored const_def_ctx = " << stored->const_def_ctx << std::endl; - - std::cout << "[DEBUG] 常量符号添加完成" << std::endl; - } - // ==================== 语句语义检查 ==================== // 处理所有语句 - 通过运行时类型判断 @@ -1004,9 +1068,27 @@ public: sema_.SetExprType(ctx, result); return {}; } - - // 获取语义上下文 + // 新增:获取符号表 + SymbolTable TakeSymbolTable() { return std::move(table_); } SemanticContext TakeSemanticContext() { return std::move(sema_); } + + // 新增:同时返回两者 + SemaResult TakeResult() { + std::cerr << "[DEBUG] TakeResult 前: 符号表作用域数量 = " + << table_.getScopeCount() << std::endl; + + // 可选:打印符号表内容 + // table_.dump(); + + SemaResult result; + result.context = std::move(sema_); + result.symbol_table = std::move(table_); + + std::cerr << "[DEBUG] TakeResult 后: 符号表作用域数量 = " + << result.symbol_table.getScopeCount() << std::endl; + return result; + } + private: SymbolTable table_; @@ -1020,7 +1102,6 @@ private: bool current_func_has_return_ = false; // ==================== 辅助函数 ==================== - ExprInfo CheckExp(SysYParser::ExpContext* ctx) { if (!ctx || !ctx->addExp()) { throw std::runtime_error(FormatError("sema", "无效表达式")); @@ -1497,9 +1578,10 @@ private: } // namespace -SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { - SemaVisitor visitor; - comp_unit.accept(&visitor); - SemanticContext ctx = visitor.TakeSemanticContext(); - return ctx; +// 修改 RunSema 函数,使其返回 SemaResult 结构体,包含符号表和语义上下文 +SemaResult RunSema(SysYParser::CompUnitContext& comp_unit) { + SemaVisitor visitor; + comp_unit.accept(&visitor); + // 直接返回 TakeResult(),利用移动语义 + return visitor.TakeResult(); } diff --git a/src/sem/SymbolTable.cpp b/src/sem/SymbolTable.cpp index 2de9fa4..4253825 100644 --- a/src/sem/SymbolTable.cpp +++ b/src/sem/SymbolTable.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #define DEBUG_SYMBOL_TABLE @@ -17,28 +18,33 @@ // ---------- 构造函数 ---------- SymbolTable::SymbolTable() { scopes_.emplace_back(); // 初始化全局作用域 + active_scope_stack_.push_back(0); registerBuiltinFunctions(); // 注册内置库函数 } // ---------- 作用域管理 ---------- void SymbolTable::enterScope() { scopes_.emplace_back(); + active_scope_stack_.push_back(scopes_.size() - 1); } void SymbolTable::exitScope() { - if (scopes_.size() > 1) { - scopes_.pop_back(); + if (active_scope_stack_.size() > 1) { + active_scope_stack_.pop_back(); } // 不能退出全局作用域 } // ---------- 符号添加与查找 ---------- bool SymbolTable::addSymbol(const Symbol& sym) { - auto& current_scope = scopes_.back(); + auto& current_scope = scopes_[active_scope_stack_.back()]; if (current_scope.find(sym.name) != current_scope.end()) { return false; // 重复定义 } - current_scope[sym.name] = sym; + + Symbol stored_sym = sym; + stored_sym.scope_level = currentScopeLevel(); + current_scope[sym.name] = stored_sym; // 立即验证存储的符号 const auto& stored = current_scope[sym.name]; @@ -59,16 +65,15 @@ Symbol* SymbolTable::lookupCurrent(const std::string& name) { } const Symbol* SymbolTable::lookup(const std::string& name) const { - for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { - const auto& scope = *it; + for (auto it = active_scope_stack_.rbegin(); it != active_scope_stack_.rend(); ++it) { + const auto& scope = scopes_[*it]; auto found = scope.find(name); if (found != scope.end()) { std::cout << "SymbolTable::lookup: found " << name - << " in scope level " << (scopes_.rend() - it - 1) + << " in active scope index " << *it << ", kind=" << (int)found->second.kind << ", const_def_ctx=" << found->second.const_def_ctx << std::endl; - return &found->second; } } @@ -76,7 +81,7 @@ const Symbol* SymbolTable::lookup(const std::string& name) const { } const Symbol* SymbolTable::lookupCurrent(const std::string& name) const { - const auto& current_scope = scopes_.back(); + const auto& current_scope = scopes_[active_scope_stack_.back()]; auto it = current_scope.find(name); if (it != current_scope.end()) { return &it->second; @@ -84,6 +89,40 @@ const Symbol* SymbolTable::lookupCurrent(const std::string& name) const { return nullptr; } +const Symbol* SymbolTable::lookupAll(const std::string& name) const { + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + auto found = it->find(name); + if (found != it->end()) { + return &found->second; + } + } + return nullptr; +} + +const Symbol* SymbolTable::lookupByVarDef(const SysYParser::VarDefContext* decl) const { + if (!decl) return nullptr; + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + for (const auto& [name, sym] : *it) { + if (sym.var_def_ctx == decl) { + return &sym; + } + } + } + return nullptr; +} + +const Symbol* SymbolTable::lookupByConstDef(const SysYParser::ConstDefContext* decl) const { + if (!decl) return nullptr; + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + for (const auto& [name, sym] : *it) { + if (sym.const_def_ctx == decl) { + return &sym; + } + } + } + return nullptr; +} + // ---------- 兼容原接口 ---------- void SymbolTable::Add(const std::string& name, SysYParser::VarDefContext* decl) { Symbol sym; @@ -96,9 +135,9 @@ void SymbolTable::Add(const std::string& name, SysYParser::VarDefContext* decl) } bool SymbolTable::Contains(const std::string& name) const { - // const 方法不能修改 scopes_,我们模拟查找 - for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { - if (it->find(name) != it->end()) { + for (auto it = active_scope_stack_.rbegin(); it != active_scope_stack_.rend(); ++it) { + const auto& scope = scopes_[*it]; + if (scope.find(name) != scope.end()) { return true; } } @@ -106,9 +145,10 @@ bool SymbolTable::Contains(const std::string& name) const { } SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const { - for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { - auto found = it->find(name); - if (found != it->end()) { + for (auto it = active_scope_stack_.rbegin(); it != active_scope_stack_.rend(); ++it) { + const auto& scope = scopes_[*it]; + auto found = scope.find(name); + if (found != scope.end()) { // 只返回变量定义的上下文(函数等其他符号返回 nullptr) if (found->second.kind == SymbolKind::Variable) { return found->second.var_def_ctx; @@ -638,7 +678,7 @@ std::vector SymbolTable::EvaluateConstInitVal( // 隐式类型转换 if (base_type->IsInt32() && val.kind == ConstValue::FLOAT) { val.kind = ConstValue::INT; - val.float_val = static_cast(val.int_val); + val.int_val = static_cast(val.float_val); } if (base_type->IsFloat() && val.kind == ConstValue::INT) { val.kind = ConstValue::FLOAT; @@ -648,32 +688,88 @@ std::vector SymbolTable::EvaluateConstInitVal( } // ========== 2. 数组常量(dims 非空)========== - // 计算数组总元素个数 size_t total = 1; for (int d : dims) total *= d; - - // 展平初始化列表(递归处理花括号) - std::vector flat; - flattenInit(ctx, flat, base_type); - - // 检查数量是否超过数组容量 - if (flat.size() > total) { - throw std::runtime_error("常量初始化:提供的初始值数量超过数组元素总数"); + + ConstValue zero; + if (base_type->IsInt32()) { + zero.kind = ConstValue::INT; + zero.int_val = 0; + } else { + zero.kind = ConstValue::FLOAT; + zero.float_val = 0.0f; } - - // 不足的部分补零 - if (flat.size() < total) { - ConstValue zero; + + // 先整体补零,再按 C 语言花括号规则覆盖显式初始化项。 + std::vector flat(total, zero); + + auto convert_value = [&](ConstValue v) -> ConstValue { if (base_type->IsInt32()) { - zero.kind = ConstValue::INT; - zero.int_val = 0; - } else { - zero.kind = ConstValue::FLOAT; - zero.float_val = 0.0f; + if (v.kind == ConstValue::FLOAT) { + throw std::runtime_error("常量初始化:整型数组不能使用浮点常量"); + } + return v; } - flat.resize(total, zero); - } - + if (v.kind == ConstValue::INT) { + ConstValue t; + t.kind = ConstValue::FLOAT; + t.float_val = static_cast(v.int_val); + return t; + } + return v; + }; + + auto subarray_span = [&](size_t depth) -> size_t { + size_t span = 1; + for (size_t i = depth + 1; i < dims.size(); ++i) span *= static_cast(dims[i]); + return span; + }; + + std::function fill; + fill = [&](SysYParser::ConstInitValContext* node, + size_t depth, + size_t begin, + size_t end) -> size_t { + if (!node || begin >= end) return begin; + + // 标量初始化项 + if (node->constExp()) { + ConstValue v = convert_value(EvaluateAddExp(node->constExp()->addExp())); + if (begin < flat.size()) flat[begin] = v; + return std::min(begin + 1, end); + } + + size_t cursor = begin; + for (auto* child : node->constInitVal()) { + if (cursor >= end) break; + + if (child->constExp()) { + ConstValue v = convert_value(EvaluateAddExp(child->constExp()->addExp())); + if (cursor < flat.size()) flat[cursor] = v; + ++cursor; + continue; + } + + // 花括号子列表:在非最内层需要按子聚合边界对齐。 + if (depth + 1 < dims.size()) { + const size_t span = subarray_span(depth); + const size_t rel = (cursor - begin) % span; + if (rel != 0) cursor += (span - rel); + if (cursor >= end) break; + + const size_t sub_end = std::min(cursor + span, end); + fill(child, depth + 1, cursor, sub_end); + // 一个带花括号的子初始化器会消费当前层的一个子聚合。 + cursor = sub_end; + } else { + // 最内层(标量数组)遇到额外花括号,按同层顺序展开。 + cursor = fill(child, depth, cursor, end); + } + } + return cursor; + }; + + fill(ctx, 0, 0, total); return flat; } diff --git a/sylib/sylib.c b/sylib/sylib.c index 7f26d0b..7237ef1 100644 --- a/sylib/sylib.c +++ b/sylib/sylib.c @@ -1,4 +1,162 @@ -// SysY 运行库实现: -// - 按实验/评测规范提供 I/O 等函数实现 -// - 与编译器生成的目标代码链接,支撑运行时行为 +#include "sylib.h" + +#include +#include + +extern int scanf(const char* format, ...); +extern int printf(const char* format, ...); +extern int getchar(void); +extern int putchar(int c); + +int getint(void) { + int x = 0; + scanf("%d", &x); + return x; +} + +int getch(void) { + return getchar(); +} + +int getarray(int a[]) { + int n; + scanf("%d", &n); + int i = 0; + for (; i < n; ++i) { + scanf("%d", &a[i]); + } + return n; +} + +float getfloat(void) { + float x = 0.0f; + scanf("%f", &x); + return x; +} + +int getfarray(float a[]) { + int n = 0; + if (scanf("%d", &n) != 1) { + return 0; + } + int i = 0; + for (; i < n; ++i) { + if (scanf("%f", &a[i]) != 1) { + return i; + } + } + return n; +} + +void putint(int x) { + printf("%d", x); +} + +void putch(int x) { + putchar(x); +} + +void putarray(int n, int a[]) { + int i = 0; + printf("%d:", n); + for (; i < n; ++i) { + printf(" %d", a[i]); + } + putchar('\n'); +} + +void putfloat(float x) { + printf("%a", x); +} + +void putfarray(int n, float a[]) { + int i = 0; + printf("%d:", n); + for (; i < n; ++i) { + printf(" %a", a[i]); + } + putchar('\n'); +} + +void puts(int s[]) { + if (!s) return; + while (*s) { + putchar(*s); + ++s; + } +} + +void _sysy_starttime(int lineno) { + (void)lineno; +} + +void _sysy_stoptime(int lineno) { + (void)lineno; +} + +void starttime(void) { + _sysy_starttime(0); +} + +void stoptime(void) { + _sysy_stoptime(0); +} + +int* memset(int* ptr, int value, int count) { + unsigned char* p = (unsigned char*)ptr; + unsigned char byte = (unsigned char)(value & 0xFF); + int i = 0; + for (; i < count; ++i) { + p[i] = byte; + } + return ptr; +} + +int* sysy_alloc_i32(int count) { + if (count <= 0) { + return 0; + } + return (int*)malloc((size_t)count * sizeof(int)); +} + +float* sysy_alloc_f32(int count) { + if (count <= 0) { + return 0; + } + return (float*)malloc((size_t)count * sizeof(float)); +} + +void sysy_free_i32(int* ptr) { + if (!ptr) { + return; + } + free(ptr); +} + +void sysy_free_f32(float* ptr) { + if (!ptr) { + return; + } + free(ptr); +} + +void sysy_zero_i32(int* ptr, int count) { + int i = 0; + if (!ptr || count <= 0) { + return; + } + for (; i < count; ++i) { + ptr[i] = 0; + } +} + +void sysy_zero_f32(float* ptr, int count) { + int i = 0; + if (!ptr || count <= 0) { + return; + } + for (; i < count; ++i) { + ptr[i] = 0.0f; + } +} diff --git a/sylib/sylib.h b/sylib/sylib.h index 502d488..0d81b83 100644 --- a/sylib/sylib.h +++ b/sylib/sylib.h @@ -1,4 +1,29 @@ -// SysY 运行库头文件: -// - 声明运行库函数原型(供编译器生成 call 或链接阶段引用) -// - 与 sylib.c 配套,按规范逐步补齐声明 +#pragma once + +int getint(void); +int getch(void); +int getarray(int a[]); +float getfloat(void); +int getfarray(float a[]); + +void putint(int x); +void putch(int x); +void putarray(int n, int a[]); +void putfloat(float x); +void putfarray(int n, float a[]); +void puts(int s[]); + +void _sysy_starttime(int lineno); +void _sysy_stoptime(int lineno); +void starttime(void); +void stoptime(void); + +int read_map(void); +int* memset(int* ptr, int value, int count); +int* sysy_alloc_i32(int count); +float* sysy_alloc_f32(int count); +void sysy_free_i32(int* ptr); +void sysy_free_f32(float* ptr); +void sysy_zero_i32(int* ptr, int count); +void sysy_zero_f32(float* ptr, int count);