基础版本

lzkk 4 days ago
parent 39d1f8119f
commit 120140abd1

11
.gitignore vendored

@ -116,3 +116,14 @@ Error.txt
time_opt.txt
settings.json
指令数基线.md
.claude/
.claudeignore
.codegraph/
.count_tmp.s
results.csv
results.json
Error.txt
time_opt.txt
settings.json
指令数基线.md

@ -1,350 +1,30 @@
// 简单 countdown 循环全展开:
// - 处理形如 while(len) { body; len = len - 1; } 的递减循环
// - 支持多块体(含 if-else自动 if-convert 为算术 select 后展开
// - 要求 len 初值为编译时常量且 ≤64
// - 全展开后函数可能变为单 BB可被 Inline 内联
// - 要求 body 为单 BBlen 初值为编译时常量且 ≤64
// - 全展开后函数变为单 BB可被 Inline 内联
// - 配合 ConstFold 将 len/power 等常量传播到每次迭代
#include "ir/IR.h"
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
// 展开 icmp ne (zext(X), 0) → X
static Value* UnwrapCondition(Value* cond) {
for (int pass = 0; pass < 2; ++pass) {
auto* outer = dynamic_cast<BinaryInst*>(cond);
if (!outer || outer->GetOpcode() != Opcode::Ne) break;
auto* rc = dynamic_cast<ConstantInt*>(outer->GetRhs());
if (!rc || rc->GetValue() != 0) break;
auto* zext = dynamic_cast<CastInst*>(outer->GetLhs());
if (!zext || zext->GetOpcode() != Opcode::ZExt) break;
cond = zext->GetOperandValue();
}
return cond;
}
// 获取块的唯一无条件跳转目标,否则 nullptr
static BasicBlock* GetOnlySuccessor(BasicBlock* bb) {
const auto& insts = bb->GetInstructions();
if (insts.empty()) return nullptr;
auto* br = dynamic_cast<BranchInst*>(insts.back().get());
return br ? br->GetTarget() : nullptr;
}
// 查找从 start 到 header 路径上的所有块(循环体块),不包含 header 本身
static void FindLoopBodyBlocks(BasicBlock* start, BasicBlock* header,
std::unordered_set<BasicBlock*>& loop_blocks) {
std::vector<BasicBlock*> worklist;
std::unordered_set<BasicBlock*> visited;
worklist.push_back(start);
while (!worklist.empty()) {
auto* bb = worklist.back();
worklist.pop_back();
if (bb == header || !visited.insert(bb).second) continue;
loop_blocks.insert(bb);
for (const auto& inst : bb->GetInstructions()) {
if (auto* br = dynamic_cast<BranchInst*>(inst.get()))
worklist.push_back(br->GetTarget());
else if (auto* cbr = dynamic_cast<CondBranchInst*>(inst.get())) {
worklist.push_back(cbr->GetTrueTarget());
worklist.push_back(cbr->GetFalseTarget());
}
if (inst->IsTerminator()) break;
}
}
}
// 判断块是否可安全 if-convert无副作用
static bool IsSimpleBlock(BasicBlock* bb) {
for (const auto& inst : bb->GetInstructions()) {
switch (inst->GetOpcode()) {
case Opcode::Add: case Opcode::Sub: case Opcode::Mul:
case Opcode::Div: case Opcode::Mod:
case Opcode::And: case Opcode::Or:
case Opcode::Eq: case Opcode::Ne: case Opcode::Lt:
case Opcode::Le: case Opcode::Gt: case Opcode::Ge:
case Opcode::ZExt: case Opcode::SIToFP: case Opcode::FPToSI:
case Opcode::Br:
continue;
default:
return false;
}
}
return true;
}
// phi 辅助函数
static Value* GetPhiValueFrom(PhiInst* phi, BasicBlock* bb) {
for (size_t i = 0; i < phi->GetNumOperands(); i += 2) {
auto* pred = dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1));
if (pred == bb) return phi->GetOperand(i);
}
return nullptr;
}
static void RemovePhiEntriesFrom(PhiInst* phi, BasicBlock* bb) {
std::vector<std::pair<Value*, Value*>> keep;
for (size_t i = 0; i < phi->GetNumOperands(); i += 2) {
auto* pred = dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1));
if (pred != bb)
keep.push_back({phi->GetOperand(i), phi->GetOperand(i + 1)});
}
if (keep.size() * 2 != phi->GetNumOperands()) {
phi->ClearOperands();
for (auto& [val, pred] : keep) {
phi->AddOperand(val);
phi->AddOperand(pred);
}
}
}
static void SetPhiEntry(PhiInst* phi, BasicBlock* bb, Value* val) {
for (size_t i = 0; i < phi->GetNumOperands(); i += 2) {
if (dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1)) == bb) {
phi->SetOperand(i, val);
return;
}
}
phi->AddOperand(val);
phi->AddOperand(bb);
}
// 尝试对单个 diamond 做 if-conversion: B → CondBr → T → Br → M, F == M
// 返回 true 表示转换成功
static bool TryIfConvertDiamond(BasicBlock* B, BasicBlock* T, BasicBlock* F,
Value* cond_i1, Context& ctx) {
if (!IsSimpleBlock(T)) return false;
BasicBlock* M = F;
if (GetOnlySuccessor(T) != M) return false;
// 收集 M 中需要转换的 phi 节点
struct PhiEntry { PhiInst* phi; Value* val_t; Value* val_f; };
std::vector<PhiEntry> to_convert;
for (const auto& inst : M->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst.get());
if (!phi) break;
Value* val_t = GetPhiValueFrom(phi, T);
if (!val_t) continue;
Value* val_f = GetPhiValueFrom(phi, B);
if (!val_f) {
for (size_t i = 0; i < phi->GetNumOperands(); i += 2) {
auto* pred = dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1));
if (pred != T) { val_f = phi->GetOperand(i); break; }
}
}
if (!val_f) continue;
to_convert.push_back({phi, val_t, val_f});
}
if (to_convert.empty()) return false;
// 移除 B 的 CondBr
auto* cbr = B->GetInstructions().back().get();
B->TakeInstruction(cbr);
// 移动 T 的非终止指令到 B
std::vector<Instruction*> t_to_move;
for (const auto& inst : T->GetInstructions())
if (inst->GetOpcode() != Opcode::Br)
t_to_move.push_back(inst.get());
for (auto* inst : t_to_move) {
auto taken = T->TakeInstruction(inst);
B->InsertInstructionBeforeTerminator(std::move(taken));
}
// 移除 T 的 Br
if (!T->GetInstructions().empty())
T->TakeInstruction(T->GetInstructions().back().get());
// 生成算术 select: fv + (tv - fv) * zext(cond)
for (auto& [phi, val_t, val_f] : to_convert) {
if (val_t == val_f) {
RemovePhiEntriesFrom(phi, T);
SetPhiEntry(phi, B, val_f);
continue;
}
auto* zext = B->Append<CastInst>(Opcode::ZExt,
Type::GetInt32Type(), cond_i1, ctx.NextTemp());
auto* diff = B->Append<BinaryInst>(Opcode::Sub, Type::GetInt32Type(),
val_t, val_f, ctx.NextTemp());
auto* masked = B->Append<BinaryInst>(Opcode::Mul, Type::GetInt32Type(),
diff, zext, ctx.NextTemp());
auto* select_val = B->Append<BinaryInst>(Opcode::Add, Type::GetInt32Type(),
val_f, masked, ctx.NextTemp());
RemovePhiEntriesFrom(phi, T);
SetPhiEntry(phi, B, select_val);
}
B->Append<BranchInst>(Type::GetVoidType(), M);
return true;
}
// 对循环体内的 diamond 做迭代 if-conversion从内向外
static void IfConvertLoopBody(const std::unordered_set<BasicBlock*>& loop_blocks,
Context& ctx) {
bool changed = true;
while (changed) {
changed = false;
for (auto* bb : loop_blocks) {
const auto& insts = bb->GetInstructions();
if (insts.empty()) continue;
auto* cbr = dynamic_cast<CondBranchInst*>(insts.back().get());
if (!cbr) continue;
BasicBlock* T = cbr->GetTrueTarget();
BasicBlock* F = cbr->GetFalseTarget();
// 只转换 T 在循环体内且 F 在循环体内(或 F 是 latch 合并且仍在体内)
if (!loop_blocks.count(T)) continue;
Value* cond = UnwrapCondition(cbr->GetCond());
if (TryIfConvertDiamond(bb, T, F, cond, ctx)) {
changed = true;
break;
}
}
}
// 清理仅剩单一条目的 phi
for (auto* bb : loop_blocks) {
auto& insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
bb->GetInstructions());
for (size_t i = 0; i < insts.size(); ) {
auto* phi = dynamic_cast<PhiInst*>(insts[i].get());
if (!phi) break;
Value* unique_val = nullptr;
bool all_same = true;
for (size_t j = 0; j < phi->GetNumOperands(); j += 2) {
Value* v = phi->GetOperand(j);
if (!unique_val) unique_val = v;
else if (unique_val != v) { all_same = false; break; }
}
if (all_same && unique_val) {
phi->ReplaceAllUsesWith(unique_val);
phi->ClearOperands();
insts.erase(insts.begin() + i);
continue;
}
++i;
}
}
}
// 合并循环体内单前驱无 phi 块(将链式 BB 压平)
static void MergeLoopBodyBlocks(const std::unordered_set<BasicBlock*>& loop_blocks,
Function* func) {
// 构建使用 def-use 信息计算前驱
auto compute_preds = [&](BasicBlock* bb) -> std::vector<BasicBlock*> {
std::vector<BasicBlock*> preds;
for (auto* other : loop_blocks) {
if (other == bb) continue;
const auto& insts = other->GetInstructions();
if (insts.empty()) continue;
auto* term = insts.back().get();
if (auto* br = dynamic_cast<BranchInst*>(term)) {
if (br->GetTarget() == bb) preds.push_back(other);
} else if (auto* cbr = dynamic_cast<CondBranchInst*>(term)) {
if (cbr->GetTrueTarget() == bb || cbr->GetFalseTarget() == bb)
preds.push_back(other);
}
}
return preds;
};
bool changed = true;
while (changed) {
changed = false;
for (auto* bb : loop_blocks) {
// 检查是否有 phi
bool has_phi = false;
for (const auto& inst : bb->GetInstructions()) {
if (dynamic_cast<PhiInst*>(inst.get())) { has_phi = true; break; }
}
if (has_phi) continue;
auto preds = compute_preds(bb);
if (preds.size() != 1) continue;
BasicBlock* pred = preds[0];
if (pred == bb) continue;
// pred 必须以无条件 Br 指向 bb
auto* pred_term = pred->GetInstructions().back().get();
auto* br = dynamic_cast<BranchInst*>(pred_term);
if (!br || br->GetTarget() != bb) continue;
// 合并:移除 pred 的 terminator
pred->TakeInstruction(pred_term);
// 移动 bb 的所有指令到 pred
auto& bb_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
bb->GetInstructions());
std::vector<Instruction*> to_move;
for (auto& inst : bb_insts)
to_move.push_back(inst.get());
for (auto* inst : to_move) {
auto taken = bb->TakeInstruction(inst);
pred->InsertInstructionBeforeTerminator(std::move(taken));
}
// 更新后继块的 phi将对 bb 的引用改为 pred
for (auto* succ_bb : loop_blocks) {
if (succ_bb == bb) continue;
auto& succ_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
succ_bb->GetInstructions());
for (auto& inst : succ_insts) {
auto* phi = dynamic_cast<PhiInst*>(inst.get());
if (!phi) break;
for (size_t i = 0; i < phi->GetNumOperands(); i += 2) {
if (dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1)) == bb)
phi->SetOperand(i + 1, pred);
}
}
}
// 也要更新 header 的 phi
auto* header = func->GetEntry(); // will be found by scanning
// 实际上 header 不在 loop_blocks 中,需要单独处理
for (const auto& hdr_bb : func->GetBlocks()) {
for (const auto& inst : hdr_bb->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst.get());
if (!phi) break;
for (size_t i = 0; i < phi->GetNumOperands(); i += 2) {
if (dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1)) == bb)
phi->SetOperand(i + 1, pred);
}
}
}
changed = true;
break;
}
}
}
// 检测递减循环模式,返回 (phi, trip_count) 或 nullptr
// 支持多块体header → body_entry → ... → latch → header
static PhiInst* DetectSimpleCountdown(BasicBlock* header, BasicBlock* body_entry,
static PhiInst* DetectSimpleCountdown(BasicBlock* header, BasicBlock* body,
BasicBlock* exit_bb, int& trip_count) {
// 找到 latch循环体内分支回 header 的块)
BasicBlock* latch = nullptr;
std::unordered_set<BasicBlock*> loop_blocks;
FindLoopBodyBlocks(body_entry, header, loop_blocks);
for (auto* bb : loop_blocks) {
for (const auto& inst : bb->GetInstructions()) {
if (auto* br = dynamic_cast<BranchInst*>(inst.get()))
if (br->GetTarget() == header) { latch = bb; break; }
}
if (latch) break;
// 检查 body → header 回边
bool has_backedge = false;
for (const auto& inst : body->GetInstructions()) {
if (auto* br = dynamic_cast<BranchInst*>(inst.get()))
if (br->GetTarget() == header) has_backedge = true;
}
if (!latch) return nullptr;
if (!has_backedge) return nullptr;
// 查找 header 中的 countdown phi
for (const auto& inst : header->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst.get());
if (!phi) continue;
@ -357,8 +37,8 @@ static PhiInst* DetectSimpleCountdown(BasicBlock* header, BasicBlock* body_entry
Value* init_val = nullptr;
Value* update_val = nullptr;
if (bb0 != latch && bb1 == latch) { init_val = val0; update_val = val1; }
else if (bb1 != latch && bb0 == latch) { init_val = val1; update_val = val0; }
if (bb0 != body && bb1 == body) { init_val = val0; update_val = val1; }
else if (bb1 != body && bb0 == body) { init_val = val1; update_val = val0; }
else continue;
auto* init_c = dynamic_cast<ConstantInt*>(init_val);
@ -377,7 +57,14 @@ static PhiInst* DetectSimpleCountdown(BasicBlock* header, BasicBlock* body_entry
for (const auto& inst : header->GetInstructions()) {
if (auto* cbr = dynamic_cast<CondBranchInst*>(inst.get())) {
Value* cond = cbr->GetCond();
cond = UnwrapCondition(cond);
if (auto* outer = dynamic_cast<BinaryInst*>(cond)) {
if (outer->GetOpcode() == Opcode::Ne) {
auto* rc = dynamic_cast<ConstantInt*>(outer->GetRhs());
if (rc && rc->GetValue() == 0)
if (auto* zext = dynamic_cast<CastInst*>(outer->GetLhs()))
if (zext->GetOpcode() == Opcode::ZExt) cond = zext->GetOperandValue();
}
}
if (auto* cmp = dynamic_cast<BinaryInst*>(cond)) {
Value* other = nullptr;
if (cmp->GetLhs() == phi) other = cmp->GetRhs();
@ -454,40 +141,11 @@ static std::unique_ptr<Instruction> CloneInstruction(
}
// 展开 countdown 循环
static bool UnrollSimple(Function* func, BasicBlock* header, BasicBlock* body_entry,
static bool UnrollSimple(Function* func, BasicBlock* header, BasicBlock* body,
BasicBlock* exit_bb, PhiInst* phi, int trip_count,
Context& ctx) {
auto& fb = const_cast<std::vector<std::unique_ptr<BasicBlock>>&>(func->GetBlocks());
// 找所有循环体块
std::unordered_set<BasicBlock*> loop_blocks;
FindLoopBodyBlocks(body_entry, header, loop_blocks);
// 找 latch含回边的块
BasicBlock* latch = nullptr;
for (auto* bb : loop_blocks) {
for (const auto& inst : bb->GetInstructions()) {
if (auto* br = dynamic_cast<BranchInst*>(inst.get()))
if (br->GetTarget() == header) latch = bb;
}
}
// if-convert 循环体内的 diamond 模式
IfConvertLoopBody(loop_blocks, ctx);
// 合并不再有 phi 的链式块
MergeLoopBodyBlocks(loop_blocks, func);
// 重新找 bodyif-conversion 后 body_entry 可能已变)
BasicBlock* body = nullptr;
for (auto* bb : loop_blocks) {
for (const auto& inst : bb->GetInstructions()) {
if (auto* br = dynamic_cast<BranchInst*>(inst.get()))
if (br->GetTarget() == header) body = bb;
}
}
if (!body) body = body_entry;
// 收集 body 指令(不含回边)
std::vector<Instruction*> body_insts;
for (const auto& inst : body->GetInstructions()) {
@ -509,46 +167,14 @@ static bool UnrollSimple(Function* func, BasicBlock* header, BasicBlock* body_en
if (preheader) break;
}
// 收集循环体中所有 phi 的初始值和步进值(用于跨迭代值映射)
struct PhiMapping {
PhiInst* phi;
Value* init_val;
Instruction* latch_val; // 指令在 latch 中产生下一个迭代的值
};
std::vector<PhiMapping> phi_mappings;
for (const auto& inst : header->GetInstructions()) {
auto* hdr_phi = dynamic_cast<PhiInst*>(inst.get());
if (!hdr_phi) break;
if (hdr_phi == phi) continue; // len 特殊处理
Value* init = nullptr;
Value* latch_val = nullptr;
for (size_t i = 0; i < hdr_phi->GetNumOperands(); i += 2) {
auto* pred = dynamic_cast<BasicBlock*>(hdr_phi->GetOperand(i + 1));
if (pred != latch && pred != body) init = hdr_phi->GetOperand(i);
else latch_val = hdr_phi->GetOperand(i);
}
if (init && latch_val)
phi_mappings.push_back({hdr_phi, init, dynamic_cast<Instruction*>(latch_val)});
}
// 克隆 N 次
std::vector<std::unique_ptr<BasicBlock>> new_blocks;
std::unordered_map<PhiInst*, Value*> prev_values; // phi → 上次迭代产生的新值
for (auto& pm : phi_mappings)
prev_values[pm.phi] = pm.init_val;
for (int iter = 0; iter < trip_count; ++iter) {
auto new_bb = std::make_unique<BasicBlock>(ctx.NextTemp() + "_unroll");
std::unordered_map<Value*, Value*> vm;
// phi 替换len → 常量,其他 → 上次迭代的值
vm[phi] = ctx.GetConstInt(trip_count - iter);
for (auto& pm : phi_mappings)
vm[pm.phi] = prev_values[pm.phi];
for (auto* inst : body_insts) {
// 跳过 len = len - 1
if (auto* bin = dynamic_cast<BinaryInst*>(inst))
if (bin->GetOpcode() == Opcode::Sub && bin->GetLhs() == phi) continue;
@ -558,13 +184,6 @@ static bool UnrollSimple(Function* func, BasicBlock* header, BasicBlock* body_en
new_bb->InsertInstructionBeforeTerminator(std::move(cloned));
}
// 更新下次迭代的 phi 值
for (auto& pm : phi_mappings) {
auto it = vm.find(pm.latch_val);
if (it != vm.end())
prev_values[pm.phi] = it->second;
}
// 最后一份 body 后跳到 exit
if (iter == trip_count - 1) {
auto br_exit = std::make_unique<BranchInst>(Type::GetVoidType(), exit_bb);
@ -588,7 +207,7 @@ static bool UnrollSimple(Function* func, BasicBlock* header, BasicBlock* body_en
}
}
// 删除 header + 循环体块,插入新块
// 删除 header + body,插入新块
auto ipos = fb.begin();
if (preheader) {
for (auto it = fb.begin(); it != fb.end(); ++it)
@ -596,15 +215,197 @@ static bool UnrollSimple(Function* func, BasicBlock* header, BasicBlock* body_en
}
for (auto& nb : new_blocks)
ipos = fb.insert(ipos, std::move(nb)) + 1;
// 移除 header 和所有循环体块
fb.erase(std::remove_if(fb.begin(), fb.end(),
[&](const std::unique_ptr<BasicBlock>& bb) {
return bb.get() == header || loop_blocks.count(bb.get());
return bb.get() == header || bb.get() == body;
}), fb.end());
return true;
}
// 展开 icmp ne (zext(X), 0) → X
static Value* UnwrapZExtCond(Value* cond) {
if (auto* outer = dynamic_cast<BinaryInst*>(cond)) {
if (outer->GetOpcode() == Opcode::Ne) {
auto* rc = dynamic_cast<ConstantInt*>(outer->GetRhs());
if (rc && rc->GetValue() == 0)
if (auto* zext = dynamic_cast<CastInst*>(outer->GetLhs()))
if (zext->GetOpcode() == Opcode::ZExt)
return zext->GetOperandValue();
}
}
return cond;
}
// 简化含 if-else 的循环体:将简单 diamond 转成算术 select 使体变单 BB
// 模式: body → CondBr → if_then → Br → merge
// ↘ merge
// 其中 if_then 只有一条 add 指令
static bool SimplifyLoopBody(Function* func, Context& ctx) {
bool changed = false;
for (const auto& bb : func->GetBlocks()) {
const auto& insts = bb->GetInstructions();
if (insts.empty()) continue;
auto* cbr = dynamic_cast<CondBranchInst*>(insts.back().get());
if (!cbr) continue;
BasicBlock* T = cbr->GetTrueTarget();
BasicBlock* F = cbr->GetFalseTarget();
BasicBlock* B = bb.get();
// 只处理一个分支有单条 add、另一分支直达 merge 的模式
auto check_simple_if = [](BasicBlock* then_bb, BasicBlock* merge_bb) -> Instruction* {
const auto& ti = then_bb->GetInstructions();
// 恰好两条指令: add + br
if (ti.size() != 2) return nullptr;
auto* add = dynamic_cast<BinaryInst*>(ti[0].get());
if (!add || add->GetOpcode() != Opcode::Add) return nullptr;
auto* br = dynamic_cast<BranchInst*>(ti[1].get());
if (!br || br->GetTarget() != merge_bb) return nullptr;
return add;
};
// 模式: T 有 add+Br 到 F
if (auto* add = check_simple_if(T, F)) {
// 检查 F 有 phi 合并来自 B 和 T 的值
for (const auto& mi : F->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(mi.get());
if (!phi) break;
Value* val_t = nullptr;
Value* val_f = nullptr;
for (size_t i = 0; i < phi->GetNumOperands(); i += 2) {
auto* pred = dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1));
if (pred == T) val_t = phi->GetOperand(i);
if (pred == B) val_f = phi->GetOperand(i);
}
if (!val_t || !val_f) continue;
// if-convert: fv + (tv - fv) * zext(cond)
Value* cond = UnwrapZExtCond(cbr->GetCond());
// 移除 B 的 CondBr
B->TakeInstruction(cbr);
// 移动 T 的 add 指令到 B
auto add_owned = T->TakeInstruction(add);
B->InsertInstructionBeforeTerminator(std::move(add_owned));
// 生成算术 select
auto* zext = B->Append<CastInst>(Opcode::ZExt,
Type::GetInt32Type(), cond, ctx.NextTemp());
auto* diff = B->Append<BinaryInst>(Opcode::Sub, Type::GetInt32Type(),
val_t, val_f, ctx.NextTemp());
auto* masked = B->Append<BinaryInst>(Opcode::Mul, Type::GetInt32Type(),
diff, zext, ctx.NextTemp());
auto* sel = B->Append<BinaryInst>(Opcode::Add, Type::GetInt32Type(),
val_f, masked, ctx.NextTemp());
// 更新 phi: 移除 T 条目,将 B 条目值设为 sel
std::vector<std::pair<Value*, Value*>> keep;
for (size_t i = 0; i < phi->GetNumOperands(); i += 2) {
auto* pred = dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1));
if (pred != T && pred != B)
keep.push_back({phi->GetOperand(i), phi->GetOperand(i + 1)});
}
phi->ClearOperands();
phi->AddOperand(sel);
phi->AddOperand(B);
for (auto& [v, p] : keep) {
phi->AddOperand(v);
phi->AddOperand(p);
}
// B 跳转到 F
B->Append<BranchInst>(Type::GetVoidType(), F);
// 清理 T移除其 BrT 变为空块)
auto& t_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(T->GetInstructions());
if (!t_insts.empty()) T->TakeInstruction(t_insts.back().get());
changed = true;
goto next_block;
}
}
next_block:;
}
// 合并单前驱无 phi 块
for (const auto& bb : func->GetBlocks()) {
// 计算前驱
std::vector<BasicBlock*> preds;
for (const auto& other : func->GetBlocks()) {
if (other.get() == bb.get()) continue;
const auto& oi = other->GetInstructions();
if (oi.empty()) continue;
if (auto* br = dynamic_cast<BranchInst*>(oi.back().get()))
if (br->GetTarget() == bb.get()) preds.push_back(other.get());
}
if (preds.size() != 1) continue;
BasicBlock* pred = preds[0];
// 检查无 phi
bool has_phi = false;
for (const auto& mi : bb->GetInstructions())
if (dynamic_cast<PhiInst*>(mi.get())) { has_phi = true; break; }
if (has_phi) continue;
// pred 必须以无条件 Br 指向此块
auto* pred_term = pred->GetInstructions().back().get();
auto* br = dynamic_cast<BranchInst*>(pred_term);
if (!br || br->GetTarget() != bb.get()) continue;
// 合并: pred 移除 terminator, 移入 bb 的全部指令
pred->TakeInstruction(pred_term);
auto& bb_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(bb->GetInstructions());
std::vector<Instruction*> to_move;
for (auto& mi : bb_insts) to_move.push_back(mi.get());
for (auto* mi : to_move) {
auto taken = bb->TakeInstruction(mi);
pred->InsertInstructionBeforeTerminator(std::move(taken));
}
// 更新后继块 phi 中对 bb 的引用
for (const auto& succ : func->GetBlocks()) {
auto& si = const_cast<std::vector<std::unique_ptr<Instruction>>&>(succ->GetInstructions());
for (auto& mi : si) {
auto* phi = dynamic_cast<PhiInst*>(mi.get());
if (!phi) break;
for (size_t i = 0; i < phi->GetNumOperands(); i += 2)
if (dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1)) == bb.get())
phi->SetOperand(i + 1, pred);
}
}
changed = true;
break;
}
// 清理冗余 phi(单一条目或所有值相同)
for (const auto& bb : func->GetBlocks()) {
auto& insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(bb->GetInstructions());
for (size_t i = 0; i < insts.size(); ) {
auto* phi = dynamic_cast<PhiInst*>(insts[i].get());
if (!phi) break;
Value* unique_val = nullptr;
bool all_same = true;
for (size_t j = 0; j < phi->GetNumOperands(); j += 2) {
Value* v = phi->GetOperand(j);
if (!unique_val) unique_val = v;
else if (unique_val != v) { all_same = false; break; }
}
if (all_same && unique_val) {
phi->ReplaceAllUsesWith(unique_val);
phi->ClearOperands();
insts.erase(insts.begin() + i);
continue;
}
++i;
}
}
return changed;
}
} // namespace
void RunLoopUnroll(Module& module) {
@ -612,6 +413,10 @@ void RunLoopUnroll(Module& module) {
for (auto& func : module.GetFunctions()) {
if (func->IsExternal()) continue;
if (!func->GetType()->IsInt32()) continue;
// 先简化循环体: if-convert + merge, 让多块体变单块体
while (SimplifyLoopBody(func.get())) {}
bool changed = true;
while (changed) {
changed = false;
@ -624,42 +429,15 @@ void RunLoopUnroll(Module& module) {
auto* cbr = dynamic_cast<CondBranchInst*>(tgt_inst.get());
if (!cbr) continue;
BasicBlock *t = cbr->GetTrueTarget(), *f = cbr->GetFalseTarget();
// 确定循环体入口和出口
BasicBlock *body_entry = nullptr, *exit_bb = nullptr;
// 模式1: bb 是 preheaderbb → header → body/exit
// 找 header 的哪个后继包含回到 header 的路径
for (auto* cand : {t, f}) {
std::unordered_set<BasicBlock*> temp;
FindLoopBodyBlocks(cand, target, temp);
bool has_backedge = false;
for (auto* lb : temp) {
for (const auto& li : lb->GetInstructions()) {
if (auto* lbr = dynamic_cast<BranchInst*>(li.get()))
if (lbr->GetTarget() == target) has_backedge = true;
}
}
if (has_backedge) body_entry = cand;
}
if (!body_entry) {
// 回退到原有逻辑body_entry 是等于 bb 的那个
if (t == bb.get()) { body_entry = t; exit_bb = f; }
else if (f == bb.get()) { body_entry = f; exit_bb = t; }
else continue;
}
// 找 exit
if (!exit_bb) {
if (body_entry == t) exit_bb = f;
else exit_bb = t;
}
// 不能自指
if (!body_entry || !exit_bb || body_entry == target || exit_bb == target) continue;
BasicBlock *body = nullptr, *exit_bb = nullptr;
if (t == bb.get()) { body = t; exit_bb = f; }
else if (f == bb.get()) { body = f; exit_bb = t; }
if (!body || !exit_bb || body == target || exit_bb == target) continue;
int tc = 0;
auto* phi = DetectSimpleCountdown(target, body_entry, exit_bb, tc);
auto* phi = DetectSimpleCountdown(target, body, exit_bb, tc);
if (!phi) continue;
if (UnrollSimple(func.get(), target, body_entry, exit_bb, phi, tc,
if (UnrollSimple(func.get(), target, body, exit_bb, phi, tc,
module.GetContext())) {
++unrolled; changed = true; goto next_func;
}

Loading…
Cancel
Save