diff --git a/TEST_RESULTS.md b/TEST_RESULTS.md index 00b3980..054bf34 100644 --- a/TEST_RESULTS.md +++ b/TEST_RESULTS.md @@ -10,10 +10,10 @@ 5. 15_graph_coloring - 图着色算法 (使用2D数组和指针参数) 6. 22_matrix_multiply - 矩阵乘法 (2D数组) 7. 25_scope3 - 作用域测试 + 8. 29_break - break语句 9. 36_op_priority2 - 运算符优先级 10. simple_add - 简单加法 - ### ✗ 失败的测试 (1个): - 95_float - **需要浮点数常量支持** (当前仅支持int) diff --git a/doc/Lab3-指令选择与汇编生成.md b/doc/Lab3-指令选择与汇编生成.md index fa66dcb..6baa6f5 100644 --- a/doc/Lab3-指令选择与汇编生成.md +++ b/doc/Lab3-指令选择与汇编生成.md @@ -58,3 +58,4 @@ cmake --build build -j "$(nproc)" 若最终输出 `输出匹配: test/test_case/simple_add.out`,说明当前示例用例 `return a + b` 的完整后端链路已经跑通。 但最终不能只检查 `simple_add`。完成 Lab3 后,应对 `test/test_case` 下全部测试用例逐个回归,确认代码生成结果能够通过统一验证;如有需要,也可以自行编写批量测试脚本统一执行。 + diff --git a/include/ir/IR.h b/include/ir/IR.h index fca3bb1..6ff2fa7 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -188,6 +188,7 @@ enum class Opcode { Div, Mod, Cmp, + Cast, Br, CondBr, Call, @@ -199,6 +200,7 @@ enum class Opcode { }; enum class CmpOp { Eq, Ne, Lt, Le, Gt, Ge }; +enum class CastOp { IntToFloat, FloatToInt }; // User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。 // 当前实现中只有 Instruction 继承自 User。 @@ -229,14 +231,19 @@ class GlobalValue : public User { // 数组:打印为 @name = global [count x i32] zeroinitializer。 class GlobalVariable : public GlobalValue { public: - GlobalVariable(std::string name, int init_val = 0, int count = 1); + GlobalVariable(std::string name, std::shared_ptr ptr_ty, + int init_val = 0, int count = 1, + std::vector init_elems = {}); int GetInitValue() const { return init_val_; } int GetCount() const { return count_; } bool IsArray() const { return count_ > 1; } + bool IsFloat() const { return GetType() && GetType()->IsPtrFloat32(); } + const std::vector& GetInitElements() const { return init_elems_; } private: int init_val_; int count_; + std::vector init_elems_; }; class Instruction : public User { @@ -272,6 +279,16 @@ class CmpInst : public Instruction { CmpOp cmp_op_; }; +class CastInst : public Instruction { + public: + CastInst(CastOp op, std::shared_ptr ty, Value* val, std::string name); + CastOp GetCastOp() const; + Value* GetValue() const; + + private: + CastOp cast_op_; +}; + class ReturnInst : public Instruction { public: ReturnInst(std::shared_ptr void_ty, Value* val); @@ -410,7 +427,10 @@ class Module { Function* FindFunction(const std::string& name) const; const std::vector>& GetFunctions() const; - GlobalVariable* CreateGlobalVar(const std::string& name, int init_val = 0, int count = 1); + GlobalVariable* CreateGlobalVar(const std::string& name, int init_val = 0, + int count = 1, + std::shared_ptr ptr_ty = Type::GetPtrInt32Type(), + std::vector init_elems = {}); GlobalVariable* FindGlobalVar(const std::string& name) const; const std::vector>& GetGlobalVars() const; @@ -436,6 +456,8 @@ class IRBuilder { BinaryInst* CreateDiv(Value* lhs, Value* rhs, const std::string& name); BinaryInst* CreateMod(Value* lhs, Value* rhs, const std::string& name); CmpInst* CreateCmp(CmpOp op, Value* lhs, Value* rhs, const std::string& name); + CastInst* CreateSIToFP(Value* v, const std::string& name); + CastInst* CreateFPToSI(Value* v, const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name); AllocaInst* CreateAllocaArray(int count, const std::string& name); AllocaInst* CreateAllocaF32(const std::string& name); diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index eae6c3c..4e13fab 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -25,6 +25,8 @@ class IRGenImpl final : public SysYBaseVisitor { public: // const 变量名 -> 编译期整数值,用于数组维度折叠。 using ConstEnv = std::unordered_map; + // const 变量名 -> 编译期浮点值,用于 float const 折叠。 + using ConstFloatEnv = std::unordered_map; IRGenImpl(ir::Module& module, const SemanticContext& sema); @@ -81,8 +83,12 @@ class IRGenImpl final : public SysYBaseVisitor { // 编译期常量整数求值(用于数组维度)。 int EvalConstExpr(SysYParser::ConstExpContext* ctx) const; + // 编译期常量浮点求值(用于 float const)。 + float EvalConstExprAsFloat(SysYParser::ConstExpContext* ctx) const; // 将 ExpContext(即 addExp)按编译期常量求值(用于 funcFParam 维度)。 int EvalExpAsConst(SysYParser::ExpContext* ctx) const; + // 将 ExpContext 按编译期常量浮点求值(用于 float 全局初始化等)。 + float EvalExpAsConstFloat(SysYParser::ExpContext* ctx) const; // 查找变量的数组维度(先查局部,再查全局)。 const std::vector* FindArrayDims(const std::string& name) const; @@ -91,15 +97,28 @@ class IRGenImpl final : public SysYBaseVisitor { ir::Value* ComputeLinearIndex(const std::vector& dims, const std::vector& subs); + // 简单隐式类型转换:i32 <-> float。 + ir::Value* CastToFloat(ir::Value* v); + ir::Value* CastToInt(ir::Value* v); + // 扁平化 constInitValue 到整数数组(供 const 数组初始化使用)。 void FlattenConstInit(SysYParser::ConstInitValueContext* ctx, const std::vector& dims, int dim_idx, std::vector& out, int& pos); + void FlattenConstInitFloat(SysYParser::ConstInitValueContext* ctx, + const std::vector& dims, int dim_idx, + std::vector& out, int& pos); // 扁平化 initValue 到 ir::Value* 数组(供普通数组初始化使用)。 void FlattenInit(SysYParser::InitValueContext* ctx, const std::vector& dims, int dim_idx, std::vector& out, int& pos); + void FlattenGlobalInitInt(SysYParser::InitValueContext* ctx, + const std::vector& dims, int dim_idx, + std::vector& out, int& pos); + void FlattenGlobalInitFloat(SysYParser::InitValueContext* ctx, + const std::vector& dims, int dim_idx, + std::vector& out, int& pos); ir::AllocaInst* CreateEntryAllocaI32(const std::string& name); ir::AllocaInst* CreateEntryAllocaArray(int count, const std::string& name); @@ -121,6 +140,8 @@ class IRGenImpl final : public SysYBaseVisitor { std::unordered_map global_storage_; // 编译期 const 整数环境(全局 + 当前函数)。 ConstEnv const_env_; + // 编译期 const 浮点环境(全局 + 当前函数)。 + ConstFloatEnv const_float_env_; // 数组维度信息:全局数组(跨函数持久)。 std::unordered_map> global_array_dims_; // 数组维度信息:局部数组/参数(每函数清空)。 diff --git a/include/mir/MIR.h b/include/mir/MIR.h index 408e7f8..6d8a7c8 100644 --- a/include/mir/MIR.h +++ b/include/mir/MIR.h @@ -57,6 +57,8 @@ enum class Opcode { FSubRR, // 浮点减法 FMulRR, // 浮点乘法 FDivRR, // 浮点除法 + SIToFP, // 有符号整型转浮点 + FPToSI, // 浮点转有符号整型 CmpRR, FCmpRR, // 浮点比较 Bl, @@ -162,14 +164,17 @@ class MachineModule { return functions_; } - void AddGlobalVar(std::string name, int init_val, int count); - const std::vector>& GetGlobalVars() const { + void AddGlobalVar(std::string name, int init_val, int count, bool is_float, + std::vector init_elems = {}); + const std::vector>>& + GetGlobalVars() const { return global_vars_; } private: std::vector> functions_; - std::vector> global_vars_; // (name, init, count) + std::vector>> + global_vars_; // (name, init, count, is_float, init_elements) }; std::unique_ptr LowerToMIR(const ir::Module& module); diff --git a/scripts/verify_asm.sh b/scripts/verify_asm.sh index 656b42b..e529839 100755 --- a/scripts/verify_asm.sh +++ b/scripts/verify_asm.sh @@ -83,7 +83,8 @@ if [[ "$run_exec" == true ]]; then } > "$actual_file" if [[ -f "$expected_file" ]]; then - if diff -u "$expected_file" "$actual_file"; then + if diff -u <(perl -0pe 's/\n\z//' "$expected_file") \ + <(perl -0pe 's/\n\z//' "$actual_file"); then echo "输出匹配: $expected_file" else echo "输出不匹配: $expected_file" >&2 diff --git a/src/ir/GlobalValue.cpp b/src/ir/GlobalValue.cpp index a492d26..b9d89c6 100644 --- a/src/ir/GlobalValue.cpp +++ b/src/ir/GlobalValue.cpp @@ -7,9 +7,12 @@ namespace ir { GlobalValue::GlobalValue(std::shared_ptr ty, std::string name) : User(std::move(ty), std::move(name)) {} -GlobalVariable::GlobalVariable(std::string name, int init_val, int count) - : GlobalValue(Type::GetPtrInt32Type(), std::move(name)), +GlobalVariable::GlobalVariable(std::string name, std::shared_ptr ptr_ty, + int init_val, int count, + std::vector init_elems) + : GlobalValue(std::move(ptr_ty), std::move(name)), init_val_(init_val), - count_(count) {} + count_(count), + init_elems_(std::move(init_elems)) {} } // namespace ir diff --git a/src/ir/IRBuilder.cpp b/src/ir/IRBuilder.cpp index 1c8d084..adb88e7 100644 --- a/src/ir/IRBuilder.cpp +++ b/src/ir/IRBuilder.cpp @@ -75,6 +75,28 @@ CmpInst* IRBuilder::CreateCmp(CmpOp op, Value* lhs, Value* rhs, name); } +CastInst* IRBuilder::CreateSIToFP(Value* v, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!v) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateSIToFP 缺少操作数")); + } + return insert_block_->Append(CastOp::IntToFloat, Type::GetFloat32Type(), + v, name); +} + +CastInst* IRBuilder::CreateFPToSI(Value* v, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!v) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateFPToSI 缺少操作数")); + } + return insert_block_->Append(CastOp::FloatToInt, Type::GetInt32Type(), + v, name); +} + AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); @@ -116,7 +138,11 @@ GepInst* IRBuilder::CreateGep(Value* base, Value* index, const std::string& name if (!base || !index) { throw std::runtime_error(FormatError("ir", "IRBuilder::CreateGep 缺少操作数")); } - return insert_block_->Append(Type::GetPtrInt32Type(), base, index, name); + std::shared_ptr ptr_ty = Type::GetPtrInt32Type(); + if (base->GetType() && base->GetType()->IsPtrFloat32()) { + ptr_ty = Type::GetPtrFloat32Type(); + } + return insert_block_->Append(ptr_ty, base, index, name); } LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 52f93e4..6d9256c 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -4,6 +4,8 @@ #include "ir/IR.h" +#include +#include #include #include #include @@ -42,6 +44,8 @@ static const char* OpcodeToString(Opcode op) { return "srem"; case Opcode::Cmp: return "icmp"; + case Opcode::Cast: + return "cast"; case Opcode::Br: return "br"; case Opcode::CondBr: @@ -100,11 +104,20 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { // 先打印全局变量 for (const auto& gv : module.GetGlobalVars()) { if (!gv) continue; + const char* elem_ty = gv->IsFloat() ? "float" : "i32"; if (gv->IsArray()) { os << "@" << gv->GetName() << " = global [" << gv->GetCount() - << " x i32] zeroinitializer\n"; + << " x " << elem_ty << "] zeroinitializer\n"; } else { - os << "@" << gv->GetName() << " = global i32 " << gv->GetInitValue() << "\n"; + if (gv->IsFloat()) { + std::int32_t bits = static_cast(gv->GetInitValue()); + float fval = 0.0f; + std::memcpy(&fval, &bits, sizeof(fval)); + os << "@" << gv->GetName() << " = global float " << fval << "\n"; + } else { + os << "@" << gv->GetName() << " = global i32 " << gv->GetInitValue() + << "\n"; + } } } if (!module.GetGlobalVars().empty()) os << "\n"; @@ -163,26 +176,41 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { << ValueToString(cmp->GetRhs()) << "\n"; break; } + case Opcode::Cast: { + auto* cast = static_cast(inst); + const char* cast_name = + (cast->GetCastOp() == CastOp::IntToFloat) ? "sitofp" : "fptosi"; + os << " " << cast->GetName() << " = " << cast_name << " " + << TypeToString(*cast->GetValue()->GetType()) << " " + << ValueToString(cast->GetValue()) << " to " + << TypeToString(*cast->GetType()) << "\n"; + break; + } case Opcode::Alloca: { auto* alloca = static_cast(inst); + const char* elem_ty = alloca->GetType()->IsPtrFloat32() ? "float" : "i32"; if (alloca->IsArray()) { - os << " " << alloca->GetName() << " = alloca i32, i32 " + os << " " << alloca->GetName() << " = alloca " << elem_ty << ", i32 " << alloca->GetCount() << "\n"; } else { - os << " " << alloca->GetName() << " = alloca i32\n"; + os << " " << alloca->GetName() << " = alloca " << elem_ty << "\n"; } break; } case Opcode::Load: { auto* load = static_cast(inst); - os << " " << load->GetName() << " = load i32, i32* " + os << " " << load->GetName() << " = load " + << TypeToString(*load->GetType()) << ", " + << TypeToString(*load->GetPtr()->GetType()) << " " << ValueToString(load->GetPtr()) << "\n"; break; } case Opcode::Store: { auto* store = static_cast(inst); - os << " store i32 " << ValueToString(store->GetValue()) - << ", i32* " << ValueToString(store->GetPtr()) << "\n"; + os << " store " << TypeToString(*store->GetValue()->GetType()) + << " " << ValueToString(store->GetValue()) + << ", " << TypeToString(*store->GetPtr()->GetType()) + << " " << ValueToString(store->GetPtr()) << "\n"; break; } case Opcode::Br: { @@ -219,18 +247,20 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { case Opcode::Gep: { auto* gep = static_cast(inst); auto* base = gep->GetBase(); + const char* elem_ty = base->GetType()->IsPtrFloat32() ? "float" : "i32"; // 全局数组用双下标 GEP,局部 alloca 用平坦 GEP。 if (auto* gv = dynamic_cast(base)) { if (gv->IsArray()) { os << " " << gep->GetName() - << " = getelementptr [" << gv->GetCount() << " x i32], [" - << gv->GetCount() << " x i32]* @" << gv->GetName() + << " = getelementptr [" << gv->GetCount() << " x " << elem_ty << "], [" + << gv->GetCount() << " x " << elem_ty << "]* @" << gv->GetName() << ", i32 0, i32 " << ValueToString(gep->GetIndex()) << "\n"; break; } } os << " " << gep->GetName() - << " = getelementptr i32, i32* " << ValueToString(base) + << " = getelementptr " << elem_ty << ", " + << TypeToString(*base->GetType()) << " " << ValueToString(base) << ", i32 " << ValueToString(gep->GetIndex()) << "\n"; break; } diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index ff7eac3..bc7c45c 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -124,8 +124,13 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, type_->GetKind() != lhs->GetType()->GetKind()) { throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配")); } - if (!type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32")); + const bool is_i32 = type_->IsInt32(); + const bool is_f32 = type_->IsFloat32(); + if (!is_i32 && !is_f32) { + throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32/float")); + } + if (op == Opcode::Mod && !is_i32) { + throw std::runtime_error(FormatError("ir", "BinaryInst 的 mod 仅支持 i32")); } AddOperand(lhs); AddOperand(rhs); @@ -147,9 +152,11 @@ CmpInst::CmpInst(CmpOp op, std::shared_ptr ty, Value* lhs, Value* rhs, if (!type_->IsInt32()) { throw std::runtime_error(FormatError("ir", "CmpInst 结果类型必须为 i32")); } - if (!lhs->GetType()->IsInt32() || !rhs->GetType()->IsInt32()) { + const bool is_int_cmp = lhs->GetType()->IsInt32() && rhs->GetType()->IsInt32(); + const bool is_float_cmp = lhs->GetType()->IsFloat32() && rhs->GetType()->IsFloat32(); + if (!is_int_cmp && !is_float_cmp) { throw std::runtime_error(FormatError( - "ir", "CmpInst 当前只支持 i32 比较,实际为 " + + "ir", "CmpInst 当前只支持 i32/float 同类型比较,实际为 " + std::string(TypeKindToString(lhs->GetType()->GetKind())) + " 与 " + std::string(TypeKindToString(rhs->GetType()->GetKind())))); @@ -164,6 +171,28 @@ Value* CmpInst::GetLhs() const { return GetOperand(0); } Value* CmpInst::GetRhs() const { return GetOperand(1); } +CastInst::CastInst(CastOp op, std::shared_ptr ty, Value* val, + std::string name) + : Instruction(Opcode::Cast, std::move(ty), std::move(name)), cast_op_(op) { + if (!val || !val->GetType() || !type_) { + throw std::runtime_error(FormatError("ir", "CastInst 缺少类型信息或操作数")); + } + if (cast_op_ == CastOp::IntToFloat) { + if (!val->GetType()->IsInt32() || !type_->IsFloat32()) { + throw std::runtime_error(FormatError("ir", "IntToFloat 需要 i32 -> float")); + } + } else { + if (!val->GetType()->IsFloat32() || !type_->IsInt32()) { + throw std::runtime_error(FormatError("ir", "FloatToInt 需要 float -> i32")); + } + } + AddOperand(val); +} + +CastOp CastInst::GetCastOp() const { return cast_op_; } + +Value* CastInst::GetValue() const { return GetOperand(0); } + ReturnInst::ReturnInst(std::shared_ptr void_ty, Value* val) : Instruction(Opcode::Ret, std::move(void_ty), "") { if (!type_ || !type_->IsVoid()) { @@ -327,8 +356,9 @@ GepInst::GepInst(std::shared_ptr ptr_ty, Value* base, Value* index, if (!base || !index) { throw std::runtime_error(FormatError("ir", "GepInst 缺少操作数")); } - if (!base->GetType() || !base->GetType()->IsPtrInt32()) { - throw std::runtime_error(FormatError("ir", "GepInst base 必须为 i32*")); + if (!base->GetType() || + (!base->GetType()->IsPtrInt32() && !base->GetType()->IsPtrFloat32())) { + throw std::runtime_error(FormatError("ir", "GepInst base 必须为 i32*/float*")); } if (!index->GetType() || !index->GetType()->IsInt32()) { throw std::runtime_error(FormatError("ir", "GepInst index 必须为 i32")); diff --git a/src/ir/Module.cpp b/src/ir/Module.cpp index e281a49..bac59a0 100644 --- a/src/ir/Module.cpp +++ b/src/ir/Module.cpp @@ -27,8 +27,12 @@ Function* Module::FindFunction(const std::string& name) const { return nullptr; } -GlobalVariable* Module::CreateGlobalVar(const std::string& name, int init_val, int count) { - global_vars_.push_back(std::make_unique(name, init_val, count)); +GlobalVariable* Module::CreateGlobalVar(const std::string& name, int init_val, + int count, std::shared_ptr ptr_ty, + std::vector init_elems) { + global_vars_.push_back( + std::make_unique(name, std::move(ptr_ty), init_val, count, + std::move(init_elems))); return global_vars_.back().get(); } diff --git a/src/irgen/IRGenConstEval.cpp b/src/irgen/IRGenConstEval.cpp index f50f6a5..eaa38c6 100644 --- a/src/irgen/IRGenConstEval.cpp +++ b/src/irgen/IRGenConstEval.cpp @@ -1,5 +1,7 @@ #include "irgen/IRGen.h" +#include +#include #include #include @@ -9,75 +11,103 @@ // 内部辅助:不依赖类成员,只需 ConstEnv。 namespace { -int EvalAddExp(SysYParser::AddExpContext* ctx, - const IRGenImpl::ConstEnv& env); -int EvalMulExp(SysYParser::MulExpContext* ctx, - const IRGenImpl::ConstEnv& env); -int EvalUnaryExp(SysYParser::UnaryExpContext* ctx, - const IRGenImpl::ConstEnv& env); +double EvalAddExp(SysYParser::AddExpContext* ctx, + const IRGenImpl::ConstEnv& int_env, + const IRGenImpl::ConstFloatEnv& float_env); +double EvalMulExp(SysYParser::MulExpContext* ctx, + const IRGenImpl::ConstEnv& int_env, + const IRGenImpl::ConstFloatEnv& float_env); +double EvalUnaryExp(SysYParser::UnaryExpContext* ctx, + const IRGenImpl::ConstEnv& int_env, + const IRGenImpl::ConstFloatEnv& float_env); -int EvalPrimary(SysYParser::PrimaryExpContext* ctx, - const IRGenImpl::ConstEnv& env) { +int ParseIntLiteral(const std::string& text) { + if (text.size() >= 2 && text[0] == '0' && + (text[1] == 'x' || text[1] == 'X')) { + return std::stoi(text, nullptr, 16); + } + if (text.size() > 1 && text[0] == '0') { + return std::stoi(text, nullptr, 8); + } + return std::stoi(text); +} + +double EvalPrimary(SysYParser::PrimaryExpContext* ctx, + const IRGenImpl::ConstEnv& int_env, + const IRGenImpl::ConstFloatEnv& float_env) { if (!ctx) throw std::runtime_error(FormatError("consteval", "空主表达式")); if (ctx->number()) { - if (!ctx->number()->ILITERAL()) - throw std::runtime_error( - FormatError("consteval", "constExp 不支持浮点字面量")); - return std::stoi(ctx->number()->getText()); + if (ctx->number()->ILITERAL()) { + return static_cast(ParseIntLiteral(ctx->number()->getText())); + } + if (ctx->number()->FLITERAL()) { + return static_cast(std::strtof(ctx->number()->getText().c_str(), nullptr)); + } + throw std::runtime_error(FormatError("consteval", "非法数字字面量")); } - if (ctx->exp()) return EvalAddExp(ctx->exp()->addExp(), env); + if (ctx->exp()) return EvalAddExp(ctx->exp()->addExp(), int_env, float_env); if (ctx->lValue()) { if (!ctx->lValue()->ID()) throw std::runtime_error(FormatError("consteval", "非法 lValue")); const std::string name = ctx->lValue()->ID()->getText(); - auto it = env.find(name); - if (it == env.end()) - throw std::runtime_error( - FormatError("consteval", "constExp 引用非 const 变量: " + name)); - return it->second; + auto it_int = int_env.find(name); + if (it_int != int_env.end()) return static_cast(it_int->second); + auto it_float = float_env.find(name); + if (it_float != float_env.end()) return static_cast(it_float->second); + throw std::runtime_error( + FormatError("consteval", "constExp 引用非 const 变量: " + name)); } throw std::runtime_error(FormatError("consteval", "不支持的主表达式形式")); } -int EvalUnaryExp(SysYParser::UnaryExpContext* ctx, - const IRGenImpl::ConstEnv& env) { +double EvalUnaryExp(SysYParser::UnaryExpContext* ctx, + const IRGenImpl::ConstEnv& int_env, + const IRGenImpl::ConstFloatEnv& float_env) { if (!ctx) throw std::runtime_error(FormatError("consteval", "空一元表达式")); - if (ctx->primaryExp()) return EvalPrimary(ctx->primaryExp(), env); + if (ctx->primaryExp()) return EvalPrimary(ctx->primaryExp(), int_env, float_env); if (ctx->unaryOp() && ctx->unaryExp()) { - int v = EvalUnaryExp(ctx->unaryExp(), env); + double v = EvalUnaryExp(ctx->unaryExp(), int_env, float_env); if (ctx->unaryOp()->SUB()) return -v; if (ctx->unaryOp()->ADD()) return v; - if (ctx->unaryOp()->NOT()) return (v == 0) ? 1 : 0; + if (ctx->unaryOp()->NOT()) return (v == 0.0) ? 1.0 : 0.0; } throw std::runtime_error( FormatError("consteval", "函数调用不能出现在 constExp 中")); } -int EvalMulExp(SysYParser::MulExpContext* ctx, - const IRGenImpl::ConstEnv& env) { +double EvalMulExp(SysYParser::MulExpContext* ctx, + const IRGenImpl::ConstEnv& int_env, + const IRGenImpl::ConstFloatEnv& float_env) { if (!ctx) throw std::runtime_error(FormatError("consteval", "空乘法表达式")); if (ctx->mulExp()) { - int lhs = EvalMulExp(ctx->mulExp(), env); - int rhs = EvalUnaryExp(ctx->unaryExp(), env); + double lhs = EvalMulExp(ctx->mulExp(), int_env, float_env); + double rhs = EvalUnaryExp(ctx->unaryExp(), int_env, float_env); if (ctx->MUL()) return lhs * rhs; - if (ctx->DIV()) { if (!rhs) throw std::runtime_error("除以零"); return lhs / rhs; } - if (ctx->MOD()) { if (!rhs) throw std::runtime_error("模零"); return lhs % rhs; } + if (ctx->DIV()) { + if (rhs == 0.0) throw std::runtime_error("除以零"); + return lhs / rhs; + } + if (ctx->MOD()) { + if (rhs == 0.0) throw std::runtime_error("模零"); + return std::fmod(lhs, rhs); + } throw std::runtime_error(FormatError("consteval", "未知乘法运算符")); } - return EvalUnaryExp(ctx->unaryExp(), env); + return EvalUnaryExp(ctx->unaryExp(), int_env, float_env); } -int EvalAddExp(SysYParser::AddExpContext* ctx, - const IRGenImpl::ConstEnv& env) { +double EvalAddExp(SysYParser::AddExpContext* ctx, + const IRGenImpl::ConstEnv& int_env, + const IRGenImpl::ConstFloatEnv& float_env) { if (!ctx) throw std::runtime_error(FormatError("consteval", "空加法表达式")); if (ctx->addExp()) { - int lhs = EvalAddExp(ctx->addExp(), env); - int rhs = EvalMulExp(ctx->mulExp(), env); + double lhs = EvalAddExp(ctx->addExp(), int_env, float_env); + double rhs = EvalMulExp(ctx->mulExp(), int_env, float_env); if (ctx->ADD()) return lhs + rhs; if (ctx->SUB()) return lhs - rhs; throw std::runtime_error(FormatError("consteval", "未知加法运算符")); } - return EvalMulExp(ctx->mulExp(), env); + return EvalMulExp(ctx->mulExp(), int_env, float_env); } } // namespace @@ -85,11 +115,23 @@ int EvalAddExp(SysYParser::AddExpContext* ctx, int IRGenImpl::EvalConstExpr(SysYParser::ConstExpContext* ctx) const { if (!ctx || !ctx->addExp()) throw std::runtime_error(FormatError("consteval", "空 constExp")); - return EvalAddExp(ctx->addExp(), const_env_); + return static_cast(EvalAddExp(ctx->addExp(), const_env_, const_float_env_)); +} + +float IRGenImpl::EvalConstExprAsFloat(SysYParser::ConstExpContext* ctx) const { + if (!ctx || !ctx->addExp()) + throw std::runtime_error(FormatError("consteval", "空 constExp")); + return static_cast(EvalAddExp(ctx->addExp(), const_env_, const_float_env_)); } int IRGenImpl::EvalExpAsConst(SysYParser::ExpContext* ctx) const { if (!ctx || !ctx->addExp()) throw std::runtime_error(FormatError("consteval", "空 exp")); - return EvalAddExp(ctx->addExp(), const_env_); + return static_cast(EvalAddExp(ctx->addExp(), const_env_, const_float_env_)); +} + +float IRGenImpl::EvalExpAsConstFloat(SysYParser::ExpContext* ctx) const { + if (!ctx || !ctx->addExp()) + throw std::runtime_error(FormatError("consteval", "空 exp")); + return static_cast(EvalAddExp(ctx->addExp(), const_env_, const_float_env_)); } diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index fe0772e..fe31973 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -1,5 +1,7 @@ #include "irgen/IRGen.h" +#include +#include #include #include "SysYParser.h" @@ -10,6 +12,8 @@ std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句块")); } + const auto saved_const_env = const_env_; + const auto saved_const_float_env = const_float_env_; for (auto* item : ctx->blockItem()) { if (item) { if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { @@ -17,6 +21,8 @@ std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { } } } + const_env_ = saved_const_env; + const_float_env_ = saved_const_float_env; return {}; } @@ -98,6 +104,40 @@ void IRGenImpl::FlattenConstInit(SysYParser::ConstInitValueContext* ctx, while (pos < start + agg_size) out[pos++] = 0; } +void IRGenImpl::FlattenConstInitFloat(SysYParser::ConstInitValueContext* ctx, + const std::vector& dims, int dim_idx, + std::vector& out, int& pos) { + if (!ctx) return; + + if (ctx->constExp()) { + out[pos++] = EvalConstExprAsFloat(ctx->constExp()); + return; + } + + int sub_size = 1; + for (int i = dim_idx + 1; i < (int)dims.size(); i++) sub_size *= dims[i]; + int agg_size = (dim_idx < (int)dims.size()) ? dims[dim_idx] * sub_size : 1; + int start = pos; + + for (auto* item : ctx->constInitValue()) { + if (!item || pos >= start + agg_size) break; + if (item->constExp()) { + out[pos++] = EvalConstExprAsFloat(item->constExp()); + } else { + if (sub_size > 1) { + int offset = pos - start; + int rem = offset % sub_size; + if (rem != 0) pos += sub_size - rem; + } + int sub_start = pos; + FlattenConstInitFloat(item, dims, dim_idx + 1, out, pos); + int sub_end = sub_start + sub_size; + while (pos < sub_end && pos < start + agg_size) out[pos++] = 0.0f; + } + } + while (pos < start + agg_size) out[pos++] = 0.0f; +} + // ─── 工具:扁平化 initValue ─────────────────────────────────────────────── void IRGenImpl::FlattenInit(SysYParser::InitValueContext* ctx, const std::vector& dims, int dim_idx, @@ -133,6 +173,75 @@ void IRGenImpl::FlattenInit(SysYParser::InitValueContext* ctx, while (pos < start + agg_size) pos++; // zeros } +void IRGenImpl::FlattenGlobalInitInt(SysYParser::InitValueContext* ctx, + const std::vector& dims, int dim_idx, + std::vector& out, int& pos) { + if (!ctx) return; + + if (ctx->exp()) { + out[pos++] = EvalExpAsConst(ctx->exp()); + return; + } + + int sub_size = 1; + for (int i = dim_idx + 1; i < (int)dims.size(); i++) sub_size *= dims[i]; + int agg_size = (dim_idx < (int)dims.size()) ? dims[dim_idx] * sub_size : 1; + int start = pos; + + for (auto* item : ctx->initValue()) { + if (!item || pos >= start + agg_size) break; + if (item->exp()) { + out[pos++] = EvalExpAsConst(item->exp()); + } else { + if (sub_size > 1) { + int offset = pos - start; + int rem = offset % sub_size; + if (rem != 0) pos += sub_size - rem; + } + int sub_start = pos; + FlattenGlobalInitInt(item, dims, dim_idx + 1, out, pos); + int sub_end = sub_start + sub_size; + while (pos < sub_end && pos < start + agg_size) out[pos++] = 0; + } + } + while (pos < start + agg_size) out[pos++] = 0; +} + +void IRGenImpl::FlattenGlobalInitFloat(SysYParser::InitValueContext* ctx, + const std::vector& dims, + int dim_idx, std::vector& out, + int& pos) { + if (!ctx) return; + + if (ctx->exp()) { + out[pos++] = EvalExpAsConstFloat(ctx->exp()); + return; + } + + int sub_size = 1; + for (int i = dim_idx + 1; i < (int)dims.size(); i++) sub_size *= dims[i]; + int agg_size = (dim_idx < (int)dims.size()) ? dims[dim_idx] * sub_size : 1; + int start = pos; + + for (auto* item : ctx->initValue()) { + if (!item || pos >= start + agg_size) break; + if (item->exp()) { + out[pos++] = EvalExpAsConstFloat(item->exp()); + } else { + if (sub_size > 1) { + int offset = pos - start; + int rem = offset % sub_size; + if (rem != 0) pos += sub_size - rem; + } + int sub_start = pos; + FlattenGlobalInitFloat(item, dims, dim_idx + 1, out, pos); + int sub_end = sub_start + sub_size; + while (pos < sub_end && pos < start + agg_size) out[pos++] = 0.0f; + } + } + while (pos < start + agg_size) out[pos++] = 0.0f; +} + // ─── const 声明 ─────────────────────────────────────────────────────────── std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) { @@ -140,16 +249,17 @@ std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) { if (!ctx->btype()) { throw std::runtime_error(FormatError("irgen", "缺少类型声明")); } - // 暂时只处理int const,float const留待后续实现 - if (ctx->btype()->FLOAT()) { - throw std::runtime_error(FormatError("irgen", "暂不支持 float const 声明")); - } - if (!ctx->btype()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持 int const 声明")); + if (ctx->btype()->INT()) { + current_decl_type_ = ir::Type::GetInt32Type(); + } else if (ctx->btype()->FLOAT()) { + current_decl_type_ = ir::Type::GetFloat32Type(); + } else { + throw std::runtime_error(FormatError("irgen", "当前仅支持 int/float const 声明")); } for (auto* def : ctx->constDef()) { if (def) def->accept(this); } + current_decl_type_ = nullptr; return {}; } @@ -162,16 +272,34 @@ std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) { if (!ctx->constInitValue() || !ctx->constInitValue()->constExp()) { throw std::runtime_error(FormatError("irgen", "const 标量声明缺少初始值")); } - int ival = EvalConstExpr(ctx->constInitValue()->constExp()); - const_env_[name] = ival; // 存入编译期环境 - - if (IsGlobalScope()) { - auto* gv = module_.CreateGlobalVar(name, ival); - global_storage_[name] = gv; + const bool is_float_const = current_decl_type_ && current_decl_type_->IsFloat32(); + if (is_float_const) { + float fval = EvalConstExprAsFloat(ctx->constInitValue()->constExp()); + const_float_env_[name] = fval; + + if (IsGlobalScope()) { + std::int32_t bits = 0; + std::memcpy(&bits, &fval, sizeof(bits)); + auto* gv = module_.CreateGlobalVar( + name, static_cast(bits), 1, ir::Type::GetPtrFloat32Type()); + global_storage_[name] = gv; + } else { + auto* slot = CreateEntryAllocaF32(module_.GetContext().NextTemp()); + named_storage_[name] = slot; + builder_.CreateStore(module_.GetContext().GetConstFloat(fval), slot); + } } else { - auto* slot = CreateEntryAllocaI32(module_.GetContext().NextTemp()); - named_storage_[name] = slot; - builder_.CreateStore(builder_.CreateConstInt(ival), slot); + int ival = EvalConstExpr(ctx->constInitValue()->constExp()); + const_env_[name] = ival; // 存入编译期环境 + + if (IsGlobalScope()) { + auto* gv = module_.CreateGlobalVar(name, ival); + global_storage_[name] = gv; + } else { + auto* slot = CreateEntryAllocaI32(module_.GetContext().NextTemp()); + named_storage_[name] = slot; + builder_.CreateStore(builder_.CreateConstInt(ival), slot); + } } return {}; } @@ -184,6 +312,40 @@ std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) { int total = 1; for (int d : dims) total *= d; + const bool is_float_const = current_decl_type_ && current_decl_type_->IsFloat32(); + + if (is_float_const) { + std::vector flat(total, 0.0f); + if (ctx->constInitValue()) { + int pos = 0; + FlattenConstInitFloat(ctx->constInitValue(), dims, 0, flat, pos); + } + std::vector init_bits; + init_bits.reserve(flat.size()); + for (float v : flat) { + std::int32_t bits = 0; + std::memcpy(&bits, &v, sizeof(bits)); + init_bits.push_back(static_cast(bits)); + } + + if (IsGlobalScope()) { + auto* gv = module_.CreateGlobalVar( + name, 0, total, ir::Type::GetPtrFloat32Type(), std::move(init_bits)); + global_storage_[name] = gv; + global_array_dims_[name] = dims; + } else { + auto* slot = CreateEntryAllocaF32Array(total, module_.GetContext().NextTemp()); + named_storage_[name] = slot; + local_array_dims_[name] = dims; + for (int i = 0; i < total; i++) { + auto* idx = builder_.CreateConstInt(i); + auto* ptr = builder_.CreateGep(slot, idx, module_.GetContext().NextTemp()); + builder_.CreateStore(module_.GetContext().GetConstFloat(flat[i]), ptr); + } + } + return {}; + } + // 扁平化初始化值 std::vector flat(total, 0); if (ctx->constInitValue()) { @@ -192,9 +354,9 @@ std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) { } if (IsGlobalScope()) { - // 全局 const 数组:创建全局数组变量(仅支持零初始化;非零初始化暂用零) - // TODO: 支持全局 const 数组的非零初始化 - auto* gv = module_.CreateGlobalVar(name, 0, total); + auto* gv = module_.CreateGlobalVar(name, 0, total, + ir::Type::GetPtrInt32Type(), + std::move(flat)); global_storage_[name] = gv; global_array_dims_[name] = dims; } else { @@ -255,11 +417,32 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { for (int d : dims) total *= d; if (IsGlobalScope()) { - auto* gv = module_.CreateGlobalVar(name, 0, total); + std::vector init_elems; + if (auto* init_val = ctx->initValue()) { + if (current_decl_type_->IsFloat32()) { + std::vector flat(total, 0.0f); + int pos = 0; + FlattenGlobalInitFloat(init_val, dims, 0, flat, pos); + init_elems.reserve(flat.size()); + for (float v : flat) { + std::int32_t bits = 0; + std::memcpy(&bits, &v, sizeof(bits)); + init_elems.push_back(static_cast(bits)); + } + } else { + init_elems.assign(total, 0); + int pos = 0; + FlattenGlobalInitInt(init_val, dims, 0, init_elems, pos); + } + } + auto* gv = module_.CreateGlobalVar( + name, 0, total, + current_decl_type_->IsFloat32() ? ir::Type::GetPtrFloat32Type() + : ir::Type::GetPtrInt32Type(), + std::move(init_elems)); storage_map_[ctx] = gv; global_storage_[name] = gv; global_array_dims_[name] = dims; - // 全局数组:不支持运行时初始化(全零已足够) } else { // 根据当前声明类型创建数组alloca ir::AllocaInst* slot; @@ -291,7 +474,13 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { if (flat[i] != nullptr) { auto* idx = builder_.CreateConstInt(i); auto* ptr = builder_.CreateGep(slot, idx, module_.GetContext().NextTemp()); - builder_.CreateStore(flat[i], ptr); + ir::Value* val = flat[i]; + if (ptr->GetType()->IsPtrFloat32()) { + val = CastToFloat(val); + } else { + val = CastToInt(val); + } + builder_.CreateStore(val, ptr); } } } @@ -301,15 +490,32 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { // ── 标量变量 ────────────────────────────────────────────────────────── if (IsGlobalScope()) { - int ival = 0; - if (auto* init_value = ctx->initValue()) { - if (!init_value->exp()) { - throw std::runtime_error( - FormatError("irgen", "全局标量变量仅支持表达式初始化")); + int init_bits_or_int = 0; + if (current_decl_type_->IsFloat32()) { + float fval = 0.0f; + if (auto* init_value = ctx->initValue()) { + if (!init_value->exp()) { + throw std::runtime_error( + FormatError("irgen", "全局标量变量仅支持表达式初始化")); + } + fval = EvalExpAsConstFloat(init_value->exp()); + } + std::int32_t bits = 0; + std::memcpy(&bits, &fval, sizeof(bits)); + init_bits_or_int = static_cast(bits); + } else { + if (auto* init_value = ctx->initValue()) { + if (!init_value->exp()) { + throw std::runtime_error( + FormatError("irgen", "全局标量变量仅支持表达式初始化")); + } + init_bits_or_int = EvalExpAsConst(init_value->exp()); } - ival = EvalExpAsConst(init_value->exp()); } - auto* gv = module_.CreateGlobalVar(name, ival); + auto* gv = module_.CreateGlobalVar( + name, init_bits_or_int, 1, + current_decl_type_->IsFloat32() ? ir::Type::GetPtrFloat32Type() + : ir::Type::GetPtrInt32Type()); storage_map_[ctx] = gv; global_storage_[name] = gv; return {}; @@ -343,6 +549,11 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { init = builder_.CreateConstInt(0); } } + if (current_decl_type_->IsFloat32()) { + init = CastToFloat(init); + } else { + init = CastToInt(init); + } builder_.CreateStore(init, slot); return {}; } diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index a1d13ae..f75e9dd 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -14,18 +14,42 @@ ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) { return std::any_cast(cond.accept(this)); } +ir::Value* IRGenImpl::CastToFloat(ir::Value* v) { + if (!v || !v->GetType()) { + throw std::runtime_error(FormatError("irgen", "CastToFloat 输入为空")); + } + if (v->GetType()->IsFloat32()) return v; + if (v->GetType()->IsInt32()) { + return builder_.CreateSIToFP(v, module_.GetContext().NextTemp()); + } + throw std::runtime_error(FormatError("irgen", "不支持转换到 float 的类型")); +} + +ir::Value* IRGenImpl::CastToInt(ir::Value* v) { + if (!v || !v->GetType()) { + throw std::runtime_error(FormatError("irgen", "CastToInt 输入为空")); + } + if (v->GetType()->IsInt32()) return v; + if (v->GetType()->IsFloat32()) { + return builder_.CreateFPToSI(v, module_.GetContext().NextTemp()); + } + throw std::runtime_error(FormatError("irgen", "不支持转换到 i32 的类型")); +} + ir::Value* IRGenImpl::ToBoolValue(ir::Value* v) { if (!v) { throw std::runtime_error(FormatError("irgen", "条件值为空")); } - if (v->GetType() && v->GetType()->IsPtrInt32()) { + if (v->GetType() && (v->GetType()->IsPtrInt32() || v->GetType()->IsPtrFloat32())) { // SysY 中数组名退化得到的指针在当前实现里总是非空。 return builder_.CreateConstInt(1); } if (dynamic_cast(v) != nullptr) { return v; } - auto* zero = builder_.CreateConstInt(0); + ir::Value* zero = v->GetType()->IsFloat32() + ? static_cast(module_.GetContext().GetConstFloat(0.0f)) + : static_cast(builder_.CreateConstInt(0)); return builder_.CreateCmp(ir::CmpOp::Ne, v, zero, module_.GetContext().NextTemp()); } @@ -60,7 +84,7 @@ ir::Value* IRGenImpl::ComputeLinearIndex( int stride = 1; for (int j = k + 1; j < (int)dims.size(); j++) stride *= dims[j]; - ir::Value* idx = EvalExpr(*subs[k]); + ir::Value* idx = CastToInt(EvalExpr(*subs[k])); if (stride != 1) { auto* sv = builder_.CreateConstInt(stride); idx = builder_.CreateMul(idx, sv, module_.GetContext().NextTemp()); @@ -184,6 +208,15 @@ std::any IRGenImpl::visitLValue(SysYParser::LValueContext* ctx) { const std::string name = ctx->ID()->getText(); if (ctx->exp().empty()) { + auto itf = const_float_env_.find(name); + if (itf != const_float_env_.end()) { + return static_cast(module_.GetContext().GetConstFloat(itf->second)); + } + auto iti = const_env_.find(name); + if (iti != const_env_.end()) { + return static_cast(builder_.CreateConstInt(iti->second)); + } + // 无下标:标量读取 或 数组基址引用 ir::Value* slot = ResolveStorage(ctx); if (!slot) { @@ -230,7 +263,9 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { if (ctx->unaryOp() && ctx->unaryExp()) { ir::Value* v = std::any_cast(ctx->unaryExp()->accept(this)); if (ctx->unaryOp()->SUB()) { - auto* zero = builder_.CreateConstInt(0); + ir::Value* zero = v->GetType()->IsFloat32() + ? static_cast(module_.GetContext().GetConstFloat(0.0f)) + : static_cast(builder_.CreateConstInt(0)); return static_cast(builder_.CreateSub( zero, v, module_.GetContext().NextTemp())); } @@ -239,7 +274,9 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { } if (ctx->unaryOp()->NOT()) { // !v ≡ (v == 0) - auto* zero = builder_.CreateConstInt(0); + ir::Value* zero = v->GetType()->IsFloat32() + ? static_cast(module_.GetContext().GetConstFloat(0.0f)) + : static_cast(builder_.CreateConstInt(0)); return static_cast(builder_.CreateCmp( ir::CmpOp::Eq, v, zero, module_.GetContext().NextTemp())); } @@ -255,8 +292,19 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { } std::vector args; if (auto* rparams = ctx->funcRParams()) { + const auto& param_types = callee->GetParamTypes(); + size_t i = 0; for (auto* ep : rparams->exp()) { - args.push_back(EvalExpr(*ep)); + ir::Value* arg = EvalExpr(*ep); + if (i < param_types.size()) { + if (param_types[i]->IsFloat32()) { + arg = CastToFloat(arg); + } else if (param_types[i]->IsInt32()) { + arg = CastToInt(arg); + } + } + args.push_back(arg); + ++i; } } const std::string name = @@ -277,6 +325,11 @@ std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { } ir::Value* lhs = std::any_cast(ctx->mulExp()->accept(this)); ir::Value* rhs = std::any_cast(ctx->unaryExp()->accept(this)); + const bool has_float = lhs->GetType()->IsFloat32() || rhs->GetType()->IsFloat32(); + if (has_float) { + lhs = CastToFloat(lhs); + rhs = CastToFloat(rhs); + } if (ctx->MUL()) { return static_cast( builder_.CreateMul(lhs, rhs, module_.GetContext().NextTemp())); @@ -286,6 +339,8 @@ std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { builder_.CreateDiv(lhs, rhs, module_.GetContext().NextTemp())); } if (ctx->MOD()) { + lhs = CastToInt(lhs); + rhs = CastToInt(rhs); return static_cast( builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp())); } @@ -307,6 +362,10 @@ std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { } ir::Value* lhs = std::any_cast(ctx->addExp()->accept(this)); ir::Value* rhs = std::any_cast(ctx->mulExp()->accept(this)); + if (lhs->GetType()->IsFloat32() || rhs->GetType()->IsFloat32()) { + lhs = CastToFloat(lhs); + rhs = CastToFloat(rhs); + } if (ctx->ADD()) { return static_cast( builder_.CreateAdd(lhs, rhs, module_.GetContext().NextTemp())); @@ -333,6 +392,10 @@ std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { } ir::Value* lhs = std::any_cast(ctx->relExp()->accept(this)); ir::Value* rhs = std::any_cast(ctx->addExp()->accept(this)); + if (lhs->GetType()->IsFloat32() || rhs->GetType()->IsFloat32()) { + lhs = CastToFloat(lhs); + rhs = CastToFloat(rhs); + } if (ctx->LT()) { return static_cast(builder_.CreateCmp( ir::CmpOp::Lt, lhs, rhs, module_.GetContext().NextTemp())); @@ -367,6 +430,10 @@ std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { } ir::Value* lhs = std::any_cast(ctx->eqExp()->accept(this)); ir::Value* rhs = std::any_cast(ctx->relExp()->accept(this)); + if (lhs->GetType()->IsFloat32() || rhs->GetType()->IsFloat32()) { + lhs = CastToFloat(lhs); + rhs = CastToFloat(rhs); + } if (ctx->EQ()) { return static_cast(builder_.CreateCmp( ir::CmpOp::Eq, lhs, rhs, module_.GetContext().NextTemp())); diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index c6ae277..4ec640a 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -93,6 +93,11 @@ void IRGenImpl::DeclareRuntimeFunctions() { // 数组 I/O decl("getarray", i32, {ir::Type::GetPtrInt32Type()}); decl("putarray", void_, {i32, ir::Type::GetPtrInt32Type()}); + // 浮点 I/O + decl("getfloat", ir::Type::GetFloat32Type(), {}); + decl("getfarray", i32, {ir::Type::GetPtrFloat32Type()}); + decl("putfloat", void_, {ir::Type::GetFloat32Type()}); + decl("putfarray", void_, {i32, ir::Type::GetPtrFloat32Type()}); // 时间 decl("starttime", void_, {}); decl("stoptime", void_, {}); @@ -216,7 +221,12 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { } } else { // 标量参数:alloca + store - auto* slot = CreateEntryAllocaI32(module_.GetContext().NextTemp()); + ir::AllocaInst* slot = nullptr; + if (arg->GetType()->IsFloat32()) { + slot = CreateEntryAllocaF32(module_.GetContext().NextTemp()); + } else { + slot = CreateEntryAllocaI32(module_.GetContext().NextTemp()); + } builder_.CreateStore(arg, slot); if (!param_names[i].empty()) { named_storage_[param_names[i]] = slot; diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index f555726..a8357be 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -29,6 +29,11 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { ? ctx->lValue()->ID()->getText() : "?"))); } + if (slot->GetType() && slot->GetType()->IsPtrFloat32()) { + rhs = CastToFloat(rhs); + } else if (slot->GetType() && slot->GetType()->IsPtrInt32()) { + rhs = CastToInt(rhs); + } builder_.CreateStore(rhs, slot); return BlockFlow::Continue; } @@ -138,6 +143,13 @@ std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { return BlockFlow::Terminated; } ir::Value* v = EvalExpr(*ctx->exp()); + if (func_ && func_->GetType()) { + if (func_->GetType()->IsFloat32()) { + v = CastToFloat(v); + } else if (func_->GetType()->IsInt32()) { + v = CastToInt(v); + } + } builder_.CreateRet(v); return BlockFlow::Terminated; } diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index 171a4d6..8ff3968 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -1,5 +1,6 @@ #include "mir/MIR.h" +#include #include #include @@ -17,6 +18,43 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function, return function.GetFrameSlot(operand.GetFrameIndex()); } +void PrintMoveImm32(std::ostream& os, PhysReg reg, int imm) { + std::uint32_t u = static_cast(imm); + std::uint32_t lo = u & 0xFFFFu; + std::uint32_t hi = (u >> 16) & 0xFFFFu; + os << " movz " << PhysRegName(reg) << ", #" << lo << "\n"; + if (hi != 0) { + os << " movk " << PhysRegName(reg) << ", #" << hi << ", lsl #16\n"; + } +} + +void PrintStackAdjust(std::ostream& os, const char* mnemonic, int size) { + if (size >= 0 && size <= 4095) { + os << " " << mnemonic << " sp, sp, #" << size << "\n"; + return; + } + PrintMoveImm32(os, PhysReg::X10, size); + os << " " << mnemonic << " sp, sp, x10\n"; +} + +void PrintAddrFromX29(std::ostream& os, PhysReg dst, int offset) { + if (offset >= -4095 && offset <= 4095) { + if (offset >= 0) { + os << " add " << PhysRegName(dst) << ", x29, #" << offset << "\n"; + } else { + os << " sub " << PhysRegName(dst) << ", x29, #" << (-offset) << "\n"; + } + return; + } + + PrintMoveImm32(os, PhysReg::X10, offset < 0 ? -offset : offset); + if (offset >= 0) { + os << " add " << PhysRegName(dst) << ", x29, x10\n"; + } else { + os << " sub " << PhysRegName(dst) << ", x29, x10\n"; + } +} + void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg, int offset) { // AArch64 ldur/stur 只支持 -256..255 的立即数偏移 @@ -25,13 +63,10 @@ void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg, << "]\n"; } else { // 大偏移:使用 x10 作为临时寄存器 - // sub x10, x29, #abs(offset) - // ldr/str reg, [x10] - int abs_offset = -offset; // offset 是负数 bool is_load = (mnemonic[0] == 'l'); // ldur -> ldr const char* base_mnemonic = is_load ? "ldr" : "str"; - os << " sub x10, x29, #" << abs_offset << "\n"; + PrintAddrFromX29(os, PhysReg::X10, offset); os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x10]\n"; } } @@ -42,7 +77,9 @@ void PrintAsm(const MachineModule& module, std::ostream& os) { // 输出全局变量定义 if (!module.GetGlobalVars().empty()) { os << ".data\n"; - for (const auto& [name, init_val, count] : module.GetGlobalVars()) { + for (const auto& [name, init_val, count, is_float, init_elems] : + module.GetGlobalVars()) { + (void)is_float; os << ".global " << name << "\n"; os << ".type " << name << ", %object\n"; os << name << ":\n"; @@ -50,8 +87,20 @@ void PrintAsm(const MachineModule& module, std::ostream& os) { // 标量全局变量 os << " .word " << init_val << "\n"; } else { - // 数组全局变量(全零初始化) - os << " .zero " << (count * 4) << "\n"; + // 数组全局变量:优先输出显式初始化元素,剩余部分补零。 + int emitted = 0; + for (int elem : init_elems) { + if (emitted >= count) { + break; + } + os << " .word " << elem << "\n"; + ++emitted; + } + if (emitted == 0) { + os << " .zero " << (count * 4) << "\n"; + } else if (emitted < count) { + os << " .zero " << ((count - emitted) * 4) << "\n"; + } } } os << "\n"; @@ -80,23 +129,31 @@ void PrintAsm(const MachineModule& module, std::ostream& os) { os << " stp x29, x30, [sp, #-16]!\n"; os << " mov x29, sp\n"; if (function.GetFrameSize() > 0) { - os << " sub sp, sp, #" << function.GetFrameSize() << "\n"; + PrintStackAdjust(os, "sub", function.GetFrameSize()); } break; case Opcode::Epilogue: if (function.GetFrameSize() > 0) { - os << " add sp, sp, #" << function.GetFrameSize() << "\n"; + PrintStackAdjust(os, "add", function.GetFrameSize()); } os << " ldp x29, x30, [sp], #16\n"; break; case Opcode::MovImm: - os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #" - << ops.at(1).GetImm() << "\n"; + PrintMoveImm32(os, ops.at(0).GetReg(), ops.at(1).GetImm()); break; case Opcode::MovReg: os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; break; + case Opcode::FMovImm: + // 通用浮点立即数:先装载 bit pattern,再位级移动到 s 寄存器。 + PrintMoveImm32(os, PhysReg::W10, ops.at(1).GetImm()); + os << " fmov " << PhysRegName(ops.at(0).GetReg()) << ", w10\n"; + break; + case Opcode::FMovReg: + os << " fmov " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + break; case Opcode::LoadStack: { const auto& slot = GetFrameSlot(function, ops.at(1)); PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset); @@ -125,12 +182,7 @@ void PrintAsm(const MachineModule& module, std::ostream& os) { // ops: xN, frame_index // add xN, x29, #offset const auto& slot = GetFrameSlot(function, ops.at(1)); - int offset = slot.offset; - if (offset >= 0) { - os << " add " << PhysRegName(ops.at(0).GetReg()) << ", x29, #" << offset << "\n"; - } else { - os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", x29, #" << (-offset) << "\n"; - } + PrintAddrFromX29(os, ops.at(0).GetReg(), slot.offset); break; } case Opcode::LoadIndirect: { @@ -196,6 +248,34 @@ void PrintAsm(const MachineModule& module, std::ostream& os) { << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; + case Opcode::FAddRR: + os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::FSubRR: + os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::FMulRR: + os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::FDivRR: + os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::SIToFP: + os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + break; + case Opcode::FPToSI: + os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + break; case Opcode::ModRR: // 不应该出现(Mod 在 lowering 时已展开为 div+mul+sub) throw std::runtime_error(FormatError("mir", "ModRR 不应被打印")); @@ -222,6 +302,24 @@ void PrintAsm(const MachineModule& module, std::ostream& os) { << cond_suffix << "\n"; break; } + case Opcode::FCmpRR: { + // ops: dst(wN), lhs(sN), rhs(sN), cmpop(imm) + auto cmp_op = static_cast(ops.at(3).GetImm()); + const char* cond_suffix = ""; + switch (cmp_op) { + case ir::CmpOp::Eq: cond_suffix = "eq"; break; + case ir::CmpOp::Ne: cond_suffix = "ne"; break; + case ir::CmpOp::Lt: cond_suffix = "lt"; break; + case ir::CmpOp::Le: cond_suffix = "le"; break; + case ir::CmpOp::Gt: cond_suffix = "gt"; break; + case ir::CmpOp::Ge: cond_suffix = "ge"; break; + } + os << " fcmp " << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", " + << cond_suffix << "\n"; + break; + } case Opcode::Bl: os << " bl " << ops.at(0).GetSymbol() << "\n"; break; diff --git a/src/mir/FrameLowering.cpp b/src/mir/FrameLowering.cpp index 4b110bf..242b5a9 100644 --- a/src/mir/FrameLowering.cpp +++ b/src/mir/FrameLowering.cpp @@ -18,12 +18,6 @@ void RunFrameLowering(MachineFunction& function) { int cursor = 0; for (const auto& slot : function.GetFrameSlots()) { cursor += slot.size; - // AArch64 ldur/stur 支持 -256 到 +255 的立即数偏移 - // 如果超出范围,需要使用多条指令 - // 这里暂时放宽限制到 4096(单页大小) - if (-cursor < -4096) { - throw std::runtime_error(FormatError("mir", "栈帧超过 4KB,需要更复杂的栈帧处理")); - } } cursor = 0; diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index 14c24ab..3c5fc2c 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -1,5 +1,7 @@ #include "mir/MIR.h" +#include +#include #include #include @@ -23,8 +25,12 @@ struct GepInfo { }; using GepMap = std::unordered_map; -void EmitValueToReg(const ir::Value* value, PhysReg target, - const ValueSlotMap& slots, MachineBasicBlock& block) { +bool IsPointerType(const std::shared_ptr& type) { + return type && (type->IsPtrInt32() || type->IsPtrFloat32()); +} + +void EmitIntValueToReg(const ir::Value* value, PhysReg target, + const ValueSlotMap& slots, MachineBasicBlock& block) { if (auto* constant = dynamic_cast(value)) { block.Append(Opcode::MovImm, {Operand::Reg(target), Operand::Imm(constant->GetValue())}); @@ -48,6 +54,36 @@ void EmitValueToReg(const ir::Value* value, PhysReg target, {Operand::Reg(target), Operand::FrameIndex(it->second)}); } +void EmitFloatValueToReg(const ir::Value* value, PhysReg target, + const ValueSlotMap& slots, MachineBasicBlock& block) { + if (auto* constant = dynamic_cast(value)) { + std::int32_t bits = 0; + float fv = constant->GetValue(); + std::memcpy(&bits, &fv, sizeof(bits)); + block.Append(Opcode::FMovImm, + {Operand::Reg(target), Operand::Imm(static_cast(bits))}); + return; + } + + auto it = slots.find(value); + if (it == slots.end()) { + throw std::runtime_error( + FormatError("mir", "找不到浮点值对应的栈槽: " + value->GetName())); + } + + block.Append(Opcode::LoadStack, + {Operand::Reg(target), Operand::FrameIndex(it->second)}); +} + +void EmitValueToReg(const ir::Value* value, PhysReg target, + const ValueSlotMap& slots, MachineBasicBlock& block) { + if (value->GetType() && value->GetType()->IsFloat32()) { + EmitFloatValueToReg(value, target, slots, block); + return; + } + EmitIntValueToReg(value, target, slots, block); +} + void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, MachineBasicBlock& block, ValueSlotMap& slots, GepMap& geps) { @@ -120,7 +156,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, } // 检查 base 是否是指针参数:如果是 Argument 且类型是指针 - if (dynamic_cast(base) && base->GetType()->IsPtrInt32()) { + if (dynamic_cast(base) && IsPointerType(base->GetType())) { // 指针参数:从栈加载指针值,然后加上索引 if (auto* const_index = dynamic_cast(index)) { // 常量索引 @@ -212,12 +248,15 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, case ir::Opcode::Store: { auto& store = static_cast(inst); auto* ptr = store.GetPtr(); + const bool is_float_value = + store.GetValue()->GetType() && store.GetValue()->GetType()->IsFloat32(); + const PhysReg src_reg = is_float_value ? PhysReg::S0 : PhysReg::W8; // 检查是否是 GEP 结果(数组元素) auto gep_it = geps.find(ptr); if (gep_it != geps.end()) { const auto& gep_info = gep_it->second; - EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block); + EmitValueToReg(store.GetValue(), src_reg, slots, block); if (gep_info.base_slot == -1) { // 全局数组 @@ -233,7 +272,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, Operand::Reg(PhysReg::X10)}); } block.Append(Opcode::StoreIndirect, - {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)}); + {Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)}); } else { // 变量索引:global_array[var_idx] int index_slot = -1 - gep_info.byte_offset; @@ -254,12 +293,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, Operand::Reg(PhysReg::X10)}); // 5. 存储 block.Append(Opcode::StoreIndirect, - {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)}); + {Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)}); } } else if (gep_info.byte_offset >= 0) { // 本地数组,常量索引 block.Append(Opcode::StoreStackOffset, - {Operand::Reg(PhysReg::W8), + {Operand::Reg(src_reg), Operand::FrameIndex(gep_info.base_slot), Operand::Imm(gep_info.byte_offset)}); } else { @@ -278,16 +317,16 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::X10)}); block.Append(Opcode::StoreIndirect, - {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)}); + {Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)}); } return; } // 检查是否是全局变量 if (auto* gv = dynamic_cast(ptr)) { - EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block); + EmitValueToReg(store.GetValue(), src_reg, slots, block); block.Append(Opcode::StoreGlobal, - {Operand::Reg(PhysReg::W8), Operand::Symbol(gv->GetName())}); + {Operand::Reg(src_reg), Operand::Symbol(gv->GetName())}); return; } @@ -298,26 +337,28 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, FormatError("mir", "暂不支持对非栈/全局变量地址进行写入")); } - EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block); + EmitValueToReg(store.GetValue(), src_reg, slots, block); // 检查是否是GEP结果:如果ptr的类型是指针且slot大小是8字节,说明存储的是地址 const auto& dst_slot = function.GetFrameSlot(dst->second); - if (ptr->GetType()->IsPtrInt32() && dst_slot.size == 8) { + if (IsPointerType(ptr->GetType()) && dst_slot.size == 8) { // GEP结果:先加载指针地址,再通过指针存储值 block.Append(Opcode::LoadStack, {Operand::Reg(PhysReg::X9), Operand::FrameIndex(dst->second)}); block.Append(Opcode::StoreIndirect, - {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)}); + {Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)}); } else { // 普通栈变量:直接存储 block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)}); + {Operand::Reg(src_reg), Operand::FrameIndex(dst->second)}); } return; } case ir::Opcode::Load: { auto& load = static_cast(inst); auto* ptr = load.GetPtr(); + const bool is_float_load = load.GetType() && load.GetType()->IsFloat32(); + const PhysReg value_reg = is_float_load ? PhysReg::S0 : PhysReg::W8; // 检查是否是 GEP 结果(数组元素) auto gep_it = geps.find(ptr); @@ -338,7 +379,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, Operand::Reg(PhysReg::X10)}); } block.Append(Opcode::LoadIndirect, - {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)}); + {Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)}); } else { // 变量索引 int index_slot = -1 - gep_info.byte_offset; @@ -354,12 +395,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::X10)}); block.Append(Opcode::LoadIndirect, - {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)}); + {Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)}); } } else if (gep_info.byte_offset >= 0) { // 本地数组,常量索引 block.Append(Opcode::LoadStackOffset, - {Operand::Reg(PhysReg::W8), + {Operand::Reg(value_reg), Operand::FrameIndex(gep_info.base_slot), Operand::Imm(gep_info.byte_offset)}); } else { @@ -378,11 +419,11 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::X10)}); block.Append(Opcode::LoadIndirect, - {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)}); + {Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)}); } block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + {Operand::Reg(value_reg), Operand::FrameIndex(dst_slot)}); slots.emplace(&inst, dst_slot); return; } @@ -391,9 +432,9 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, if (auto* gv = dynamic_cast(ptr)) { int dst_slot = function.CreateFrameIndex(); block.Append(Opcode::LoadGlobal, - {Operand::Reg(PhysReg::W8), Operand::Symbol(gv->GetName())}); + {Operand::Reg(value_reg), Operand::Symbol(gv->GetName())}); block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + {Operand::Reg(value_reg), Operand::FrameIndex(dst_slot)}); slots.emplace(&inst, dst_slot); return; } @@ -409,72 +450,112 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, // 检查是否是GEP结果:如果ptr的类型是指针且slot大小是8字节,说明存储的是地址 const auto& src_slot = function.GetFrameSlot(src->second); - if (ptr->GetType()->IsPtrInt32() && src_slot.size == 8) { + if (IsPointerType(ptr->GetType()) && src_slot.size == 8) { // GEP结果:先加载指针地址,再通过指针加载值 block.Append(Opcode::LoadStack, {Operand::Reg(PhysReg::X9), Operand::FrameIndex(src->second)}); block.Append(Opcode::LoadIndirect, - {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)}); + {Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)}); } else { // 普通栈变量:直接加载 block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)}); + {Operand::Reg(value_reg), Operand::FrameIndex(src->second)}); } block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + {Operand::Reg(value_reg), Operand::FrameIndex(dst_slot)}); slots.emplace(&inst, dst_slot); return; } case ir::Opcode::Add: { auto& bin = static_cast(inst); int dst_slot = function.CreateFrameIndex(); - EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); - EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + if (bin.GetType()->IsFloat32()) { + EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block); + block.Append(Opcode::FAddRR, {Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S1)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); + } else { + EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + } slots.emplace(&inst, dst_slot); return; } case ir::Opcode::Sub: { auto& bin = static_cast(inst); int dst_slot = function.CreateFrameIndex(); - EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); - EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); - block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + if (bin.GetType()->IsFloat32()) { + EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block); + block.Append(Opcode::FSubRR, {Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S1)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); + } else { + EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); + block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + } slots.emplace(&inst, dst_slot); return; } case ir::Opcode::Mul: { auto& bin = static_cast(inst); int dst_slot = function.CreateFrameIndex(); - EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); - EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); - block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + if (bin.GetType()->IsFloat32()) { + EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block); + block.Append(Opcode::FMulRR, {Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S1)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); + } else { + EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); + block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + } slots.emplace(&inst, dst_slot); return; } case ir::Opcode::Div: { auto& bin = static_cast(inst); int dst_slot = function.CreateFrameIndex(); - EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); - EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); - block.Append(Opcode::DivRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + if (bin.GetType()->IsFloat32()) { + EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block); + block.Append(Opcode::FDivRR, {Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S1)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); + } else { + EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); + block.Append(Opcode::DivRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + } slots.emplace(&inst, dst_slot); return; } @@ -502,23 +583,53 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, case ir::Opcode::Cmp: { auto& cmp = static_cast(inst); int dst_slot = function.CreateFrameIndex(); - EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block); - EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block); - // cmp 操作符通过 operand 传递 - block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9), - Operand::Imm(static_cast(cmp.GetCmpOp()))}); + if (cmp.GetLhs()->GetType()->IsFloat32()) { + EmitValueToReg(cmp.GetLhs(), PhysReg::S0, slots, block); + EmitValueToReg(cmp.GetRhs(), PhysReg::S1, slots, block); + block.Append(Opcode::FCmpRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S1), + Operand::Imm(static_cast(cmp.GetCmpOp()))}); + } else { + EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block); + EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block); + // cmp 操作符通过 operand 传递 + block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9), + Operand::Imm(static_cast(cmp.GetCmpOp()))}); + } block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); slots.emplace(&inst, dst_slot); return; } + case ir::Opcode::Cast: { + auto& cast = static_cast(inst); + int dst_slot = function.CreateFrameIndex(); + if (cast.GetCastOp() == ir::CastOp::IntToFloat) { + EmitValueToReg(cast.GetValue(), PhysReg::W8, slots, block); + block.Append(Opcode::SIToFP, + {Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::W8)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); + } else { + EmitValueToReg(cast.GetValue(), PhysReg::S0, slots, block); + block.Append(Opcode::FPToSI, + {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::S0)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + } + slots.emplace(&inst, dst_slot); + return; + } case ir::Opcode::Ret: { auto& ret = static_cast(inst); if (ret.GetValue()) { // int/float 返回值 - EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block); + PhysReg ret_reg = ret.GetValue()->GetType()->IsFloat32() ? PhysReg::S0 + : PhysReg::W0; + EmitValueToReg(ret.GetValue(), ret_reg, slots, block); } // void 返回:不设置 w0 block.Append(Opcode::Ret); @@ -531,7 +642,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, throw std::runtime_error(FormatError("mir", "Call 指令缺少被调用函数")); } - // 参数传递:根据类型使用 w0-w7(整数)或 x0-x7(指针) + // 参数传递:根据类型使用 w0-w7(整数)、s0-s7(浮点)或 x0-x7(指针) size_t num_args = call.GetNumArgs(); if (num_args > 8) { throw std::runtime_error(FormatError("mir", "暂不支持超过 8 个参数的函数调用")); @@ -540,8 +651,10 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, const auto& param_types = callee->GetParamTypes(); for (size_t i = 0; i < num_args; i++) { auto* arg_value = call.GetArg(i); - // 检查参数类型是否是指针 - bool is_ptr = (i < param_types.size() && param_types[i]->IsPtrInt32()); + bool is_ptr = + (i < param_types.size() && + (param_types[i]->IsPtrInt32() || param_types[i]->IsPtrFloat32())); + bool is_float = (i < param_types.size() && param_types[i]->IsFloat32()); if (is_ptr) { // 指针参数:加载到 x 寄存器 @@ -564,8 +677,10 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, FormatError("mir", "找不到指针参数的值: " + arg_value->GetName())); } } else { - // 整数参数:加载到 w 寄存器 - PhysReg arg_reg = static_cast(static_cast(PhysReg::W0) + i); + // 标量参数:整数用 w,浮点用 s + PhysReg arg_reg = is_float + ? static_cast(static_cast(PhysReg::S0) + i) + : static_cast(static_cast(PhysReg::W0) + i); EmitValueToReg(arg_value, arg_reg, slots, block); } } @@ -576,8 +691,9 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, // 处理返回值 if (!call.GetType()->IsVoid()) { int dst_slot = function.CreateFrameIndex(); + PhysReg ret_reg = call.GetType()->IsFloat32() ? PhysReg::S0 : PhysReg::W0; block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W0), Operand::FrameIndex(dst_slot)}); + {Operand::Reg(ret_reg), Operand::FrameIndex(dst_slot)}); slots.emplace(&inst, dst_slot); } return; @@ -597,7 +713,8 @@ std::unique_ptr LowerToMIR(const ir::Module& module) { // 复制全局变量信息 for (const auto& gv_ptr : module.GetGlobalVars()) { const auto& gv = *gv_ptr; - machine_module->AddGlobalVar(gv.GetName(), gv.GetInitValue(), gv.GetCount()); + machine_module->AddGlobalVar(gv.GetName(), gv.GetInitValue(), gv.GetCount(), + gv.IsFloat(), gv.GetInitElements()); } for (const auto& func_ptr : module.GetFunctions()) { @@ -632,15 +749,18 @@ std::unique_ptr LowerToMIR(const ir::Module& module) { auto& entry_block = machine_func->GetEntry(); for (size_t i = 0; i < num_params; i++) { auto* arg = func.GetArgument(i); - bool is_ptr = arg->GetType()->IsPtrInt32(); + bool is_ptr = arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat32(); + bool is_float = arg->GetType()->IsFloat32(); int slot_size = is_ptr ? 8 : 4; // 指针 8 字节,整数 4 字节 int slot = machine_func->CreateFrameIndex(slot_size); slots.emplace(arg, slot); - // 根据参数类型选择寄存器:指针用 x0-x7,整数用 w0-w7 + // 根据参数类型选择寄存器:指针用 x0-x7,整数用 w0-w7,浮点用 s0-s7 PhysReg param_reg; if (is_ptr) { param_reg = static_cast(static_cast(PhysReg::X0) + i); + } else if (is_float) { + param_reg = static_cast(static_cast(PhysReg::S0) + i); } else { param_reg = static_cast(static_cast(PhysReg::W0) + i); } diff --git a/src/mir/MIRFunction.cpp b/src/mir/MIRFunction.cpp index 4ac1036..c4f6f34 100644 --- a/src/mir/MIRFunction.cpp +++ b/src/mir/MIRFunction.cpp @@ -52,8 +52,10 @@ MachineFunction* MachineModule::CreateFunction(std::string name) { return functions_.back().get(); } -void MachineModule::AddGlobalVar(std::string name, int init_val, int count) { - global_vars_.emplace_back(std::move(name), init_val, count); +void MachineModule::AddGlobalVar(std::string name, int init_val, int count, + bool is_float, std::vector init_elems) { + global_vars_.emplace_back(std::move(name), init_val, count, is_float, + std::move(init_elems)); } } // namespace mir