diff --git a/src/include/ir/passes/PassManager.h b/src/include/ir/passes/PassManager.h index 3557e7ad..5e3f3cf6 100644 --- a/src/include/ir/passes/PassManager.h +++ b/src/include/ir/passes/PassManager.h @@ -11,6 +11,7 @@ namespace ir { void RunMem2Reg(Module& module); void RunLICM(Module* module); void RunInline(Module& module); +void RunIfConversion(Module& module); void RunLoopUnroll(Module& module); void RunConstFold(Module& module); void RunConstProp(Module& module); @@ -27,10 +28,13 @@ class PassManager { RunMem2Reg(*module); - RunInline(*module); + RunIfConversion(*module); + RunCFGSimplify(*module); RunLoopUnroll(*module); + RunInline(*module); + RunLICM(module); bool changed = true; diff --git a/src/ir/passes/CMakeLists.txt b/src/ir/passes/CMakeLists.txt index a503d88f..ff16d1a1 100644 --- a/src/ir/passes/CMakeLists.txt +++ b/src/ir/passes/CMakeLists.txt @@ -6,6 +6,7 @@ add_library(ir_passes STATIC DCE.cpp CFGSimplify.cpp Inline.cpp + IfConversion.cpp LoopUnroll.cpp IRVerifier.cpp ) diff --git a/src/ir/passes/IfConversion.cpp b/src/ir/passes/IfConversion.cpp new file mode 100644 index 00000000..c4405d40 --- /dev/null +++ b/src/ir/passes/IfConversion.cpp @@ -0,0 +1,270 @@ +// IfConversion: 将简单 if-else diamond 转换为算术 select +// - 扫描 CondBr→T→Br→M 且 F==M 的 diamond 模式 +// - 安全检查:T 必须只有单一前驱(B),仅允许纯算术指令(禁 Div/Mod/浮点) +// - 将 phi 转换为 fv + (tv-fv)*zext(cond) +// - 配合 CFGSimplify 清理空块,使循环体变为单 BB → 可被 LoopUnroll 展开 + +#include "ir/IR.h" + +#include +#include + +namespace ir { + +namespace { + +static Value* UnwrapCondition(Value* cond) { + for (int pass = 0; pass < 2; ++pass) { + auto* outer = dynamic_cast(cond); + if (!outer || outer->GetOpcode() != Opcode::Ne) break; + auto* rc = dynamic_cast(outer->GetRhs()); + if (!rc || rc->GetValue() != 0) break; + auto* zext = dynamic_cast(outer->GetLhs()); + if (!zext || zext->GetOpcode() != Opcode::ZExt) break; + cond = zext->GetOperandValue(); + } + return cond; +} + +static BasicBlock* GetOnlyBrTarget(BasicBlock* bb) { + const auto& insts = bb->GetInstructions(); + if (insts.empty()) return nullptr; + auto* br = dynamic_cast(insts.back().get()); + return br ? br->GetTarget() : nullptr; +} + +static std::vector ComputePredecessors( + BasicBlock* bb, const std::vector>& all_blocks) { + std::vector preds; + for (const auto& other : all_blocks) { + if (other.get() == bb) continue; + const auto& insts = other->GetInstructions(); + if (insts.empty()) continue; + auto* term = insts.back().get(); + if (auto* br = dynamic_cast(term)) { + if (br->GetTarget() == bb) preds.push_back(other.get()); + } else if (auto* cbr = dynamic_cast(term)) { + if (cbr->GetTrueTarget() == bb || cbr->GetFalseTarget() == bb) + preds.push_back(other.get()); + } + } + return preds; +} + +static bool IsSimpleBlock(BasicBlock* bb) { + for (const auto& inst : bb->GetInstructions()) { + switch (inst->GetOpcode()) { + case Opcode::Add: case Opcode::Sub: case Opcode::Mul: + case Opcode::And: case Opcode::Or: + case Opcode::Eq: case Opcode::Ne: case Opcode::Lt: + case Opcode::Le: case Opcode::Gt: case Opcode::Ge: + case Opcode::ZExt: + case Opcode::Br: + continue; + default: + return false; + } + } + return true; +} + +static Value* GetPhiValueFrom(PhiInst* phi, BasicBlock* bb) { + for (size_t i = 0; i < phi->GetNumOperands(); i += 2) { + if (dynamic_cast(phi->GetOperand(i + 1)) == bb) + return phi->GetOperand(i); + } + return nullptr; +} + +static void RemovePhiEntriesFrom(PhiInst* phi, BasicBlock* bb) { + std::vector> keep; + for (size_t i = 0; i < phi->GetNumOperands(); i += 2) { + auto* pred = dynamic_cast(phi->GetOperand(i + 1)); + if (pred != bb) + keep.push_back({phi->GetOperand(i), phi->GetOperand(i + 1)}); + } + if (keep.size() * 2 != phi->GetNumOperands()) { + phi->ClearOperands(); + for (auto& [val, pred] : keep) { + phi->AddOperand(val); + phi->AddOperand(pred); + } + } +} + +static void SetPhiEntry(PhiInst* phi, BasicBlock* bb, Value* val) { + for (size_t i = 0; i < phi->GetNumOperands(); i += 2) { + if (dynamic_cast(phi->GetOperand(i + 1)) == bb) { + phi->SetOperand(i, val); + return; + } + } + phi->AddOperand(val); + phi->AddOperand(bb); +} + +static bool TryConvertOneDiamond(BasicBlock* B, BasicBlock* T, BasicBlock* M, + Value* cond_i1, Context& ctx, + const std::vector>& all_blocks) { + if (!IsSimpleBlock(T)) return false; + if (GetOnlyBrTarget(T) != M) return false; + auto t_preds = ComputePredecessors(T, all_blocks); + if (t_preds.size() != 1 || t_preds[0] != B) return false; + + struct PhiEntry { PhiInst* phi; Value* val_t; Value* val_f; }; + std::vector to_convert; + for (const auto& inst : M->GetInstructions()) { + auto* phi = dynamic_cast(inst.get()); + if (!phi) break; + Value* val_t = GetPhiValueFrom(phi, T); + if (!val_t) continue; + Value* val_f = GetPhiValueFrom(phi, B); + if (!val_f) { + for (size_t i = 0; i < phi->GetNumOperands(); i += 2) { + auto* pred = dynamic_cast(phi->GetOperand(i + 1)); + if (pred != T) { val_f = phi->GetOperand(i); break; } + } + } + if (!val_f) continue; + to_convert.push_back({phi, val_t, val_f}); + } + if (to_convert.empty()) return false; + + auto* cbr = B->GetInstructions().back().get(); + B->TakeInstruction(cbr); + + auto& t_insts = const_cast>&>(T->GetInstructions()); + std::vector t_to_move; + for (const auto& inst : t_insts) + if (inst->GetOpcode() != Opcode::Br) + t_to_move.push_back(inst.get()); + for (auto* inst : t_to_move) { + auto taken = T->TakeInstruction(inst); + B->InsertInstructionBeforeTerminator(std::move(taken)); + } + if (!T->GetInstructions().empty()) + T->TakeInstruction(T->GetInstructions().back().get()); + + for (auto& [phi, val_t, val_f] : to_convert) { + if (val_t == val_f) { + RemovePhiEntriesFrom(phi, T); + SetPhiEntry(phi, B, val_f); + continue; + } + auto* zext = B->Append(Opcode::ZExt, Type::GetInt32Type(), cond_i1, ctx.NextTemp()); + auto* diff = B->Append(Opcode::Sub, Type::GetInt32Type(), val_t, val_f, ctx.NextTemp()); + auto* masked = B->Append(Opcode::Mul, Type::GetInt32Type(), diff, zext, ctx.NextTemp()); + auto* select_val = B->Append(Opcode::Add, Type::GetInt32Type(), val_f, masked, ctx.NextTemp()); + RemovePhiEntriesFrom(phi, T); + SetPhiEntry(phi, B, select_val); + } + + B->Append(Type::GetVoidType(), M); + return true; +} + +static void IfConvertFunction(Function* func, Context& ctx) { + auto& blocks = const_cast>&>(func->GetBlocks()); + bool changed = true; + while (changed) { + changed = false; + for (const auto& bb : blocks) { + const auto& insts = bb->GetInstructions(); + if (insts.empty()) continue; + auto* cbr = dynamic_cast(insts.back().get()); + if (!cbr) continue; + BasicBlock* T = cbr->GetTrueTarget(); + BasicBlock* F = cbr->GetFalseTarget(); + Value* cond = UnwrapCondition(cbr->GetCond()); + if (TryConvertOneDiamond(bb.get(), T, F, cond, ctx, blocks)) { + changed = true; + break; + } + } + } +} + +static void CleanupRedundantPhis(Function* func) { + for (const auto& bb : func->GetBlocks()) { + auto& insts = const_cast>&>(bb->GetInstructions()); + for (size_t i = 0; i < insts.size(); ) { + auto* phi = dynamic_cast(insts[i].get()); + if (!phi) break; + Value* unique_val = nullptr; + bool all_same = true; + for (size_t j = 0; j < phi->GetNumOperands(); j += 2) { + Value* v = phi->GetOperand(j); + if (!unique_val) unique_val = v; + else if (unique_val != v) { all_same = false; break; } + } + if (all_same && unique_val) { + phi->ReplaceAllUsesWith(unique_val); + phi->ClearOperands(); + insts.erase(insts.begin() + i); + continue; + } + ++i; + } + } +} + +static void MergeSinglePredBlocks(Function* func) { + auto& blocks = const_cast>&>(func->GetBlocks()); + bool changed = true; + while (changed) { + changed = false; + for (auto& bb_ptr : blocks) { + BasicBlock* bb = bb_ptr.get(); + if (bb == func->GetEntry()) continue; + bool has_phi = false; + for (const auto& inst : bb->GetInstructions()) { + if (dynamic_cast(inst.get())) { has_phi = true; break; } + } + if (has_phi) continue; + auto preds = ComputePredecessors(bb, blocks); + if (preds.size() != 1) continue; + BasicBlock* pred = preds[0]; + if (pred == bb) continue; + const auto& pred_insts = pred->GetInstructions(); + if (pred_insts.empty()) continue; + auto* br = dynamic_cast(pred_insts.back().get()); + if (!br || br->GetTarget() != bb) continue; + pred->TakeInstruction(pred_insts.back().get()); + auto& bb_insts = const_cast>&>(bb->GetInstructions()); + std::vector to_move; + for (auto& inst : bb_insts) + to_move.push_back(inst.get()); + for (auto* inst : to_move) { + auto taken = bb->TakeInstruction(inst); + pred->InsertInstructionBeforeTerminator(std::move(taken)); + } + for (auto& other : blocks) { + if (other.get() == bb) continue; + auto& o_insts = const_cast>&>(other->GetInstructions()); + for (auto& inst : o_insts) { + auto* phi = dynamic_cast(inst.get()); + if (!phi) break; + for (size_t i = 0; i < phi->GetNumOperands(); i += 2) { + if (dynamic_cast(phi->GetOperand(i + 1)) == bb) + phi->SetOperand(i + 1, pred); + } + } + } + changed = true; + break; + } + } +} + +} // namespace + +void RunIfConversion(Module& module) { + for (auto& func : module.GetFunctions()) { + if (func->IsExternal()) continue; + IfConvertFunction(func.get(), module.GetContext()); + CleanupRedundantPhis(func.get()); + MergeSinglePredBlocks(func.get()); + } +} + +} // namespace ir diff --git a/src/ir/passes/LoopUnroll.cpp b/src/ir/passes/LoopUnroll.cpp index 8c1fceb4..e6465b8e 100644 --- a/src/ir/passes/LoopUnroll.cpp +++ b/src/ir/passes/LoopUnroll.cpp @@ -167,11 +167,35 @@ static bool UnrollSimple(Function* func, BasicBlock* header, BasicBlock* body, if (preheader) break; } - // 克隆 N 次 - std::vector> new_blocks; + // 收集所有 header phi 的 init/latch 映射(用于跨迭代值追踪) + struct PhiInfo { Value* init_val; Value* latch_val; }; + std::unordered_map phi_info; + for (const auto& inst : header->GetInstructions()) { + auto* hphi = dynamic_cast(inst.get()); + if (!hphi) break; + Value *v0 = hphi->GetOperand(0), *v1 = hphi->GetOperand(2); + BasicBlock *bb0 = dynamic_cast(hphi->GetOperand(1)); + BasicBlock *bb1 = dynamic_cast(hphi->GetOperand(3)); + if (bb0 != body && bb1 == body) + phi_info[hphi] = {v0, v1}; + else if (bb1 != body && bb0 == body) + phi_info[hphi] = {v1, v0}; + } + + // 跨迭代追踪所有 phi 值 + std::unordered_map curr_vals; + for (auto& [hphi, info] : phi_info) + curr_vals[hphi] = info.init_val; + + // 将所有迭代克隆到单个块中(使函数变为单 BB,可被 Inline 内联) + auto unrolled_bb = std::make_unique(ctx.NextTemp() + "_unroll"); for (int iter = 0; iter < trip_count; ++iter) { - auto new_bb = std::make_unique(ctx.NextTemp() + "_unroll"); std::unordered_map vm; + + // 所有 header phi 替换为当前迭代值 + for (auto& [hphi, val] : curr_vals) + vm[hphi] = val; + // len phi 额外用常量覆盖 vm[phi] = ctx.GetConstInt(trip_count - iter); for (auto* inst : body_insts) { @@ -181,43 +205,92 @@ static bool UnrollSimple(Function* func, BasicBlock* header, BasicBlock* body, auto cloned = CloneInstruction(inst, vm, ctx); if (!cloned) continue; vm[inst] = cloned.get(); - new_bb->InsertInstructionBeforeTerminator(std::move(cloned)); + unrolled_bb->InsertInstructionBeforeTerminator(std::move(cloned)); } - // 最后一份 body 后跳到 exit - if (iter == trip_count - 1) { - auto br_exit = std::make_unique(Type::GetVoidType(), exit_bb); - new_bb->InsertInstructionBeforeTerminator(std::move(br_exit)); + // 更新下次迭代的 phi 值 + for (auto& [hphi, info] : phi_info) { + if (hphi == phi) continue; + auto it = vm.find(info.latch_val); + if (it != vm.end()) + curr_vals[hphi] = it->second; } + } - new_blocks.push_back(std::move(new_bb)); + // 将 exit 块的 ret 指令直接放入展开块(使函数变为单 BB) + if (!exit_bb->GetInstructions().empty()) { + auto* exit_ret = exit_bb->GetInstructions().back().get(); + if (dynamic_cast(exit_ret)) { + auto taken = exit_bb->TakeInstruction(exit_ret); + unrolled_bb->InsertInstructionBeforeTerminator(std::move(taken)); + } } - // 修复 preheader 跳转 - if (preheader && !new_blocks.empty()) { + // 用最后迭代的值替换所有 header phi 的剩余引用(如 exit 块中) + for (auto& [hphi, val] : curr_vals) + hphi->ReplaceAllUsesWith(val); + + // 修复 preheader 跳转到展开块 + if (preheader) { auto& pi = const_cast>&>(preheader->GetInstructions()); if (!pi.empty()) { auto* term = pi.back().get(); if (auto* br = dynamic_cast(term)) - br->SetOperand(0, new_blocks[0].get()); + br->SetOperand(0, unrolled_bb.get()); else if (auto* cbr = dynamic_cast(term)) { - if (cbr->GetTrueTarget() == header) cbr->SetOperand(1, new_blocks[0].get()); - if (cbr->GetFalseTarget() == header) cbr->SetOperand(2, new_blocks[0].get()); + if (cbr->GetTrueTarget() == header) cbr->SetOperand(1, unrolled_bb.get()); + if (cbr->GetFalseTarget() == header) cbr->SetOperand(2, unrolled_bb.get()); + } + } + } + + // 若 preheader 仅有 Br 指令,将展开块内容合并到 preheader(使函数单 BB) + if (preheader && preheader->GetInstructions().size() == 1 && + dynamic_cast(preheader->GetInstructions().back().get())) { + // 移除 preheader 的 Br + auto* pre_br = preheader->GetInstructions().back().get(); + preheader->TakeInstruction(pre_br); + // 移动展开块所有指令到 preheader + auto& u_insts = const_cast>&>( + unrolled_bb->GetInstructions()); + std::vector u_to_move; + for (auto& inst : u_insts) + u_to_move.push_back(inst.get()); + for (auto* inst : u_to_move) { + auto taken = unrolled_bb->TakeInstruction(inst); + preheader->InsertInstructionBeforeTerminator(std::move(taken)); + } + // unrolled_bb 现在是空的,后续不插入它 + } else { + // 修复 preheader 跳转到展开块 + if (preheader) { + auto& pi = const_cast>&>(preheader->GetInstructions()); + if (!pi.empty()) { + auto* term = pi.back().get(); + if (auto* br = dynamic_cast(term)) + br->SetOperand(0, unrolled_bb.get()); + else if (auto* cbr = dynamic_cast(term)) { + if (cbr->GetTrueTarget() == header) cbr->SetOperand(1, unrolled_bb.get()); + if (cbr->GetFalseTarget() == header) cbr->SetOperand(2, unrolled_bb.get()); + } } } } - // 删除 header + body,插入新块 + // 删除 header + body + exit auto ipos = fb.begin(); if (preheader) { for (auto it = fb.begin(); it != fb.end(); ++it) if (it->get() == preheader) { ipos = it + 1; break; } } - for (auto& nb : new_blocks) - ipos = fb.insert(ipos, std::move(nb)) + 1; + // 若展开块已空(已合并到 preheader),不插入 + if (!unrolled_bb->GetInstructions().empty()) { + ipos = fb.insert(ipos, std::move(unrolled_bb)) + 1; + } fb.erase(std::remove_if(fb.begin(), fb.end(), [&](const std::unique_ptr& bb) { - return bb.get() == header || bb.get() == body; + return bb.get() == header || bb.get() == body || + bb.get() == exit_bb; }), fb.end()); return true; }