// 循环习语优化: // - 将连续常量填充的规范循环替换为运行时批量填充调用 // - 当前仅处理 step=1、init=0、单 store 的 innermost 循环 #include "ir/IR.h" #include #include #include #include "LoopPassUtils.h" namespace ir { namespace passes { namespace { struct FillLoopCandidate { analysis::Loop* loop = nullptr; BasicBlock* preheader = nullptr; BasicBlock* header = nullptr; BasicBlock* exit = nullptr; PhiInst* induction = nullptr; Value* bound = nullptr; Value* base_ptr = nullptr; Value* offset = nullptr; int fill_value = 0; }; struct GuardedRowFillCandidate { analysis::Loop* loop = nullptr; BasicBlock* preheader = nullptr; BasicBlock* header = nullptr; BasicBlock* body = nullptr; BasicBlock* action = nullptr; BasicBlock* latch = nullptr; BasicBlock* exit = nullptr; PhiInst* induction = nullptr; Value* bound = nullptr; PhiInst* linear = nullptr; Value* linear_init = nullptr; int linear_step = 0; Value* base_ptr = nullptr; Value* threshold = nullptr; bool prefix = false; int fill_value = 0; }; bool ExprDependsOn(Value* value, Value* needle, std::unordered_set& visiting) { if (value == needle) return true; auto* inst = dynamic_cast(value); if (!inst) return false; if (!visiting.insert(value).second) return false; for (size_t i = 0; i < inst->GetNumOperands(); ++i) { if (ExprDependsOn(inst->GetOperand(i), needle, visiting)) { return true; } } return false; } bool ExprDependsOn(Value* value, Value* needle) { std::unordered_set visiting; return ExprDependsOn(value, needle, visiting); } Value* GetIncomingForBlock(PhiInst* phi, BasicBlock* block) { if (!phi || !block) return nullptr; for (size_t i = 0; i < phi->GetNumIncoming(); ++i) { if (phi->GetIncomingBlock(i) == block) { return phi->GetIncomingValue(i); } } return nullptr; } Value* MaterializeInvariantExpr(Value* value, analysis::Loop* loop, IRBuilder& builder, ValueMap& remap) { auto it = remap.find(value); if (it != remap.end()) return it->second; if (dynamic_cast(value) || dynamic_cast(value) || dynamic_cast(value) || dynamic_cast(value)) { return value; } auto* inst = dynamic_cast(value); if (!inst || !loop->Contains(inst->GetParent())) return value; for (size_t i = 0; i < inst->GetNumOperands(); ++i) { auto* operand = inst->GetOperand(i); remap[operand] = MaterializeInvariantExpr(operand, loop, builder, remap); } auto cloned = CloneInstruction(inst, remap, ".idiom"); if (!cloned) return nullptr; auto* raw = cloned.get(); InsertInstruction(builder.GetInsertBlock(), std::move(cloned)); remap[inst] = raw; return raw; } bool HasOutsideUse(Instruction* inst, analysis::Loop* loop) { for (const auto& use : inst->GetUses()) { auto* user = dynamic_cast(use.GetUser()); if (!user) return true; if (!user->GetParent() || !loop->Contains(user->GetParent())) { return true; } } return false; } Value* MatchContiguousOffset(Value* index, PhiInst* iv, analysis::Loop* loop) { if (index == iv) return nullptr; auto* bin = dynamic_cast(index); if (!bin || bin->GetOpcode() != Opcode::Add || !bin->GetType() || !bin->GetType()->IsInt32()) { return nullptr; } if (bin->GetLhs() == iv && IsLoopInvariantValue(bin->GetRhs(), loop)) { return bin->GetRhs(); } if (bin->GetRhs() == iv && IsLoopInvariantValue(bin->GetLhs(), loop)) { return bin->GetLhs(); } return nullptr; } bool BuildFillLoopCandidate(Function& func, analysis::Loop* loop, FillLoopCandidate* out) { (void)func; auto match = MatchCanonicalLoop(loop); if (!match.has_value()) return false; if (match->loop->GetChildren().size() != 0) return false; if (match->body != match->latch || loop->GetBlocks().size() != 2) return false; if (match->header_phis.size() != 1 || match->header_phis.front() != match->induction.phi) { return false; } if (match->induction.step != 1) return false; if (match->header_cmp->GetCmpOp() != CmpOp::Lt) return false; auto* init_ci = dynamic_cast(match->induction.init); if (!init_ci || init_ci->GetValue() != 0) return false; if (!match->exit->GetInstructions().empty() && dynamic_cast(match->exit->GetInstructions().front().get()) != nullptr) { return false; } StoreInst* store = nullptr; std::vector body_insts; for (const auto& inst_ptr : match->body->GetInstructions()) { auto* inst = inst_ptr.get(); if (dynamic_cast(inst) != nullptr || inst->IsTerminator()) continue; body_insts.push_back(inst); switch (inst->GetOpcode()) { case Opcode::Add: case Opcode::Gep: case Opcode::Store: break; default: return false; } if (inst != match->induction.next && HasOutsideUse(inst, loop)) { return false; } if (auto* maybe_store = dynamic_cast(inst)) { if (store) return false; store = maybe_store; } } if (!store) return false; auto* fill_ci = dynamic_cast(store->GetValue()); if (!fill_ci) return false; auto* gep = dynamic_cast(store->GetPtr()); if (!gep || !gep->GetBase() || !gep->GetBase()->GetType() || !gep->GetBase()->GetType()->IsPtrInt32()) { return false; } Value* offset = MatchContiguousOffset(gep->GetIndex(), match->induction.phi, loop); if (gep->GetIndex() != match->induction.phi && offset == nullptr) { return false; } out->loop = loop; out->preheader = match->preheader; out->header = match->header; out->exit = match->exit; out->induction = match->induction.phi; out->bound = match->bound; out->base_ptr = gep->GetBase(); out->offset = offset; out->fill_value = fill_ci->GetValue(); return true; } Function* GetOrCreateFillI32(Module& module) { if (auto* fn = module.FindFunction("__fill_i32")) return fn; auto* fn = module.CreateFunction("__fill_i32", Type::GetVoidType(), {Type::GetPtrInt32Type(), Type::GetInt32Type(), Type::GetInt32Type()}); fn->SetExternal(true); return fn; } Function* GetOrCreateFillRowsI32(Module& module) { if (auto* fn = module.FindFunction("__fill_rows_i32")) return fn; auto* fn = module.CreateFunction( "__fill_rows_i32", Type::GetVoidType(), {Type::GetPtrInt32Type(), Type::GetInt32Type(), Type::GetInt32Type(), Type::GetInt32Type(), Type::GetInt32Type(), Type::GetInt32Type()}); fn->SetExternal(true); return fn; } bool BuildGuardedRowFillCandidate(Function& func, analysis::Loop* loop, GuardedRowFillCandidate* out) { (void)func; if (!loop) return false; if (loop->GetChildren().size() != 0) return false; if (loop->GetBlocks().size() != 4) return false; auto* header = loop->GetHeader(); auto* preheader = loop->GetPreheader(); if (!header || !preheader) return false; if (loop->GetLatches().size() != 1) return false; auto* latch = loop->GetLatches().front(); if (!latch) return false; auto* header_term = header->HasTerminator() ? dynamic_cast( header->MutableInstructions().back().get()) : nullptr; if (!header_term) return false; auto* header_cmp = dynamic_cast(header_term->GetCond()); if (!header_cmp || header_cmp->GetCmpOp() != CmpOp::Lt) return false; auto induction = MatchCanonicalInduction(header, preheader, latch); if (!induction.has_value() || induction->step != 1) return false; Value* bound = nullptr; BasicBlock* body = nullptr; BasicBlock* exit = nullptr; if (loop->Contains(header_term->GetTrueBlock()) && !loop->Contains(header_term->GetFalseBlock())) { body = header_term->GetTrueBlock(); exit = header_term->GetFalseBlock(); } else if (loop->Contains(header_term->GetFalseBlock()) && !loop->Contains(header_term->GetTrueBlock())) { body = header_term->GetFalseBlock(); exit = header_term->GetTrueBlock(); } else { return false; } if (header_cmp->GetLhs() == induction->phi && IsLoopInvariantValue(header_cmp->GetRhs(), loop)) { bound = header_cmp->GetRhs(); } else if (header_cmp->GetRhs() == induction->phi && IsLoopInvariantValue(header_cmp->GetLhs(), loop)) { bound = header_cmp->GetLhs(); } else { return false; } auto header_phis = CollectHeaderPhis(header); if (header_phis.size() != 2) return false; PhiInst* linear_phi = nullptr; for (auto* phi : header_phis) { if (phi != induction->phi) { linear_phi = phi; break; } } if (!linear_phi || !linear_phi->GetType() || !linear_phi->GetType()->IsInt32()) { return false; } auto* linear_init = GetIncomingForBlock(linear_phi, preheader); auto* linear_next = GetIncomingForBlock(linear_phi, latch); auto* linear_next_bin = dynamic_cast(linear_next); if (!linear_init || !linear_next_bin || linear_next_bin->GetOpcode() != Opcode::Add || linear_next_bin->GetLhs() != linear_phi) { return false; } auto* linear_step_ci = dynamic_cast(linear_next_bin->GetRhs()); if (!linear_step_ci || linear_step_ci->GetValue() <= 0) return false; auto* guard = body->HasTerminator() ? dynamic_cast(body->MutableInstructions().back().get()) : nullptr; if (!guard) return false; BasicBlock* action = nullptr; if (guard->GetTrueBlock() == latch && loop->Contains(guard->GetFalseBlock())) { action = guard->GetFalseBlock(); } else if (guard->GetFalseBlock() == latch && loop->Contains(guard->GetTrueBlock())) { action = guard->GetTrueBlock(); } else { return false; } auto* action_term = dynamic_cast(action->MutableInstructions().back().get()); if (!action_term || action_term->GetTarget() != latch) return false; CallInst* fill_call = nullptr; GepInst* fill_gep = nullptr; for (const auto& inst_ptr : action->GetInstructions()) { auto* inst = inst_ptr.get(); if (inst->IsTerminator()) continue; if (auto* gep = dynamic_cast(inst)) { fill_gep = gep; continue; } fill_call = dynamic_cast(inst); } if (!fill_call || !fill_gep || fill_call->GetNumArgs() != 3) return false; auto* callee = fill_call->GetCallee(); if (!callee || callee->GetName() != "__fill_i32") return false; auto* fill_value = dynamic_cast(fill_call->GetArg(2)); if (!fill_value) return false; if (fill_call->GetArg(0) != fill_gep || fill_call->GetArg(1) != bound) { return false; } if (fill_gep->GetIndex() != linear_phi) return false; if (!fill_gep->GetBase() || !fill_gep->GetBase()->GetType() || !fill_gep->GetBase()->GetType()->IsPtrInt32()) { return false; } auto* guard_cmp = dynamic_cast(guard->GetCond()); if (!guard_cmp || !guard_cmp->GetType() || !guard_cmp->GetType()->IsInt32()) { return false; } Value* threshold = nullptr; bool prefix = false; bool suffix = false; if (guard_cmp->GetLhs() == induction->phi && !ExprDependsOn(guard_cmp->GetRhs(), induction->phi) && !ExprDependsOn(guard_cmp->GetRhs(), linear_phi)) { threshold = guard_cmp->GetRhs(); if (guard_cmp->GetCmpOp() == CmpOp::Lt && action == guard->GetTrueBlock()) { prefix = true; } else if (guard_cmp->GetCmpOp() == CmpOp::Ge && action == guard->GetTrueBlock()) { suffix = true; } else if (guard_cmp->GetCmpOp() == CmpOp::Lt && action == guard->GetFalseBlock()) { suffix = true; } else if (guard_cmp->GetCmpOp() == CmpOp::Ge && action == guard->GetFalseBlock()) { prefix = true; } else { return false; } } else { return false; } out->loop = loop; out->preheader = preheader; out->header = header; out->body = body; out->action = action; out->latch = latch; out->exit = exit; out->induction = induction->phi; out->bound = bound; out->linear = linear_phi; out->linear_init = linear_init; out->linear_step = linear_step_ci->GetValue(); out->base_ptr = fill_gep->GetBase(); out->fill_value = fill_value->GetValue(); out->threshold = threshold; out->prefix = prefix; return prefix || suffix; } bool RunFillLoop(Function& func, const FillLoopCandidate& cand, Module& module, Context& ctx) { (void)func; auto* fill_fn = GetOrCreateFillI32(module); auto* preheader = cand.preheader; if (preheader->HasTerminator()) { preheader->RemoveInstruction(preheader->MutableInstructions().back().get()); } IRBuilder builder(ctx, preheader); Value* start_ptr = cand.base_ptr; if (cand.offset) { start_ptr = builder.CreateGep(cand.base_ptr, cand.offset, ctx.NextTemp()); } builder.CreateCall(fill_fn, {start_ptr, cand.bound, ctx.GetConstInt(cand.fill_value)}, ""); preheader->Append(Type::GetVoidType(), cand.exit); cand.induction->RemoveIncomingBlock(preheader); preheader->RemoveSuccessor(cand.header); cand.header->RemovePredecessor(preheader); preheader->AddSuccessor(cand.exit); cand.exit->AddPredecessor(preheader); return true; } bool RunGuardedRowFillLoop(Function& func, const GuardedRowFillCandidate& cand, Module& module, Context& ctx) { (void)func; auto* fill_rows_fn = GetOrCreateFillRowsI32(module); auto* preheader = cand.preheader; if (preheader->HasTerminator()) { preheader->RemoveInstruction(preheader->MutableInstructions().back().get()); } IRBuilder builder(ctx, preheader); ValueMap remap; auto* threshold = MaterializeInvariantExpr(cand.threshold, cand.loop, builder, remap); if (!threshold) return false; Value* start_index = cand.prefix ? ctx.GetConstInt(0) : threshold; Value* rows = cand.prefix ? threshold : nullptr; if (!cand.prefix) { rows = builder.CreateSub(cand.bound, start_index, ctx.NextTemp()); } auto* start_offset_mul = builder.CreateMul(start_index, ctx.GetConstInt(cand.linear_step), ctx.NextTemp()); auto* start_offset = builder.CreateAdd(cand.linear_init, start_offset_mul, ctx.NextTemp()); builder.CreateCall(fill_rows_fn, {cand.base_ptr, start_offset, rows, ctx.GetConstInt(cand.linear_step), cand.bound, ctx.GetConstInt(cand.fill_value)}, ""); preheader->Append(Type::GetVoidType(), cand.exit); cand.induction->RemoveIncomingBlock(preheader); cand.linear->RemoveIncomingBlock(preheader); preheader->RemoveSuccessor(cand.header); cand.header->RemovePredecessor(preheader); preheader->AddSuccessor(cand.exit); cand.exit->AddPredecessor(preheader); return true; } } // namespace bool RunLoopIdiom(Function& func, Module& module, Context& ctx) { if (func.IsExternal()) return false; analysis::DominatorTree dom_tree(func); analysis::LoopInfo loop_info(func, dom_tree); for (const auto& loop_ptr : loop_info.GetLoops()) { GuardedRowFillCandidate row_fill; if (BuildGuardedRowFillCandidate(func, loop_ptr.get(), &row_fill)) { if (RunGuardedRowFillLoop(func, row_fill, module, ctx)) { return true; } } FillLoopCandidate cand; if (!BuildFillLoopCandidate(func, loop_ptr.get(), &cand)) continue; if (RunFillLoop(func, cand, module, ctx)) { return true; } } return false; } } // namespace passes } // namespace ir