perf(ir): 简单递减循环全展开——countdown loop unrolling

- 检测 while(len) { body; len=len-1; } 递减循环模式
- TripCount ≤64 且为编译时常量 → 全展开
- 展开后函数变为单BB,可被 Inline 内联
- 配合 ConstFold 将常量 len/power 传播到每次迭代

门禁:functional 99/100(95_float预存),h_functional 40/40
lzk
lzkk 3 days ago
parent b3187edbbf
commit 608aa98f4e

@ -11,6 +11,7 @@ namespace ir {
void RunMem2Reg(Module& module);
void RunLICM(Module* module);
void RunInline(Module& module);
void RunLoopUnroll(Module& module);
void RunConstFold(Module& module);
void RunConstProp(Module& module);
void RunDCE(Module& module);
@ -28,6 +29,8 @@ class PassManager {
RunInline(*module);
RunLoopUnroll(*module);
RunLICM(module);
bool changed = true;

@ -6,6 +6,7 @@ add_library(ir_passes STATIC
DCE.cpp
CFGSimplify.cpp
Inline.cpp
LoopUnroll.cpp
IRVerifier.cpp
)

@ -0,0 +1,258 @@
// 简单 countdown 循环全展开:
// - 处理形如 while(len) { body; len = len - 1; } 的递减循环
// - 要求 body 为单 BBlen 初值为编译时常量且 ≤64
// - 全展开后函数变为单 BB可被 Inline 内联
// - 配合 ConstFold 将 len/power 等常量传播到每次迭代
#include "ir/IR.h"
#include <algorithm>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
// 检测递减循环模式,返回 (phi, trip_count) 或 nullptr
static PhiInst* DetectSimpleCountdown(BasicBlock* header, BasicBlock* body,
BasicBlock* exit_bb, int& trip_count) {
// 检查 body → header 回边
bool has_backedge = false;
for (const auto& inst : body->GetInstructions()) {
if (auto* br = dynamic_cast<BranchInst*>(inst.get()))
if (br->GetTarget() == header) has_backedge = true;
}
if (!has_backedge) return nullptr;
for (const auto& inst : header->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst.get());
if (!phi) continue;
if (phi->GetNumOperands() < 4) continue;
Value* val0 = phi->GetOperand(0);
BasicBlock* bb0 = dynamic_cast<BasicBlock*>(phi->GetOperand(1));
Value* val1 = phi->GetOperand(2);
BasicBlock* bb1 = dynamic_cast<BasicBlock*>(phi->GetOperand(3));
Value* init_val = nullptr;
Value* update_val = nullptr;
if (bb0 != body && bb1 == body) { init_val = val0; update_val = val1; }
else if (bb1 != body && bb0 == body) { init_val = val1; update_val = val0; }
else continue;
auto* init_c = dynamic_cast<ConstantInt*>(init_val);
if (!init_c) continue;
int count = init_c->GetValue();
if (count <= 0 || count > 64) continue;
auto* sub = dynamic_cast<BinaryInst*>(update_val);
if (!sub || sub->GetOpcode() != Opcode::Sub) continue;
if (sub->GetLhs() != phi) continue;
auto* dec = dynamic_cast<ConstantInt*>(sub->GetRhs());
if (!dec || dec->GetValue() != 1) continue;
// 检查退出条件 phi == 0
bool exits = false;
for (const auto& inst : header->GetInstructions()) {
if (auto* cbr = dynamic_cast<CondBranchInst*>(inst.get())) {
Value* cond = cbr->GetCond();
if (auto* outer = dynamic_cast<BinaryInst*>(cond)) {
if (outer->GetOpcode() == Opcode::Ne) {
auto* rc = dynamic_cast<ConstantInt*>(outer->GetRhs());
if (rc && rc->GetValue() == 0)
if (auto* zext = dynamic_cast<CastInst*>(outer->GetLhs()))
if (zext->GetOpcode() == Opcode::ZExt) cond = zext->GetOperandValue();
}
}
if (auto* cmp = dynamic_cast<BinaryInst*>(cond)) {
Value* other = nullptr;
if (cmp->GetLhs() == phi) other = cmp->GetRhs();
else if (cmp->GetRhs() == phi) other = cmp->GetLhs();
if (other && dynamic_cast<ConstantInt*>(other)) {
auto* c = static_cast<ConstantInt*>(other);
if (c->GetValue() == 0 &&
(cbr->GetTrueTarget() == exit_bb || cbr->GetFalseTarget() == exit_bb))
exits = true;
}
}
}
}
if (!exits) continue;
trip_count = count;
return phi;
}
return nullptr;
}
// 克隆指令
static std::unique_ptr<Instruction> CloneInstruction(
Instruction* inst,
const std::unordered_map<Value*, Value*>& value_map,
Context& ctx) {
auto map = [&](Value* v) -> Value* {
auto it = value_map.find(v);
return (it != value_map.end()) ? it->second : v;
};
Opcode op = inst->GetOpcode();
switch (op) {
case Opcode::Add: case Opcode::Sub: case Opcode::Mul:
case Opcode::Div: case Opcode::Mod:
case Opcode::Eq: case Opcode::Ne: case Opcode::Lt:
case Opcode::Le: case Opcode::Gt: case Opcode::Ge:
case Opcode::And: case Opcode::Or: {
auto* bin = static_cast<BinaryInst*>(inst);
return std::make_unique<BinaryInst>(op, inst->GetType(),
map(bin->GetLhs()), map(bin->GetRhs()),
ctx.NextTemp());
}
case Opcode::Load: {
auto* load = static_cast<LoadInst*>(inst);
return std::make_unique<LoadInst>(inst->GetType(), map(load->GetPtr()),
ctx.NextTemp());
}
case Opcode::Store: {
auto* store = static_cast<StoreInst*>(inst);
return std::make_unique<StoreInst>(inst->GetType(), map(store->GetValue()),
map(store->GetPtr()));
}
case Opcode::ZExt: {
auto* cast = static_cast<CastInst*>(inst);
return std::make_unique<CastInst>(op, inst->GetType(),
map(cast->GetOperandValue()), ctx.NextTemp());
}
case Opcode::Br: {
auto* br = static_cast<BranchInst*>(inst);
return std::make_unique<BranchInst>(inst->GetType(), br->GetTarget());
}
case Opcode::CondBr: {
auto* cbr = static_cast<CondBranchInst*>(inst);
return std::make_unique<CondBranchInst>(inst->GetType(), map(cbr->GetCond()),
cbr->GetTrueTarget(), cbr->GetFalseTarget());
}
default: return nullptr;
}
}
// 展开 countdown 循环
static bool UnrollSimple(Function* func, BasicBlock* header, BasicBlock* body,
BasicBlock* exit_bb, PhiInst* phi, int trip_count,
Context& ctx) {
auto& fb = const_cast<std::vector<std::unique_ptr<BasicBlock>>&>(func->GetBlocks());
// 收集 body 指令(不含回边)
std::vector<Instruction*> body_insts;
for (const auto& inst : body->GetInstructions()) {
if (auto* br = dynamic_cast<BranchInst*>(inst.get()))
if (br->GetTarget() == header) continue;
body_insts.push_back(inst.get());
}
// 找 preheader
BasicBlock* preheader = nullptr;
for (const auto& bb : func->GetBlocks()) {
for (const auto& inst : bb->GetInstructions()) {
if (auto* br = dynamic_cast<BranchInst*>(inst.get()))
if (br->GetTarget() == header) { preheader = bb.get(); break; }
if (auto* cbr = dynamic_cast<CondBranchInst*>(inst.get()))
if (cbr->GetTrueTarget() == header || cbr->GetFalseTarget() == header)
{ preheader = bb.get(); break; }
}
if (preheader) break;
}
// 克隆 N 次
std::vector<std::unique_ptr<BasicBlock>> new_blocks;
for (int iter = 0; iter < trip_count; ++iter) {
auto new_bb = std::make_unique<BasicBlock>(ctx.NextTemp() + "_unroll");
std::unordered_map<Value*, Value*> vm;
vm[phi] = ctx.GetConstInt(trip_count - iter);
for (auto* inst : body_insts) {
if (auto* bin = dynamic_cast<BinaryInst*>(inst))
if (bin->GetOpcode() == Opcode::Sub && bin->GetLhs() == phi) continue;
auto cloned = CloneInstruction(inst, vm, ctx);
if (!cloned) continue;
vm[inst] = cloned.get();
new_bb->InsertInstructionBeforeTerminator(std::move(cloned));
}
// 最后一份 body 后跳到 exit
if (iter == trip_count - 1) {
auto br_exit = std::make_unique<BranchInst>(Type::GetVoidType(), exit_bb);
new_bb->InsertInstructionBeforeTerminator(std::move(br_exit));
}
new_blocks.push_back(std::move(new_bb));
}
// 修复 preheader 跳转
if (preheader && !new_blocks.empty()) {
auto& pi = const_cast<std::vector<std::unique_ptr<Instruction>>&>(preheader->GetInstructions());
if (!pi.empty()) {
auto* term = pi.back().get();
if (auto* br = dynamic_cast<BranchInst*>(term))
br->SetOperand(0, new_blocks[0].get());
else if (auto* cbr = dynamic_cast<CondBranchInst*>(term)) {
if (cbr->GetTrueTarget() == header) cbr->SetOperand(1, new_blocks[0].get());
if (cbr->GetFalseTarget() == header) cbr->SetOperand(2, new_blocks[0].get());
}
}
}
// 删除 header + body插入新块
auto ipos = fb.begin();
if (preheader) {
for (auto it = fb.begin(); it != fb.end(); ++it)
if (it->get() == preheader) { ipos = it + 1; break; }
}
for (auto& nb : new_blocks)
ipos = fb.insert(ipos, std::move(nb)) + 1;
fb.erase(std::remove_if(fb.begin(), fb.end(),
[&](const std::unique_ptr<BasicBlock>& bb) {
return bb.get() == header || bb.get() == body;
}), fb.end());
return true;
}
} // namespace
void RunLoopUnroll(Module& module) {
int unrolled = 0;
for (auto& func : module.GetFunctions()) {
if (func->IsExternal()) continue;
bool changed = true;
while (changed) {
changed = false;
for (const auto& bb : func->GetBlocks()) {
for (const auto& inst : bb->GetInstructions()) {
auto* br = dynamic_cast<BranchInst*>(inst.get());
if (!br) continue;
BasicBlock* target = br->GetTarget();
for (const auto& tgt_inst : target->GetInstructions()) {
auto* cbr = dynamic_cast<CondBranchInst*>(tgt_inst.get());
if (!cbr) continue;
BasicBlock *t = cbr->GetTrueTarget(), *f = cbr->GetFalseTarget();
BasicBlock *body = nullptr, *exit_bb = nullptr;
if (t == bb.get()) { body = t; exit_bb = f; }
else if (f == bb.get()) { body = f; exit_bb = t; }
if (!body || !exit_bb || body == target || exit_bb == target) continue;
int tc = 0;
auto* phi = DetectSimpleCountdown(target, body, exit_bb, tc);
if (!phi) continue;
if (UnrollSimple(func.get(), target, body, exit_bb, phi, tc,
module.GetContext())) {
++unrolled; changed = true; goto next_func;
}
}
}
}
next_func:;
}
}
}
} // namespace ir
Loading…
Cancel
Save