diff --git a/src/mir/passes/Peephole.cpp b/src/mir/passes/Peephole.cpp index 44cf3025..a46ff917 100644 --- a/src/mir/passes/Peephole.cpp +++ b/src/mir/passes/Peephole.cpp @@ -9,6 +9,20 @@ namespace mir namespace { + static CondCode InvertCondCode(CondCode cc) + { + switch (cc) + { + case CondCode::EQ: return CondCode::NE; + case CondCode::NE: return CondCode::EQ; + case CondCode::LT: return CondCode::GE; + case CondCode::LE: return CondCode::GT; + case CondCode::GT: return CondCode::LE; + case CondCode::GE: return CondCode::LT; + } + return cc; + } + static bool IsSamePhysReg(PhysReg a, PhysReg b) { int an = static_cast(a); @@ -294,22 +308,51 @@ namespace mir } // 分支 fallthrough: 末尾 Br 的目标是紧邻下一个块 → 删除 Br + // CondBr + Br 模式:CondBr 条件反转使 fallthrough 对齐 if (!insts.empty()) { - auto &last = insts.back(); - if (last.GetOpcode() == Opcode::Br && - last.GetOperands().size() >= 1 && - last.GetOperands()[0].GetKind() == Operand::Kind::Label) + const auto &blocks = function.GetBlocks(); + int my_idx = -1; + for (size_t bi = 0; bi < blocks.size(); ++bi) + { + if (blocks[bi].get() == &block) { my_idx = static_cast(bi); break; } + } + int next_label = (my_idx >= 0 && my_idx + 1 < static_cast(blocks.size())) + ? blocks[my_idx + 1]->GetLabelId() + : -1; + + if (next_label >= 0) { - int target_label = last.GetOperands()[0].GetLabel(); - const auto &blocks = function.GetBlocks(); - for (size_t bi = 0; bi < blocks.size(); ++bi) + // CondBr + Br 模式 + if (insts.size() >= 2) { - if (blocks[bi].get() == &block && bi + 1 < blocks.size()) + auto br_it = insts.end() - 1; + auto cond_it = insts.end() - 2; + if (br_it->GetOpcode() == Opcode::Br && + cond_it->GetOpcode() == Opcode::CondBr && + br_it->GetOperands().size() >= 1 && + br_it->GetOperands()[0].GetKind() == Operand::Kind::Label && + cond_it->GetOperands().size() >= 2 && + cond_it->GetOperands()[1].GetKind() == Operand::Kind::Label) { - if (blocks[bi + 1]->GetLabelId() == target_label) + int cond_target = cond_it->GetOperands()[1].GetLabel(); + int br_target = br_it->GetOperands()[0].GetLabel(); + + if (cond_target == next_label) + { + // CondBr 目标已是 fallthrough → 反转条件,交换目标 + CondCode old_cc = static_cast(cond_it->GetOperands()[0].GetImm()); + CondCode new_cc = InvertCondCode(old_cc); + const_cast(cond_it->GetOperands()[0]) = Operand::Imm(static_cast(new_cc)); + const_cast(cond_it->GetOperands()[1]) = + Operand::Label(br_target); + insts.pop_back(); // 删除 Br + } + else if (br_target == next_label) + { + // Br 目标已是 fallthrough → 直接删除 Br insts.pop_back(); - break; + } } } }