From 346a9c40990f21ede4758e3542a3f982f9c4e58f Mon Sep 17 00:00:00 2001 From: Shrink <1569629152@qq.com> Date: Fri, 24 Apr 2026 01:24:01 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=B5=AE=E7=82=B9?= =?UTF-8?q?=E6=AF=94=E8=BE=83=E5=AF=B9=20NaN=20=E7=9A=84=E9=94=99=E8=AF=AF?= =?UTF-8?q?=E5=A4=84=E7=90=86=EF=BC=88IEEE=20754=20=E5=90=88=E8=A7=84?= =?UTF-8?q?=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 问题描述: ARM64 的浮点比较指令 fcmp 后,使用标准条件码(如 b.lt)对 NaN 操作数会产生错误结果。例如 NaN < -1e-6 错误地返回 true,导致 my_sqrt(NaN) 陷入死循环,vector_mul3.sy 测试无法退出。 根本原因: - fcmp NaN, x 设置标志位 NZCV = 0011 (N=0, Z=0, C=1, V=1) - b.lt 条件为 N!=V,对 NaN 为 0!=1 = true ✗ 错误 - IEEE 754 要求 NaN 与任何数比较(除!=)都应返回 false 修复方案: 1. 添加 FBcond 指令(浮点条件分支) 2. 新增 FloatCondSuffix() 函数返回 IEEE 754 兼容的条件码: - Lt: lt → mi (N==1,对 NaN 返回 false ✓) - Le: le → ls (!(C==1 && Z==0),对 NaN 返回 false ✓) - Gt/Ge/Eq/Ne: 保持不变(已正确处理 NaN) 3. 在浮点比较后使用 FBcond 而不是 Bcond 4. FCmpRR (cset) 也使用 FloatCondSuffix 测试结果: ✓ NaN < -1e-6 正确返回 false(之前错误返回 true) ✓ vector_mul3.sy 正常退出(之前死循环) ✓ my_sqrt(NaN) 不再陷入无限循环 符合标准: 此修复使编译器生成的浮点比较代码完全符合 IEEE 754 标准。 Co-Authored-By: Claude Sonnet 4.5 --- include/mir/MIR.h | 1 + src/mir/AsmPrinter.cpp | 26 +++++++++++++++++++++++++- src/mir/Lowering.cpp | 6 ++++-- 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/include/mir/MIR.h b/include/mir/MIR.h index 56b7547..39c15e2 100644 --- a/include/mir/MIR.h +++ b/include/mir/MIR.h @@ -71,6 +71,7 @@ enum class Opcode { Bl, B, // 无条件跳转 Bcond, // 条件跳转(基于之前的 cmp) + FBcond, // 浮点条件跳转(基于之前的 fcmp,使用 IEEE 754 兼容的条件码) Cbnz, // 非零跳转 Cbz, // 零跳转 Ret, diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index 89e9b73..424d85f 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -90,6 +90,25 @@ const char* CondSuffix(ir::CmpOp cmp_op) { return "eq"; } +// 浮点比较使用 IEEE 754 兼容的条件码(正确处理 NaN) +const char* FloatCondSuffix(ir::CmpOp cmp_op) { + switch (cmp_op) { + case ir::CmpOp::Eq: + return "eq"; // Z==1 + case ir::CmpOp::Ne: + return "ne"; // Z==0 + case ir::CmpOp::Lt: + return "mi"; // N==1 (minus, 正确处理 NaN) + case ir::CmpOp::Le: + return "ls"; // !(C==1 && Z==0) (lower or same, 正确处理 NaN) + case ir::CmpOp::Gt: + return "gt"; // Z==0 && N==V (已正确处理 NaN) + case ir::CmpOp::Ge: + return "ge"; // N==V (已正确处理 NaN) + } + return "eq"; +} + } // namespace void PrintAsm(const MachineModule& module, std::ostream& os) { @@ -350,7 +369,7 @@ void PrintAsm(const MachineModule& module, std::ostream& os) { os << " fcmp " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", " - << CondSuffix(cmp_op) << "\n"; + << FloatCondSuffix(cmp_op) << "\n"; break; } case Opcode::Bl: @@ -372,6 +391,11 @@ void PrintAsm(const MachineModule& module, std::ostream& os) { os << " b." << CondSuffix(static_cast(ops.at(1).GetImm())) << " ." << ops.at(0).GetSymbol() << "\n"; break; + case Opcode::FBcond: + // ops: symbol, cmpop(imm) - 浮点条件分支 + os << " b." << FloatCondSuffix(static_cast(ops.at(1).GetImm())) + << " ." << ops.at(0).GetSymbol() << "\n"; + break; case Opcode::Ret: os << " ret\n"; break; diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index f85fc44..75b1171 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -966,7 +966,8 @@ std::unique_ptr LowerToMIR(const ir::Module& module) { auto* true_mbb = block_map[next_cbr->GetTrueBlock()]; auto* false_mbb = block_map[next_cbr->GetFalseBlock()]; - if (cmp_inst->GetLhs()->GetType()->IsFloat32()) { + bool is_float_cmp = cmp_inst->GetLhs()->GetType()->IsFloat32(); + if (is_float_cmp) { EmitValueToReg(cmp_inst->GetLhs(), PhysReg::S0, slots, *current_mbb); EmitValueToReg(cmp_inst->GetRhs(), PhysReg::S1, slots, *current_mbb); current_mbb->Append( @@ -980,8 +981,9 @@ std::unique_ptr LowerToMIR(const ir::Module& module) { {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); } + // 浮点比较使用 FBcond(IEEE 754 兼容),整数比较使用 Bcond current_mbb->Append( - Opcode::Bcond, + is_float_cmp ? Opcode::FBcond : Opcode::Bcond, {Operand::Symbol(true_mbb->GetName()), Operand::Imm(static_cast(cmp_inst->GetCmpOp()))}); current_mbb->Append(Opcode::B,