|
|
|
|
@ -0,0 +1,258 @@
|
|
|
|
|
// 简单 countdown 循环全展开:
|
|
|
|
|
// - 处理形如 while(len) { body; len = len - 1; } 的递减循环
|
|
|
|
|
// - 要求 body 为单 BB,len 初值为编译时常量且 ≤64
|
|
|
|
|
// - 全展开后函数变为单 BB,可被 Inline 内联
|
|
|
|
|
// - 配合 ConstFold 将 len/power 等常量传播到每次迭代
|
|
|
|
|
|
|
|
|
|
#include "ir/IR.h"
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
namespace ir {
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
|
|
// 检测递减循环模式,返回 (phi, trip_count) 或 nullptr
|
|
|
|
|
static PhiInst* DetectSimpleCountdown(BasicBlock* header, BasicBlock* body,
|
|
|
|
|
BasicBlock* exit_bb, int& trip_count) {
|
|
|
|
|
// 检查 body → header 回边
|
|
|
|
|
bool has_backedge = false;
|
|
|
|
|
for (const auto& inst : body->GetInstructions()) {
|
|
|
|
|
if (auto* br = dynamic_cast<BranchInst*>(inst.get()))
|
|
|
|
|
if (br->GetTarget() == header) has_backedge = true;
|
|
|
|
|
}
|
|
|
|
|
if (!has_backedge) return nullptr;
|
|
|
|
|
|
|
|
|
|
for (const auto& inst : header->GetInstructions()) {
|
|
|
|
|
auto* phi = dynamic_cast<PhiInst*>(inst.get());
|
|
|
|
|
if (!phi) continue;
|
|
|
|
|
if (phi->GetNumOperands() < 4) continue;
|
|
|
|
|
|
|
|
|
|
Value* val0 = phi->GetOperand(0);
|
|
|
|
|
BasicBlock* bb0 = dynamic_cast<BasicBlock*>(phi->GetOperand(1));
|
|
|
|
|
Value* val1 = phi->GetOperand(2);
|
|
|
|
|
BasicBlock* bb1 = dynamic_cast<BasicBlock*>(phi->GetOperand(3));
|
|
|
|
|
|
|
|
|
|
Value* init_val = nullptr;
|
|
|
|
|
Value* update_val = nullptr;
|
|
|
|
|
if (bb0 != body && bb1 == body) { init_val = val0; update_val = val1; }
|
|
|
|
|
else if (bb1 != body && bb0 == body) { init_val = val1; update_val = val0; }
|
|
|
|
|
else continue;
|
|
|
|
|
|
|
|
|
|
auto* init_c = dynamic_cast<ConstantInt*>(init_val);
|
|
|
|
|
if (!init_c) continue;
|
|
|
|
|
int count = init_c->GetValue();
|
|
|
|
|
if (count <= 0 || count > 64) continue;
|
|
|
|
|
|
|
|
|
|
auto* sub = dynamic_cast<BinaryInst*>(update_val);
|
|
|
|
|
if (!sub || sub->GetOpcode() != Opcode::Sub) continue;
|
|
|
|
|
if (sub->GetLhs() != phi) continue;
|
|
|
|
|
auto* dec = dynamic_cast<ConstantInt*>(sub->GetRhs());
|
|
|
|
|
if (!dec || dec->GetValue() != 1) continue;
|
|
|
|
|
|
|
|
|
|
// 检查退出条件 phi == 0
|
|
|
|
|
bool exits = false;
|
|
|
|
|
for (const auto& inst : header->GetInstructions()) {
|
|
|
|
|
if (auto* cbr = dynamic_cast<CondBranchInst*>(inst.get())) {
|
|
|
|
|
Value* cond = cbr->GetCond();
|
|
|
|
|
if (auto* outer = dynamic_cast<BinaryInst*>(cond)) {
|
|
|
|
|
if (outer->GetOpcode() == Opcode::Ne) {
|
|
|
|
|
auto* rc = dynamic_cast<ConstantInt*>(outer->GetRhs());
|
|
|
|
|
if (rc && rc->GetValue() == 0)
|
|
|
|
|
if (auto* zext = dynamic_cast<CastInst*>(outer->GetLhs()))
|
|
|
|
|
if (zext->GetOpcode() == Opcode::ZExt) cond = zext->GetOperandValue();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (auto* cmp = dynamic_cast<BinaryInst*>(cond)) {
|
|
|
|
|
Value* other = nullptr;
|
|
|
|
|
if (cmp->GetLhs() == phi) other = cmp->GetRhs();
|
|
|
|
|
else if (cmp->GetRhs() == phi) other = cmp->GetLhs();
|
|
|
|
|
if (other && dynamic_cast<ConstantInt*>(other)) {
|
|
|
|
|
auto* c = static_cast<ConstantInt*>(other);
|
|
|
|
|
if (c->GetValue() == 0 &&
|
|
|
|
|
(cbr->GetTrueTarget() == exit_bb || cbr->GetFalseTarget() == exit_bb))
|
|
|
|
|
exits = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!exits) continue;
|
|
|
|
|
|
|
|
|
|
trip_count = count;
|
|
|
|
|
return phi;
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 克隆指令
|
|
|
|
|
static std::unique_ptr<Instruction> CloneInstruction(
|
|
|
|
|
Instruction* inst,
|
|
|
|
|
const std::unordered_map<Value*, Value*>& value_map,
|
|
|
|
|
Context& ctx) {
|
|
|
|
|
auto map = [&](Value* v) -> Value* {
|
|
|
|
|
auto it = value_map.find(v);
|
|
|
|
|
return (it != value_map.end()) ? it->second : v;
|
|
|
|
|
};
|
|
|
|
|
Opcode op = inst->GetOpcode();
|
|
|
|
|
switch (op) {
|
|
|
|
|
case Opcode::Add: case Opcode::Sub: case Opcode::Mul:
|
|
|
|
|
case Opcode::Div: case Opcode::Mod:
|
|
|
|
|
case Opcode::Eq: case Opcode::Ne: case Opcode::Lt:
|
|
|
|
|
case Opcode::Le: case Opcode::Gt: case Opcode::Ge:
|
|
|
|
|
case Opcode::And: case Opcode::Or: {
|
|
|
|
|
auto* bin = static_cast<BinaryInst*>(inst);
|
|
|
|
|
return std::make_unique<BinaryInst>(op, inst->GetType(),
|
|
|
|
|
map(bin->GetLhs()), map(bin->GetRhs()),
|
|
|
|
|
ctx.NextTemp());
|
|
|
|
|
}
|
|
|
|
|
case Opcode::Load: {
|
|
|
|
|
auto* load = static_cast<LoadInst*>(inst);
|
|
|
|
|
return std::make_unique<LoadInst>(inst->GetType(), map(load->GetPtr()),
|
|
|
|
|
ctx.NextTemp());
|
|
|
|
|
}
|
|
|
|
|
case Opcode::Store: {
|
|
|
|
|
auto* store = static_cast<StoreInst*>(inst);
|
|
|
|
|
return std::make_unique<StoreInst>(inst->GetType(), map(store->GetValue()),
|
|
|
|
|
map(store->GetPtr()));
|
|
|
|
|
}
|
|
|
|
|
case Opcode::ZExt: {
|
|
|
|
|
auto* cast = static_cast<CastInst*>(inst);
|
|
|
|
|
return std::make_unique<CastInst>(op, inst->GetType(),
|
|
|
|
|
map(cast->GetOperandValue()), ctx.NextTemp());
|
|
|
|
|
}
|
|
|
|
|
case Opcode::Br: {
|
|
|
|
|
auto* br = static_cast<BranchInst*>(inst);
|
|
|
|
|
return std::make_unique<BranchInst>(inst->GetType(), br->GetTarget());
|
|
|
|
|
}
|
|
|
|
|
case Opcode::CondBr: {
|
|
|
|
|
auto* cbr = static_cast<CondBranchInst*>(inst);
|
|
|
|
|
return std::make_unique<CondBranchInst>(inst->GetType(), map(cbr->GetCond()),
|
|
|
|
|
cbr->GetTrueTarget(), cbr->GetFalseTarget());
|
|
|
|
|
}
|
|
|
|
|
default: return nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 展开 countdown 循环
|
|
|
|
|
static bool UnrollSimple(Function* func, BasicBlock* header, BasicBlock* body,
|
|
|
|
|
BasicBlock* exit_bb, PhiInst* phi, int trip_count,
|
|
|
|
|
Context& ctx) {
|
|
|
|
|
auto& fb = const_cast<std::vector<std::unique_ptr<BasicBlock>>&>(func->GetBlocks());
|
|
|
|
|
|
|
|
|
|
// 收集 body 指令(不含回边)
|
|
|
|
|
std::vector<Instruction*> body_insts;
|
|
|
|
|
for (const auto& inst : body->GetInstructions()) {
|
|
|
|
|
if (auto* br = dynamic_cast<BranchInst*>(inst.get()))
|
|
|
|
|
if (br->GetTarget() == header) continue;
|
|
|
|
|
body_insts.push_back(inst.get());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 找 preheader
|
|
|
|
|
BasicBlock* preheader = nullptr;
|
|
|
|
|
for (const auto& bb : func->GetBlocks()) {
|
|
|
|
|
for (const auto& inst : bb->GetInstructions()) {
|
|
|
|
|
if (auto* br = dynamic_cast<BranchInst*>(inst.get()))
|
|
|
|
|
if (br->GetTarget() == header) { preheader = bb.get(); break; }
|
|
|
|
|
if (auto* cbr = dynamic_cast<CondBranchInst*>(inst.get()))
|
|
|
|
|
if (cbr->GetTrueTarget() == header || cbr->GetFalseTarget() == header)
|
|
|
|
|
{ preheader = bb.get(); break; }
|
|
|
|
|
}
|
|
|
|
|
if (preheader) break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 克隆 N 次
|
|
|
|
|
std::vector<std::unique_ptr<BasicBlock>> new_blocks;
|
|
|
|
|
for (int iter = 0; iter < trip_count; ++iter) {
|
|
|
|
|
auto new_bb = std::make_unique<BasicBlock>(ctx.NextTemp() + "_unroll");
|
|
|
|
|
std::unordered_map<Value*, Value*> vm;
|
|
|
|
|
vm[phi] = ctx.GetConstInt(trip_count - iter);
|
|
|
|
|
|
|
|
|
|
for (auto* inst : body_insts) {
|
|
|
|
|
if (auto* bin = dynamic_cast<BinaryInst*>(inst))
|
|
|
|
|
if (bin->GetOpcode() == Opcode::Sub && bin->GetLhs() == phi) continue;
|
|
|
|
|
|
|
|
|
|
auto cloned = CloneInstruction(inst, vm, ctx);
|
|
|
|
|
if (!cloned) continue;
|
|
|
|
|
vm[inst] = cloned.get();
|
|
|
|
|
new_bb->InsertInstructionBeforeTerminator(std::move(cloned));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 最后一份 body 后跳到 exit
|
|
|
|
|
if (iter == trip_count - 1) {
|
|
|
|
|
auto br_exit = std::make_unique<BranchInst>(Type::GetVoidType(), exit_bb);
|
|
|
|
|
new_bb->InsertInstructionBeforeTerminator(std::move(br_exit));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
new_blocks.push_back(std::move(new_bb));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 修复 preheader 跳转
|
|
|
|
|
if (preheader && !new_blocks.empty()) {
|
|
|
|
|
auto& pi = const_cast<std::vector<std::unique_ptr<Instruction>>&>(preheader->GetInstructions());
|
|
|
|
|
if (!pi.empty()) {
|
|
|
|
|
auto* term = pi.back().get();
|
|
|
|
|
if (auto* br = dynamic_cast<BranchInst*>(term))
|
|
|
|
|
br->SetOperand(0, new_blocks[0].get());
|
|
|
|
|
else if (auto* cbr = dynamic_cast<CondBranchInst*>(term)) {
|
|
|
|
|
if (cbr->GetTrueTarget() == header) cbr->SetOperand(1, new_blocks[0].get());
|
|
|
|
|
if (cbr->GetFalseTarget() == header) cbr->SetOperand(2, new_blocks[0].get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 删除 header + body,插入新块
|
|
|
|
|
auto ipos = fb.begin();
|
|
|
|
|
if (preheader) {
|
|
|
|
|
for (auto it = fb.begin(); it != fb.end(); ++it)
|
|
|
|
|
if (it->get() == preheader) { ipos = it + 1; break; }
|
|
|
|
|
}
|
|
|
|
|
for (auto& nb : new_blocks)
|
|
|
|
|
ipos = fb.insert(ipos, std::move(nb)) + 1;
|
|
|
|
|
fb.erase(std::remove_if(fb.begin(), fb.end(),
|
|
|
|
|
[&](const std::unique_ptr<BasicBlock>& bb) {
|
|
|
|
|
return bb.get() == header || bb.get() == body;
|
|
|
|
|
}), fb.end());
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void RunLoopUnroll(Module& module) {
|
|
|
|
|
int unrolled = 0;
|
|
|
|
|
for (auto& func : module.GetFunctions()) {
|
|
|
|
|
if (func->IsExternal()) continue;
|
|
|
|
|
bool changed = true;
|
|
|
|
|
while (changed) {
|
|
|
|
|
changed = false;
|
|
|
|
|
for (const auto& bb : func->GetBlocks()) {
|
|
|
|
|
for (const auto& inst : bb->GetInstructions()) {
|
|
|
|
|
auto* br = dynamic_cast<BranchInst*>(inst.get());
|
|
|
|
|
if (!br) continue;
|
|
|
|
|
BasicBlock* target = br->GetTarget();
|
|
|
|
|
for (const auto& tgt_inst : target->GetInstructions()) {
|
|
|
|
|
auto* cbr = dynamic_cast<CondBranchInst*>(tgt_inst.get());
|
|
|
|
|
if (!cbr) continue;
|
|
|
|
|
BasicBlock *t = cbr->GetTrueTarget(), *f = cbr->GetFalseTarget();
|
|
|
|
|
BasicBlock *body = nullptr, *exit_bb = nullptr;
|
|
|
|
|
if (t == bb.get()) { body = t; exit_bb = f; }
|
|
|
|
|
else if (f == bb.get()) { body = f; exit_bb = t; }
|
|
|
|
|
if (!body || !exit_bb || body == target || exit_bb == target) continue;
|
|
|
|
|
|
|
|
|
|
int tc = 0;
|
|
|
|
|
auto* phi = DetectSimpleCountdown(target, body, exit_bb, tc);
|
|
|
|
|
if (!phi) continue;
|
|
|
|
|
if (UnrollSimple(func.get(), target, body, exit_bb, phi, tc,
|
|
|
|
|
module.GetContext())) {
|
|
|
|
|
++unrolled; changed = true; goto next_func;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
next_func:;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace ir
|