#include "ir/PassManager.h" #include "ir/Analysis.h" #include "ir/IR.h" #include "LoopMemoryUtils.h" #include "LoopPassUtils.h" #include #include #include namespace ir { namespace { struct FissionLoopInfo { Loop* loop = nullptr; BasicBlock* preheader = nullptr; BasicBlock* header = nullptr; BasicBlock* body = nullptr; BasicBlock* exit = nullptr; CondBrInst* branch = nullptr; BinaryInst* compare = nullptr; Opcode compare_opcode = Opcode::ICmpLT; Value* bound = nullptr; loopmem::SimpleInductionVar induction_var; PhiInst* iv = nullptr; BinaryInst* step_inst = nullptr; }; bool HasSyntheticLoopTag(const std::string& name) { return name.find("unroll.") != std::string::npos || name.find("fission.") != std::string::npos; } bool IsAlreadyTransformedLoop(const Loop& loop, BasicBlock* body) { if (!loop.preheader || !loop.header || !body) { return true; } return HasSyntheticLoopTag(loop.preheader->GetName()) || HasSyntheticLoopTag(loop.header->GetName()) || HasSyntheticLoopTag(body->GetName()); } Opcode SwapCompareOpcode(Opcode opcode) { switch (opcode) { case Opcode::ICmpLT: return Opcode::ICmpGT; case Opcode::ICmpLE: return Opcode::ICmpGE; case Opcode::ICmpGT: return Opcode::ICmpLT; case Opcode::ICmpGE: return Opcode::ICmpLE; default: return opcode; } } bool MatchFissionLoop(Loop& loop, FissionLoopInfo& info) { if (!loop.preheader || !loop.header || !loop.IsInnermost()) { return false; } BasicBlock* body = nullptr; BasicBlock* exit = nullptr; if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) { return false; } if (IsAlreadyTransformedLoop(loop, body)) { return false; } std::vector phis; loopmem::SimpleInductionVar induction_var; bool found_iv = false; for (const auto& inst_ptr : loop.header->GetInstructions()) { auto* phi = dyncast(inst_ptr.get()); if (!phi) { break; } phis.push_back(phi); if (!found_iv && loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) { found_iv = true; } } if (!found_iv || phis.size() != 1) { return false; } auto* branch = dyncast(looputils::GetTerminator(loop.header)); auto* compare = branch ? dyncast(branch->GetCondition()) : nullptr; if (!branch || branch->GetThenBlock() != body || !compare) { return false; } Opcode compare_opcode = compare->GetOpcode(); Value* bound = nullptr; if (compare->GetLhs() == induction_var.phi && looputils::IsLoopInvariantValue(loop, compare->GetRhs())) { bound = compare->GetRhs(); } else if (compare->GetRhs() == induction_var.phi && looputils::IsLoopInvariantValue(loop, compare->GetLhs())) { bound = compare->GetLhs(); compare_opcode = SwapCompareOpcode(compare_opcode); } else { return false; } auto* step_inst = dyncast(induction_var.latch_value); if (!step_inst || step_inst->GetParent() != body) { return false; } for (const auto& inst_ptr : body->GetInstructions()) { auto* inst = inst_ptr.get(); if (inst->IsTerminator() || inst == step_inst) { continue; } if (!looputils::IsCloneableInstruction(inst) || dyncast(inst) || dyncast(inst) || dyncast(inst)) { return false; } } info.loop = &loop; info.preheader = loop.preheader; info.header = loop.header; info.body = body; info.exit = exit; info.branch = branch; info.compare = compare; info.compare_opcode = compare_opcode; info.bound = bound; info.induction_var = induction_var; info.iv = induction_var.phi; info.step_inst = step_inst; return true; } bool ContainsInterestingPayload(const std::vector& group) { bool has_memory = false; for (auto* inst : group) { if (dyncast(inst) || dyncast(inst)) { has_memory = true; } } return has_memory; } Value* RemapExitValue(Value* value, PhiInst* old_iv, PhiInst* new_iv) { if (value == old_iv) { return new_iv; } return value; } bool BuildSecondLoop(Function& function, const FissionLoopInfo& info, const std::vector& second_group) { auto* second_header = function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.header")); auto* second_body = function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.body")); const int preheader_index = looputils::GetPhiIncomingIndex(info.iv, info.preheader); if (preheader_index < 0) { return false; } auto* second_iv = second_header->Append( info.iv->GetType(), nullptr, looputils::NextSyntheticName(function, "fission.iv.")); second_iv->AddIncoming(info.iv->GetIncomingValue(preheader_index), info.header); auto* second_cmp = second_header->Append( info.compare_opcode, Type::GetBoolType(), second_iv, info.bound, nullptr, looputils::NextSyntheticName(function, "fission.cmp.")); second_header->Append(second_cmp, second_body, info.exit, nullptr); second_header->AddPredecessor(info.header); second_header->AddSuccessor(second_body); second_header->AddSuccessor(info.exit); std::unordered_map remap; remap[info.iv] = second_iv; std::unordered_set selected(second_group.begin(), second_group.end()); selected.insert(info.step_inst); for (const auto& inst_ptr : info.body->GetInstructions()) { auto* inst = inst_ptr.get(); if (inst->IsTerminator() || selected.find(inst) == selected.end()) { continue; } looputils::CloneInstruction(function, inst, second_body, remap, "fission."); } auto* cloned_step_value = looputils::RemapValue(remap, info.step_inst); if (!cloned_step_value) { return false; } second_iv->AddIncoming(cloned_step_value, second_body); second_body->Append(second_header, nullptr); second_body->AddPredecessor(second_header); second_body->AddSuccessor(second_header); second_header->AddPredecessor(second_body); if (!looputils::RedirectSuccessorEdge(info.header, info.exit, second_header)) { return false; } info.exit->RemovePredecessor(info.header); info.exit->AddPredecessor(second_header); for (const auto& inst_ptr : info.exit->GetInstructions()) { auto* phi = dyncast(inst_ptr.get()); if (!phi) { break; } const int incoming = looputils::GetPhiIncomingIndex(phi, info.header); if (incoming < 0) { continue; } phi->SetOperand(static_cast(2 * incoming), RemapExitValue(phi->GetIncomingValue(incoming), info.iv, second_iv)); phi->SetOperand(static_cast(2 * incoming + 1), second_header); } return true; } bool RunLoopFissionOnFunction(Function& function) { if (function.IsExternal() || !function.GetEntryBlock()) { return false; } bool changed = false; while (true) { DominatorTree dom_tree(function); LoopInfo loop_info(function, dom_tree); bool local_changed = false; for (auto* loop : loop_info.GetLoopsInPostOrder()) { FissionLoopInfo info; if (!MatchFissionLoop(*loop, info)) { continue; } const auto accesses = loopmem::CollectMemoryAccesses(*loop, info.iv); std::vector payload; for (const auto& inst_ptr : info.body->GetInstructions()) { auto* inst = inst_ptr.get(); if (inst->IsTerminator() || inst == info.step_inst) { continue; } payload.push_back(inst); } if (payload.size() < 2) { continue; } int chosen_cut = -1; std::vector first_group; std::vector second_group; for (std::size_t cut = 1; cut < payload.size(); ++cut) { std::vector first(payload.begin(), payload.begin() + static_cast(cut)); std::vector second(payload.begin() + static_cast(cut), payload.end()); if (!ContainsInterestingPayload(first) || !ContainsInterestingPayload(second)) { continue; } std::unordered_set first_set(first.begin(), first.end()); std::unordered_set second_set(second.begin(), second.end()); if (loopmem::HasScalarDependenceAcrossCut(first, second_set) || loopmem::HasMemoryDependenceAcrossCut(accesses, first_set, second_set, info.induction_var.stride)) { continue; } chosen_cut = static_cast(cut); first_group = std::move(first); second_group = std::move(second); break; } if (chosen_cut < 0) { continue; } std::unordered_set keep(first_group.begin(), first_group.end()); keep.insert(info.step_inst); std::vector to_remove; for (const auto& inst_ptr : info.body->GetInstructions()) { auto* inst = inst_ptr.get(); if (inst->IsTerminator() || keep.find(inst) != keep.end()) { continue; } to_remove.push_back(inst); } if (!BuildSecondLoop(function, info, second_group)) { continue; } for (auto* inst : to_remove) { info.body->EraseInstruction(inst); } changed = true; local_changed = true; break; } if (!local_changed) { break; } } return changed; } } // namespace bool RunLoopFission(Module& module) { bool changed = false; for (const auto& function : module.GetFunctions()) { if (function) { changed |= RunLoopFissionOnFunction(*function); } } return changed; } } // namespace ir