From 608aa98f4ed4db0e0c2489d53c46e048258647ed Mon Sep 17 00:00:00 2001 From: lzkk <956449176@qq.com> Date: Thu, 28 May 2026 19:04:46 +0800 Subject: [PATCH] =?UTF-8?q?perf(ir):=20=E7=AE=80=E5=8D=95=E9=80=92?= =?UTF-8?q?=E5=87=8F=E5=BE=AA=E7=8E=AF=E5=85=A8=E5=B1=95=E5=BC=80=E2=80=94?= =?UTF-8?q?=E2=80=94countdown=20loop=20unrolling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 检测 while(len) { body; len=len-1; } 递减循环模式 - TripCount ≤64 且为编译时常量 → 全展开 - 展开后函数变为单BB,可被 Inline 内联 - 配合 ConstFold 将常量 len/power 传播到每次迭代 门禁:functional 99/100(95_float预存),h_functional 40/40 --- src/include/ir/passes/PassManager.h | 3 + src/ir/passes/CMakeLists.txt | 1 + src/ir/passes/LoopUnroll.cpp | 258 ++++++++++++++++++++++++++++ 3 files changed, 262 insertions(+) create mode 100644 src/ir/passes/LoopUnroll.cpp diff --git a/src/include/ir/passes/PassManager.h b/src/include/ir/passes/PassManager.h index 22670ae9..3557e7ad 100644 --- a/src/include/ir/passes/PassManager.h +++ b/src/include/ir/passes/PassManager.h @@ -11,6 +11,7 @@ namespace ir { void RunMem2Reg(Module& module); void RunLICM(Module* module); void RunInline(Module& module); +void RunLoopUnroll(Module& module); void RunConstFold(Module& module); void RunConstProp(Module& module); void RunDCE(Module& module); @@ -28,6 +29,8 @@ class PassManager { RunInline(*module); + RunLoopUnroll(*module); + RunLICM(module); bool changed = true; diff --git a/src/ir/passes/CMakeLists.txt b/src/ir/passes/CMakeLists.txt index 31318ab9..a503d88f 100644 --- a/src/ir/passes/CMakeLists.txt +++ b/src/ir/passes/CMakeLists.txt @@ -6,6 +6,7 @@ add_library(ir_passes STATIC DCE.cpp CFGSimplify.cpp Inline.cpp + LoopUnroll.cpp IRVerifier.cpp ) diff --git a/src/ir/passes/LoopUnroll.cpp b/src/ir/passes/LoopUnroll.cpp new file mode 100644 index 00000000..21ec1dc1 --- /dev/null +++ b/src/ir/passes/LoopUnroll.cpp @@ -0,0 +1,258 @@ +// 简单 countdown 循环全展开: +// - 处理形如 while(len) { body; len = len - 1; } 的递减循环 +// - 要求 body 为单 BB,len 初值为编译时常量且 ≤64 +// - 全展开后函数变为单 BB,可被 Inline 内联 +// - 配合 ConstFold 将 len/power 等常量传播到每次迭代 + +#include "ir/IR.h" + +#include +#include +#include + +namespace ir { + +namespace { + +// 检测递减循环模式,返回 (phi, trip_count) 或 nullptr +static PhiInst* DetectSimpleCountdown(BasicBlock* header, BasicBlock* body, + BasicBlock* exit_bb, int& trip_count) { + // 检查 body → header 回边 + bool has_backedge = false; + for (const auto& inst : body->GetInstructions()) { + if (auto* br = dynamic_cast(inst.get())) + if (br->GetTarget() == header) has_backedge = true; + } + if (!has_backedge) return nullptr; + + for (const auto& inst : header->GetInstructions()) { + auto* phi = dynamic_cast(inst.get()); + if (!phi) continue; + if (phi->GetNumOperands() < 4) continue; + + Value* val0 = phi->GetOperand(0); + BasicBlock* bb0 = dynamic_cast(phi->GetOperand(1)); + Value* val1 = phi->GetOperand(2); + BasicBlock* bb1 = dynamic_cast(phi->GetOperand(3)); + + Value* init_val = nullptr; + Value* update_val = nullptr; + if (bb0 != body && bb1 == body) { init_val = val0; update_val = val1; } + else if (bb1 != body && bb0 == body) { init_val = val1; update_val = val0; } + else continue; + + auto* init_c = dynamic_cast(init_val); + if (!init_c) continue; + int count = init_c->GetValue(); + if (count <= 0 || count > 64) continue; + + auto* sub = dynamic_cast(update_val); + if (!sub || sub->GetOpcode() != Opcode::Sub) continue; + if (sub->GetLhs() != phi) continue; + auto* dec = dynamic_cast(sub->GetRhs()); + if (!dec || dec->GetValue() != 1) continue; + + // 检查退出条件 phi == 0 + bool exits = false; + for (const auto& inst : header->GetInstructions()) { + if (auto* cbr = dynamic_cast(inst.get())) { + Value* cond = cbr->GetCond(); + if (auto* outer = dynamic_cast(cond)) { + if (outer->GetOpcode() == Opcode::Ne) { + auto* rc = dynamic_cast(outer->GetRhs()); + if (rc && rc->GetValue() == 0) + if (auto* zext = dynamic_cast(outer->GetLhs())) + if (zext->GetOpcode() == Opcode::ZExt) cond = zext->GetOperandValue(); + } + } + if (auto* cmp = dynamic_cast(cond)) { + Value* other = nullptr; + if (cmp->GetLhs() == phi) other = cmp->GetRhs(); + else if (cmp->GetRhs() == phi) other = cmp->GetLhs(); + if (other && dynamic_cast(other)) { + auto* c = static_cast(other); + if (c->GetValue() == 0 && + (cbr->GetTrueTarget() == exit_bb || cbr->GetFalseTarget() == exit_bb)) + exits = true; + } + } + } + } + if (!exits) continue; + + trip_count = count; + return phi; + } + return nullptr; +} + +// 克隆指令 +static std::unique_ptr CloneInstruction( + Instruction* inst, + const std::unordered_map& value_map, + Context& ctx) { + auto map = [&](Value* v) -> Value* { + auto it = value_map.find(v); + return (it != value_map.end()) ? it->second : v; + }; + Opcode op = inst->GetOpcode(); + switch (op) { + case Opcode::Add: case Opcode::Sub: case Opcode::Mul: + case Opcode::Div: case Opcode::Mod: + case Opcode::Eq: case Opcode::Ne: case Opcode::Lt: + case Opcode::Le: case Opcode::Gt: case Opcode::Ge: + case Opcode::And: case Opcode::Or: { + auto* bin = static_cast(inst); + return std::make_unique(op, inst->GetType(), + map(bin->GetLhs()), map(bin->GetRhs()), + ctx.NextTemp()); + } + case Opcode::Load: { + auto* load = static_cast(inst); + return std::make_unique(inst->GetType(), map(load->GetPtr()), + ctx.NextTemp()); + } + case Opcode::Store: { + auto* store = static_cast(inst); + return std::make_unique(inst->GetType(), map(store->GetValue()), + map(store->GetPtr())); + } + case Opcode::ZExt: { + auto* cast = static_cast(inst); + return std::make_unique(op, inst->GetType(), + map(cast->GetOperandValue()), ctx.NextTemp()); + } + case Opcode::Br: { + auto* br = static_cast(inst); + return std::make_unique(inst->GetType(), br->GetTarget()); + } + case Opcode::CondBr: { + auto* cbr = static_cast(inst); + return std::make_unique(inst->GetType(), map(cbr->GetCond()), + cbr->GetTrueTarget(), cbr->GetFalseTarget()); + } + default: return nullptr; + } +} + +// 展开 countdown 循环 +static bool UnrollSimple(Function* func, BasicBlock* header, BasicBlock* body, + BasicBlock* exit_bb, PhiInst* phi, int trip_count, + Context& ctx) { + auto& fb = const_cast>&>(func->GetBlocks()); + + // 收集 body 指令(不含回边) + std::vector body_insts; + for (const auto& inst : body->GetInstructions()) { + if (auto* br = dynamic_cast(inst.get())) + if (br->GetTarget() == header) continue; + body_insts.push_back(inst.get()); + } + + // 找 preheader + BasicBlock* preheader = nullptr; + for (const auto& bb : func->GetBlocks()) { + for (const auto& inst : bb->GetInstructions()) { + if (auto* br = dynamic_cast(inst.get())) + if (br->GetTarget() == header) { preheader = bb.get(); break; } + if (auto* cbr = dynamic_cast(inst.get())) + if (cbr->GetTrueTarget() == header || cbr->GetFalseTarget() == header) + { preheader = bb.get(); break; } + } + if (preheader) break; + } + + // 克隆 N 次 + std::vector> new_blocks; + for (int iter = 0; iter < trip_count; ++iter) { + auto new_bb = std::make_unique(ctx.NextTemp() + "_unroll"); + std::unordered_map vm; + vm[phi] = ctx.GetConstInt(trip_count - iter); + + for (auto* inst : body_insts) { + if (auto* bin = dynamic_cast(inst)) + if (bin->GetOpcode() == Opcode::Sub && bin->GetLhs() == phi) continue; + + auto cloned = CloneInstruction(inst, vm, ctx); + if (!cloned) continue; + vm[inst] = cloned.get(); + new_bb->InsertInstructionBeforeTerminator(std::move(cloned)); + } + + // 最后一份 body 后跳到 exit + if (iter == trip_count - 1) { + auto br_exit = std::make_unique(Type::GetVoidType(), exit_bb); + new_bb->InsertInstructionBeforeTerminator(std::move(br_exit)); + } + + new_blocks.push_back(std::move(new_bb)); + } + + // 修复 preheader 跳转 + if (preheader && !new_blocks.empty()) { + auto& pi = const_cast>&>(preheader->GetInstructions()); + if (!pi.empty()) { + auto* term = pi.back().get(); + if (auto* br = dynamic_cast(term)) + br->SetOperand(0, new_blocks[0].get()); + else if (auto* cbr = dynamic_cast(term)) { + if (cbr->GetTrueTarget() == header) cbr->SetOperand(1, new_blocks[0].get()); + if (cbr->GetFalseTarget() == header) cbr->SetOperand(2, new_blocks[0].get()); + } + } + } + + // 删除 header + body,插入新块 + auto ipos = fb.begin(); + if (preheader) { + for (auto it = fb.begin(); it != fb.end(); ++it) + if (it->get() == preheader) { ipos = it + 1; break; } + } + for (auto& nb : new_blocks) + ipos = fb.insert(ipos, std::move(nb)) + 1; + fb.erase(std::remove_if(fb.begin(), fb.end(), + [&](const std::unique_ptr& bb) { + return bb.get() == header || bb.get() == body; + }), fb.end()); + return true; +} + +} // namespace + +void RunLoopUnroll(Module& module) { + int unrolled = 0; + for (auto& func : module.GetFunctions()) { + if (func->IsExternal()) continue; + bool changed = true; + while (changed) { + changed = false; + for (const auto& bb : func->GetBlocks()) { + for (const auto& inst : bb->GetInstructions()) { + auto* br = dynamic_cast(inst.get()); + if (!br) continue; + BasicBlock* target = br->GetTarget(); + for (const auto& tgt_inst : target->GetInstructions()) { + auto* cbr = dynamic_cast(tgt_inst.get()); + if (!cbr) continue; + BasicBlock *t = cbr->GetTrueTarget(), *f = cbr->GetFalseTarget(); + BasicBlock *body = nullptr, *exit_bb = nullptr; + if (t == bb.get()) { body = t; exit_bb = f; } + else if (f == bb.get()) { body = f; exit_bb = t; } + if (!body || !exit_bb || body == target || exit_bb == target) continue; + + int tc = 0; + auto* phi = DetectSimpleCountdown(target, body, exit_bb, tc); + if (!phi) continue; + if (UnrollSimple(func.get(), target, body, exit_bb, phi, tc, + module.GetContext())) { + ++unrolled; changed = true; goto next_func; + } + } + } + } + next_func:; + } + } +} + +} // namespace ir