You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/ir/passes/LoopIdiom.cpp

466 lines
16 KiB

// 循环习语优化:
// - 将连续常量填充的规范循环替换为运行时批量填充调用
// - 当前仅处理 step=1、init=0、单 store 的 innermost 循环
#include "ir/IR.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "LoopPassUtils.h"
namespace ir {
namespace passes {
namespace {
struct FillLoopCandidate {
analysis::Loop* loop = nullptr;
BasicBlock* preheader = nullptr;
BasicBlock* header = nullptr;
BasicBlock* exit = nullptr;
PhiInst* induction = nullptr;
Value* bound = nullptr;
Value* base_ptr = nullptr;
Value* offset = nullptr;
int fill_value = 0;
};
struct GuardedRowFillCandidate {
analysis::Loop* loop = nullptr;
BasicBlock* preheader = nullptr;
BasicBlock* header = nullptr;
BasicBlock* body = nullptr;
BasicBlock* action = nullptr;
BasicBlock* latch = nullptr;
BasicBlock* exit = nullptr;
PhiInst* induction = nullptr;
Value* bound = nullptr;
PhiInst* linear = nullptr;
Value* linear_init = nullptr;
int linear_step = 0;
Value* base_ptr = nullptr;
Value* threshold = nullptr;
bool prefix = false;
int fill_value = 0;
};
bool ExprDependsOn(Value* value, Value* needle,
std::unordered_set<Value*>& visiting) {
if (value == needle) return true;
auto* inst = dynamic_cast<Instruction*>(value);
if (!inst) return false;
if (!visiting.insert(value).second) return false;
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
if (ExprDependsOn(inst->GetOperand(i), needle, visiting)) {
return true;
}
}
return false;
}
bool ExprDependsOn(Value* value, Value* needle) {
std::unordered_set<Value*> visiting;
return ExprDependsOn(value, needle, visiting);
}
Value* GetIncomingForBlock(PhiInst* phi, BasicBlock* block) {
if (!phi || !block) return nullptr;
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (phi->GetIncomingBlock(i) == block) {
return phi->GetIncomingValue(i);
}
}
return nullptr;
}
Value* MaterializeInvariantExpr(Value* value, analysis::Loop* loop, IRBuilder& builder,
ValueMap& remap) {
auto it = remap.find(value);
if (it != remap.end()) return it->second;
if (dynamic_cast<ConstantValue*>(value) || dynamic_cast<Argument*>(value) ||
dynamic_cast<GlobalVariable*>(value) || dynamic_cast<Function*>(value)) {
return value;
}
auto* inst = dynamic_cast<Instruction*>(value);
if (!inst || !loop->Contains(inst->GetParent())) return value;
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
auto* operand = inst->GetOperand(i);
remap[operand] = MaterializeInvariantExpr(operand, loop, builder, remap);
}
auto cloned = CloneInstruction(inst, remap, ".idiom");
if (!cloned) return nullptr;
auto* raw = cloned.get();
InsertInstruction(builder.GetInsertBlock(), std::move(cloned));
remap[inst] = raw;
return raw;
}
bool HasOutsideUse(Instruction* inst, analysis::Loop* loop) {
for (const auto& use : inst->GetUses()) {
auto* user = dynamic_cast<Instruction*>(use.GetUser());
if (!user) return true;
if (!user->GetParent() || !loop->Contains(user->GetParent())) {
return true;
}
}
return false;
}
Value* MatchContiguousOffset(Value* index, PhiInst* iv, analysis::Loop* loop) {
if (index == iv) return nullptr;
auto* bin = dynamic_cast<BinaryInst*>(index);
if (!bin || bin->GetOpcode() != Opcode::Add ||
!bin->GetType() || !bin->GetType()->IsInt32()) {
return nullptr;
}
if (bin->GetLhs() == iv && IsLoopInvariantValue(bin->GetRhs(), loop)) {
return bin->GetRhs();
}
if (bin->GetRhs() == iv && IsLoopInvariantValue(bin->GetLhs(), loop)) {
return bin->GetLhs();
}
return nullptr;
}
bool BuildFillLoopCandidate(Function& func, analysis::Loop* loop,
FillLoopCandidate* out) {
(void)func;
auto match = MatchCanonicalLoop(loop);
if (!match.has_value()) return false;
if (match->loop->GetChildren().size() != 0) return false;
if (match->body != match->latch || loop->GetBlocks().size() != 2) return false;
if (match->header_phis.size() != 1 ||
match->header_phis.front() != match->induction.phi) {
return false;
}
if (match->induction.step != 1) return false;
if (match->header_cmp->GetCmpOp() != CmpOp::Lt) return false;
auto* init_ci = dynamic_cast<ConstantInt*>(match->induction.init);
if (!init_ci || init_ci->GetValue() != 0) return false;
if (!match->exit->GetInstructions().empty() &&
dynamic_cast<PhiInst*>(match->exit->GetInstructions().front().get()) != nullptr) {
return false;
}
StoreInst* store = nullptr;
std::vector<Instruction*> body_insts;
for (const auto& inst_ptr : match->body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dynamic_cast<PhiInst*>(inst) != nullptr || inst->IsTerminator()) continue;
body_insts.push_back(inst);
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Gep:
case Opcode::Store:
break;
default:
return false;
}
if (inst != match->induction.next && HasOutsideUse(inst, loop)) {
return false;
}
if (auto* maybe_store = dynamic_cast<StoreInst*>(inst)) {
if (store) return false;
store = maybe_store;
}
}
if (!store) return false;
auto* fill_ci = dynamic_cast<ConstantInt*>(store->GetValue());
if (!fill_ci) return false;
auto* gep = dynamic_cast<GepInst*>(store->GetPtr());
if (!gep || !gep->GetBase() || !gep->GetBase()->GetType() ||
!gep->GetBase()->GetType()->IsPtrInt32()) {
return false;
}
Value* offset = MatchContiguousOffset(gep->GetIndex(), match->induction.phi, loop);
if (gep->GetIndex() != match->induction.phi && offset == nullptr) {
return false;
}
out->loop = loop;
out->preheader = match->preheader;
out->header = match->header;
out->exit = match->exit;
out->induction = match->induction.phi;
out->bound = match->bound;
out->base_ptr = gep->GetBase();
out->offset = offset;
out->fill_value = fill_ci->GetValue();
return true;
}
Function* GetOrCreateFillI32(Module& module) {
if (auto* fn = module.FindFunction("__fill_i32")) return fn;
auto* fn = module.CreateFunction("__fill_i32", Type::GetVoidType(),
{Type::GetPtrInt32Type(), Type::GetInt32Type(),
Type::GetInt32Type()});
fn->SetExternal(true);
return fn;
}
Function* GetOrCreateFillRowsI32(Module& module) {
if (auto* fn = module.FindFunction("__fill_rows_i32")) return fn;
auto* fn = module.CreateFunction(
"__fill_rows_i32", Type::GetVoidType(),
{Type::GetPtrInt32Type(), Type::GetInt32Type(), Type::GetInt32Type(),
Type::GetInt32Type(), Type::GetInt32Type(), Type::GetInt32Type()});
fn->SetExternal(true);
return fn;
}
bool BuildGuardedRowFillCandidate(Function& func, analysis::Loop* loop,
GuardedRowFillCandidate* out) {
(void)func;
if (!loop) return false;
if (loop->GetChildren().size() != 0) return false;
if (loop->GetBlocks().size() != 4) return false;
auto* header = loop->GetHeader();
auto* preheader = loop->GetPreheader();
if (!header || !preheader) return false;
if (loop->GetLatches().size() != 1) return false;
auto* latch = loop->GetLatches().front();
if (!latch) return false;
auto* header_term = header->HasTerminator()
? dynamic_cast<CondBranchInst*>(
header->MutableInstructions().back().get())
: nullptr;
if (!header_term) return false;
auto* header_cmp = dynamic_cast<CmpInst*>(header_term->GetCond());
if (!header_cmp || header_cmp->GetCmpOp() != CmpOp::Lt) return false;
auto induction = MatchCanonicalInduction(header, preheader, latch);
if (!induction.has_value() || induction->step != 1) return false;
Value* bound = nullptr;
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
if (loop->Contains(header_term->GetTrueBlock()) &&
!loop->Contains(header_term->GetFalseBlock())) {
body = header_term->GetTrueBlock();
exit = header_term->GetFalseBlock();
} else if (loop->Contains(header_term->GetFalseBlock()) &&
!loop->Contains(header_term->GetTrueBlock())) {
body = header_term->GetFalseBlock();
exit = header_term->GetTrueBlock();
} else {
return false;
}
if (header_cmp->GetLhs() == induction->phi &&
IsLoopInvariantValue(header_cmp->GetRhs(), loop)) {
bound = header_cmp->GetRhs();
} else if (header_cmp->GetRhs() == induction->phi &&
IsLoopInvariantValue(header_cmp->GetLhs(), loop)) {
bound = header_cmp->GetLhs();
} else {
return false;
}
auto header_phis = CollectHeaderPhis(header);
if (header_phis.size() != 2) return false;
PhiInst* linear_phi = nullptr;
for (auto* phi : header_phis) {
if (phi != induction->phi) {
linear_phi = phi;
break;
}
}
if (!linear_phi || !linear_phi->GetType() || !linear_phi->GetType()->IsInt32()) {
return false;
}
auto* linear_init = GetIncomingForBlock(linear_phi, preheader);
auto* linear_next = GetIncomingForBlock(linear_phi, latch);
auto* linear_next_bin = dynamic_cast<BinaryInst*>(linear_next);
if (!linear_init || !linear_next_bin ||
linear_next_bin->GetOpcode() != Opcode::Add ||
linear_next_bin->GetLhs() != linear_phi) {
return false;
}
auto* linear_step_ci = dynamic_cast<ConstantInt*>(linear_next_bin->GetRhs());
if (!linear_step_ci || linear_step_ci->GetValue() <= 0) return false;
auto* guard = body->HasTerminator()
? dynamic_cast<CondBranchInst*>(body->MutableInstructions().back().get())
: nullptr;
if (!guard) return false;
BasicBlock* action = nullptr;
if (guard->GetTrueBlock() == latch && loop->Contains(guard->GetFalseBlock())) {
action = guard->GetFalseBlock();
} else if (guard->GetFalseBlock() == latch &&
loop->Contains(guard->GetTrueBlock())) {
action = guard->GetTrueBlock();
} else {
return false;
}
auto* action_term =
dynamic_cast<BranchInst*>(action->MutableInstructions().back().get());
if (!action_term || action_term->GetTarget() != latch) return false;
CallInst* fill_call = nullptr;
GepInst* fill_gep = nullptr;
for (const auto& inst_ptr : action->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator()) continue;
if (auto* gep = dynamic_cast<GepInst*>(inst)) {
fill_gep = gep;
continue;
}
fill_call = dynamic_cast<CallInst*>(inst);
}
if (!fill_call || !fill_gep || fill_call->GetNumArgs() != 3) return false;
auto* callee = fill_call->GetCallee();
if (!callee || callee->GetName() != "__fill_i32") return false;
auto* fill_value = dynamic_cast<ConstantInt*>(fill_call->GetArg(2));
if (!fill_value) return false;
if (fill_call->GetArg(0) != fill_gep || fill_call->GetArg(1) != bound) {
return false;
}
if (fill_gep->GetIndex() != linear_phi) return false;
if (!fill_gep->GetBase() || !fill_gep->GetBase()->GetType() ||
!fill_gep->GetBase()->GetType()->IsPtrInt32()) {
return false;
}
auto* guard_cmp = dynamic_cast<CmpInst*>(guard->GetCond());
if (!guard_cmp || !guard_cmp->GetType() || !guard_cmp->GetType()->IsInt32()) {
return false;
}
Value* threshold = nullptr;
bool prefix = false;
bool suffix = false;
if (guard_cmp->GetLhs() == induction->phi &&
!ExprDependsOn(guard_cmp->GetRhs(), induction->phi) &&
!ExprDependsOn(guard_cmp->GetRhs(), linear_phi)) {
threshold = guard_cmp->GetRhs();
if (guard_cmp->GetCmpOp() == CmpOp::Lt && action == guard->GetTrueBlock()) {
prefix = true;
} else if (guard_cmp->GetCmpOp() == CmpOp::Ge &&
action == guard->GetTrueBlock()) {
suffix = true;
} else if (guard_cmp->GetCmpOp() == CmpOp::Lt &&
action == guard->GetFalseBlock()) {
suffix = true;
} else if (guard_cmp->GetCmpOp() == CmpOp::Ge &&
action == guard->GetFalseBlock()) {
prefix = true;
} else {
return false;
}
} else {
return false;
}
out->loop = loop;
out->preheader = preheader;
out->header = header;
out->body = body;
out->action = action;
out->latch = latch;
out->exit = exit;
out->induction = induction->phi;
out->bound = bound;
out->linear = linear_phi;
out->linear_init = linear_init;
out->linear_step = linear_step_ci->GetValue();
out->base_ptr = fill_gep->GetBase();
out->fill_value = fill_value->GetValue();
out->threshold = threshold;
out->prefix = prefix;
return prefix || suffix;
}
bool RunFillLoop(Function& func, const FillLoopCandidate& cand,
Module& module, Context& ctx) {
(void)func;
auto* fill_fn = GetOrCreateFillI32(module);
auto* preheader = cand.preheader;
if (preheader->HasTerminator()) {
preheader->RemoveInstruction(preheader->MutableInstructions().back().get());
}
IRBuilder builder(ctx, preheader);
Value* start_ptr = cand.base_ptr;
if (cand.offset) {
start_ptr = builder.CreateGep(cand.base_ptr, cand.offset, ctx.NextTemp());
}
builder.CreateCall(fill_fn, {start_ptr, cand.bound, ctx.GetConstInt(cand.fill_value)},
"");
preheader->Append<BranchInst>(Type::GetVoidType(), cand.exit);
cand.induction->RemoveIncomingBlock(preheader);
preheader->RemoveSuccessor(cand.header);
cand.header->RemovePredecessor(preheader);
preheader->AddSuccessor(cand.exit);
cand.exit->AddPredecessor(preheader);
return true;
}
bool RunGuardedRowFillLoop(Function& func, const GuardedRowFillCandidate& cand,
Module& module, Context& ctx) {
(void)func;
auto* fill_rows_fn = GetOrCreateFillRowsI32(module);
auto* preheader = cand.preheader;
if (preheader->HasTerminator()) {
preheader->RemoveInstruction(preheader->MutableInstructions().back().get());
}
IRBuilder builder(ctx, preheader);
ValueMap remap;
auto* threshold =
MaterializeInvariantExpr(cand.threshold, cand.loop, builder, remap);
if (!threshold) return false;
Value* start_index = cand.prefix ? ctx.GetConstInt(0) : threshold;
Value* rows = cand.prefix ? threshold : nullptr;
if (!cand.prefix) {
rows = builder.CreateSub(cand.bound, start_index, ctx.NextTemp());
}
auto* start_offset_mul =
builder.CreateMul(start_index, ctx.GetConstInt(cand.linear_step), ctx.NextTemp());
auto* start_offset =
builder.CreateAdd(cand.linear_init, start_offset_mul, ctx.NextTemp());
builder.CreateCall(fill_rows_fn,
{cand.base_ptr, start_offset, rows,
ctx.GetConstInt(cand.linear_step), cand.bound,
ctx.GetConstInt(cand.fill_value)},
"");
preheader->Append<BranchInst>(Type::GetVoidType(), cand.exit);
cand.induction->RemoveIncomingBlock(preheader);
cand.linear->RemoveIncomingBlock(preheader);
preheader->RemoveSuccessor(cand.header);
cand.header->RemovePredecessor(preheader);
preheader->AddSuccessor(cand.exit);
cand.exit->AddPredecessor(preheader);
return true;
}
} // namespace
bool RunLoopIdiom(Function& func, Module& module, Context& ctx) {
if (func.IsExternal()) return false;
analysis::DominatorTree dom_tree(func);
analysis::LoopInfo loop_info(func, dom_tree);
for (const auto& loop_ptr : loop_info.GetLoops()) {
GuardedRowFillCandidate row_fill;
if (BuildGuardedRowFillCandidate(func, loop_ptr.get(), &row_fill)) {
if (RunGuardedRowFillLoop(func, row_fill, module, ctx)) {
return true;
}
}
FillLoopCandidate cand;
if (!BuildFillLoopCandidate(func, loop_ptr.get(), &cand)) continue;
if (RunFillLoop(func, cand, module, ctx)) {
return true;
}
}
return false;
}
} // namespace passes
} // namespace ir