diff --git a/include/mir/MIR.h b/include/mir/MIR.h index 56b7547..39c15e2 100644 --- a/include/mir/MIR.h +++ b/include/mir/MIR.h @@ -71,6 +71,7 @@ enum class Opcode { Bl, B, // 无条件跳转 Bcond, // 条件跳转(基于之前的 cmp) + FBcond, // 浮点条件跳转(基于之前的 fcmp,使用 IEEE 754 兼容的条件码) Cbnz, // 非零跳转 Cbz, // 零跳转 Ret, diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index 89e9b73..424d85f 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -90,6 +90,25 @@ const char* CondSuffix(ir::CmpOp cmp_op) { return "eq"; } +// 浮点比较使用 IEEE 754 兼容的条件码(正确处理 NaN) +const char* FloatCondSuffix(ir::CmpOp cmp_op) { + switch (cmp_op) { + case ir::CmpOp::Eq: + return "eq"; // Z==1 + case ir::CmpOp::Ne: + return "ne"; // Z==0 + case ir::CmpOp::Lt: + return "mi"; // N==1 (minus, 正确处理 NaN) + case ir::CmpOp::Le: + return "ls"; // !(C==1 && Z==0) (lower or same, 正确处理 NaN) + case ir::CmpOp::Gt: + return "gt"; // Z==0 && N==V (已正确处理 NaN) + case ir::CmpOp::Ge: + return "ge"; // N==V (已正确处理 NaN) + } + return "eq"; +} + } // namespace void PrintAsm(const MachineModule& module, std::ostream& os) { @@ -350,7 +369,7 @@ void PrintAsm(const MachineModule& module, std::ostream& os) { os << " fcmp " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", " - << CondSuffix(cmp_op) << "\n"; + << FloatCondSuffix(cmp_op) << "\n"; break; } case Opcode::Bl: @@ -372,6 +391,11 @@ void PrintAsm(const MachineModule& module, std::ostream& os) { os << " b." << CondSuffix(static_cast(ops.at(1).GetImm())) << " ." << ops.at(0).GetSymbol() << "\n"; break; + case Opcode::FBcond: + // ops: symbol, cmpop(imm) - 浮点条件分支 + os << " b." << FloatCondSuffix(static_cast(ops.at(1).GetImm())) + << " ." << ops.at(0).GetSymbol() << "\n"; + break; case Opcode::Ret: os << " ret\n"; break; diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index f85fc44..75b1171 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -966,7 +966,8 @@ std::unique_ptr LowerToMIR(const ir::Module& module) { auto* true_mbb = block_map[next_cbr->GetTrueBlock()]; auto* false_mbb = block_map[next_cbr->GetFalseBlock()]; - if (cmp_inst->GetLhs()->GetType()->IsFloat32()) { + bool is_float_cmp = cmp_inst->GetLhs()->GetType()->IsFloat32(); + if (is_float_cmp) { EmitValueToReg(cmp_inst->GetLhs(), PhysReg::S0, slots, *current_mbb); EmitValueToReg(cmp_inst->GetRhs(), PhysReg::S1, slots, *current_mbb); current_mbb->Append( @@ -980,8 +981,9 @@ std::unique_ptr LowerToMIR(const ir::Module& module) { {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); } + // 浮点比较使用 FBcond(IEEE 754 兼容),整数比较使用 Bcond current_mbb->Append( - Opcode::Bcond, + is_float_cmp ? Opcode::FBcond : Opcode::Bcond, {Operand::Symbol(true_mbb->GetName()), Operand::Imm(static_cast(cmp_inst->GetCmpOp()))}); current_mbb->Append(Opcode::B,