// IfConversion: 将简单 if-else diamond 转换为算术 select // - 扫描 CondBr→T→Br→M 且 F==M 的 diamond 模式 // - 安全检查:T 必须只有单一前驱(B),仅允许纯算术指令(禁 Div/Mod/浮点) // - 将 phi 转换为 fv + (tv-fv)*zext(cond) // - 配合 CFGSimplify 清理空块,使循环体变为单 BB → 可被 LoopUnroll 展开 #include "ir/IR.h" #include #include namespace ir { namespace { static Value* UnwrapCondition(Value* cond) { for (int pass = 0; pass < 2; ++pass) { auto* outer = dynamic_cast(cond); if (!outer || outer->GetOpcode() != Opcode::Ne) break; auto* rc = dynamic_cast(outer->GetRhs()); if (!rc || rc->GetValue() != 0) break; auto* zext = dynamic_cast(outer->GetLhs()); if (!zext || zext->GetOpcode() != Opcode::ZExt) break; cond = zext->GetOperandValue(); } return cond; } static BasicBlock* GetOnlyBrTarget(BasicBlock* bb) { const auto& insts = bb->GetInstructions(); if (insts.empty()) return nullptr; auto* br = dynamic_cast(insts.back().get()); return br ? br->GetTarget() : nullptr; } static std::vector ComputePredecessors( BasicBlock* bb, const std::vector>& all_blocks) { std::vector preds; for (const auto& other : all_blocks) { if (other.get() == bb) continue; const auto& insts = other->GetInstructions(); if (insts.empty()) continue; auto* term = insts.back().get(); if (auto* br = dynamic_cast(term)) { if (br->GetTarget() == bb) preds.push_back(other.get()); } else if (auto* cbr = dynamic_cast(term)) { if (cbr->GetTrueTarget() == bb || cbr->GetFalseTarget() == bb) preds.push_back(other.get()); } } return preds; } static bool IsSimpleBlock(BasicBlock* bb) { for (const auto& inst : bb->GetInstructions()) { switch (inst->GetOpcode()) { case Opcode::Add: case Opcode::Sub: case Opcode::Mul: case Opcode::And: case Opcode::Or: case Opcode::Eq: case Opcode::Ne: case Opcode::Lt: case Opcode::Le: case Opcode::Gt: case Opcode::Ge: case Opcode::ZExt: case Opcode::Br: continue; default: return false; } } return true; } static Value* GetPhiValueFrom(PhiInst* phi, BasicBlock* bb) { for (size_t i = 0; i < phi->GetNumOperands(); i += 2) { if (dynamic_cast(phi->GetOperand(i + 1)) == bb) return phi->GetOperand(i); } return nullptr; } static void RemovePhiEntriesFrom(PhiInst* phi, BasicBlock* bb) { std::vector> keep; for (size_t i = 0; i < phi->GetNumOperands(); i += 2) { auto* pred = dynamic_cast(phi->GetOperand(i + 1)); if (pred != bb) keep.push_back({phi->GetOperand(i), phi->GetOperand(i + 1)}); } if (keep.size() * 2 != phi->GetNumOperands()) { phi->ClearOperands(); for (auto& [val, pred] : keep) { phi->AddOperand(val); phi->AddOperand(pred); } } } static void SetPhiEntry(PhiInst* phi, BasicBlock* bb, Value* val) { for (size_t i = 0; i < phi->GetNumOperands(); i += 2) { if (dynamic_cast(phi->GetOperand(i + 1)) == bb) { phi->SetOperand(i, val); return; } } phi->AddOperand(val); phi->AddOperand(bb); } static bool TryConvertOneDiamond(BasicBlock* B, BasicBlock* T, BasicBlock* M, Value* cond_i1, Context& ctx, const std::vector>& all_blocks) { if (!IsSimpleBlock(T)) return false; if (GetOnlyBrTarget(T) != M) return false; auto t_preds = ComputePredecessors(T, all_blocks); if (t_preds.size() != 1 || t_preds[0] != B) return false; struct PhiEntry { PhiInst* phi; Value* val_t; Value* val_f; }; std::vector to_convert; for (const auto& inst : M->GetInstructions()) { auto* phi = dynamic_cast(inst.get()); if (!phi) break; Value* val_t = GetPhiValueFrom(phi, T); if (!val_t) continue; Value* val_f = GetPhiValueFrom(phi, B); if (!val_f) { for (size_t i = 0; i < phi->GetNumOperands(); i += 2) { auto* pred = dynamic_cast(phi->GetOperand(i + 1)); if (pred != T) { val_f = phi->GetOperand(i); break; } } } if (!val_f) continue; to_convert.push_back({phi, val_t, val_f}); } if (to_convert.empty()) return false; auto* cbr = B->GetInstructions().back().get(); B->TakeInstruction(cbr); auto& t_insts = const_cast>&>(T->GetInstructions()); std::vector t_to_move; for (const auto& inst : t_insts) if (inst->GetOpcode() != Opcode::Br) t_to_move.push_back(inst.get()); for (auto* inst : t_to_move) { auto taken = T->TakeInstruction(inst); B->InsertInstructionBeforeTerminator(std::move(taken)); } if (!T->GetInstructions().empty()) T->TakeInstruction(T->GetInstructions().back().get()); for (auto& [phi, val_t, val_f] : to_convert) { if (val_t == val_f) { RemovePhiEntriesFrom(phi, T); SetPhiEntry(phi, B, val_f); continue; } auto* zext = B->Append(Opcode::ZExt, Type::GetInt32Type(), cond_i1, ctx.NextTemp()); auto* diff = B->Append(Opcode::Sub, Type::GetInt32Type(), val_t, val_f, ctx.NextTemp()); auto* masked = B->Append(Opcode::Mul, Type::GetInt32Type(), diff, zext, ctx.NextTemp()); auto* select_val = B->Append(Opcode::Add, Type::GetInt32Type(), val_f, masked, ctx.NextTemp()); RemovePhiEntriesFrom(phi, T); SetPhiEntry(phi, B, select_val); } B->Append(Type::GetVoidType(), M); return true; } static void IfConvertFunction(Function* func, Context& ctx) { auto& blocks = const_cast>&>(func->GetBlocks()); bool changed = true; while (changed) { changed = false; for (const auto& bb : blocks) { const auto& insts = bb->GetInstructions(); if (insts.empty()) continue; auto* cbr = dynamic_cast(insts.back().get()); if (!cbr) continue; BasicBlock* T = cbr->GetTrueTarget(); BasicBlock* F = cbr->GetFalseTarget(); Value* cond = UnwrapCondition(cbr->GetCond()); if (TryConvertOneDiamond(bb.get(), T, F, cond, ctx, blocks)) { changed = true; break; } } } } static void CleanupRedundantPhis(Function* func) { for (const auto& bb : func->GetBlocks()) { auto& insts = const_cast>&>(bb->GetInstructions()); for (size_t i = 0; i < insts.size(); ) { auto* phi = dynamic_cast(insts[i].get()); if (!phi) break; Value* unique_val = nullptr; bool all_same = true; for (size_t j = 0; j < phi->GetNumOperands(); j += 2) { Value* v = phi->GetOperand(j); if (!unique_val) unique_val = v; else if (unique_val != v) { all_same = false; break; } } if (all_same && unique_val) { phi->ReplaceAllUsesWith(unique_val); phi->ClearOperands(); insts.erase(insts.begin() + i); continue; } ++i; } } } static void MergeSinglePredBlocks(Function* func) { auto& blocks = const_cast>&>(func->GetBlocks()); bool changed = true; while (changed) { changed = false; for (auto& bb_ptr : blocks) { BasicBlock* bb = bb_ptr.get(); if (bb == func->GetEntry()) continue; bool has_phi = false; for (const auto& inst : bb->GetInstructions()) { if (dynamic_cast(inst.get())) { has_phi = true; break; } } if (has_phi) continue; auto preds = ComputePredecessors(bb, blocks); if (preds.size() != 1) continue; BasicBlock* pred = preds[0]; if (pred == bb) continue; const auto& pred_insts = pred->GetInstructions(); if (pred_insts.empty()) continue; auto* br = dynamic_cast(pred_insts.back().get()); if (!br || br->GetTarget() != bb) continue; pred->TakeInstruction(pred_insts.back().get()); auto& bb_insts = const_cast>&>(bb->GetInstructions()); std::vector to_move; for (auto& inst : bb_insts) to_move.push_back(inst.get()); for (auto* inst : to_move) { auto taken = bb->TakeInstruction(inst); pred->InsertInstructionBeforeTerminator(std::move(taken)); } for (auto& other : blocks) { if (other.get() == bb) continue; auto& o_insts = const_cast>&>(other->GetInstructions()); for (auto& inst : o_insts) { auto* phi = dynamic_cast(inst.get()); if (!phi) break; for (size_t i = 0; i < phi->GetNumOperands(); i += 2) { if (dynamic_cast(phi->GetOperand(i + 1)) == bb) phi->SetOperand(i + 1, pred); } } } changed = true; break; } } } } // namespace void RunIfConversion(Module& module) { for (auto& func : module.GetFunctions()) { if (func->IsExternal()) continue; IfConvertFunction(func.get(), module.GetContext()); CleanupRedundantPhis(func.get()); MergeSinglePredBlocks(func.get()); } } } // namespace ir