You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/ir/passes/LoopFission.cpp

172 lines
5.6 KiB

// 循环分裂:
// - 针对单块循环中两段彼此独立的 store 语句组做保守分裂
// - 仅处理单归纳变量、无其他 loop-carried phi 的情形
#include "ir/IR.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <vector>
#include "LoopPassUtils.h"
namespace ir {
namespace passes {
namespace {
Value* StripPointerBase(Value* value) {
while (auto* gep = dynamic_cast<GepInst*>(value)) {
value = gep->GetBase();
}
return value;
}
bool IsFissionCandidate(const CanonicalLoopMatch& match) {
if (match.loop->GetChildren().size() != 0) return false;
if (match.loop->GetBlocks().size() != 2) return false;
if (match.body != match.latch) return false;
if (match.header_phis.size() != 1) return false;
if (match.header_phis.front() != match.induction.phi) return false;
if (match.induction.step <= 0) return false;
auto* body_term =
dynamic_cast<BranchInst*>(match.body->MutableInstructions().back().get());
return body_term && body_term->GetTarget() == match.header;
}
bool DependsOnAny(Instruction* inst, const std::unordered_set<Instruction*>& defs) {
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
auto* def = dynamic_cast<Instruction*>(inst->GetOperand(i));
if (def && defs.count(def) != 0) return true;
}
return false;
}
bool RunFissionOnLoop(Function& func, const CanonicalLoopMatch& match,
Context& ctx) {
if (!IsFissionCandidate(match)) return false;
std::vector<Instruction*> body_insts;
for (const auto& inst_ptr : match.body->GetInstructions()) {
if (!inst_ptr.get()->IsTerminator()) {
body_insts.push_back(inst_ptr.get());
}
}
if (body_insts.size() < 3) return false;
auto* iv_next = dynamic_cast<Instruction*>(match.induction.next);
if (!iv_next || iv_next->GetParent() != match.body) return false;
std::vector<size_t> store_positions;
for (size_t i = 0; i < body_insts.size(); ++i) {
if (dynamic_cast<StoreInst*>(body_insts[i]) != nullptr) {
store_positions.push_back(i);
}
}
if (store_positions.size() != 2) return false;
const size_t first_store_idx = store_positions[0];
const size_t second_store_idx = store_positions[1];
if (body_insts.back() != iv_next) return false;
if (second_store_idx + 1 != body_insts.size() - 1) return false;
auto* first_store = static_cast<StoreInst*>(body_insts[first_store_idx]);
auto* second_store = static_cast<StoreInst*>(body_insts[second_store_idx]);
if (StripPointerBase(first_store->GetPtr()) == StripPointerBase(second_store->GetPtr())) {
return false;
}
std::vector<Instruction*> group1(body_insts.begin(),
body_insts.begin() + first_store_idx + 1);
std::vector<Instruction*> group2(body_insts.begin() + first_store_idx + 1,
body_insts.begin() + second_store_idx + 1);
std::unordered_set<Instruction*> group1_defs(group1.begin(), group1.end());
std::unordered_set<Instruction*> group2_defs(group2.begin(), group2.end());
group1_defs.erase(iv_next);
group2_defs.erase(iv_next);
for (auto* inst : group2) {
if (DependsOnAny(inst, group1_defs)) return false;
}
for (auto* inst : group1) {
if (DependsOnAny(inst, group2_defs)) return false;
}
auto* original_exit = match.exit;
std::string block_suffix = ctx.NextTemp();
if (!block_suffix.empty() && block_suffix.front() == '%') {
block_suffix.erase(0, 1);
}
auto* preheader2 =
func.CreateBlock(match.header->GetName() + ".fission.pre." + block_suffix);
auto* header2 =
func.CreateBlock(match.header->GetName() + ".fission.hdr." + block_suffix);
auto* body2 =
func.CreateBlock(match.body->GetName() + ".fission.body." + block_suffix);
preheader2->Append<BranchInst>(Type::GetVoidType(), header2);
auto* iv2 = header2->PrependPhi(Type::GetInt32Type(), ctx.NextTemp());
iv2->AddIncoming(match.induction.init, preheader2);
auto* cmp2 = header2->Append<CmpInst>(
match.header_cmp->GetCmpOp(), Type::GetInt32Type(), iv2, match.bound,
ctx.NextTemp());
header2->Append<CondBranchInst>(Type::GetVoidType(), cmp2, body2, original_exit);
ValueMap remap;
remap.emplace(match.induction.phi, iv2);
for (auto* inst : group2) {
auto cloned = CloneInstruction(inst, remap, ".f2");
if (!cloned) return false;
auto* raw = cloned.get();
body2->MutableInstructions().push_back(std::move(cloned));
raw->SetParent(body2);
remap[inst] = raw;
}
auto next2_cloned = CloneInstruction(iv_next, remap, ".f2");
if (!next2_cloned) return false;
auto* next2 = next2_cloned.get();
body2->MutableInstructions().push_back(std::move(next2_cloned));
next2->SetParent(body2);
body2->Append<BranchInst>(Type::GetVoidType(), header2);
iv2->AddIncoming(next2, body2);
const bool exit_is_true = (match.header_branch->GetTrueBlock() == original_exit);
match.header_branch->SetOperand(exit_is_true ? 1 : 2, preheader2);
match.header->RemoveSuccessor(original_exit);
match.header->AddSuccessor(preheader2);
preheader2->AddPredecessor(match.header);
original_exit->RemovePredecessor(match.header);
for (auto* inst : group2) {
match.body->RemoveInstruction(inst);
}
return true;
}
} // namespace
bool RunLoopFission(Function& func, Context& ctx) {
if (func.IsExternal()) return false;
analysis::DominatorTree dom_tree(func);
analysis::LoopInfo loop_info(func, dom_tree);
for (const auto& loop_ptr : loop_info.GetLoops()) {
auto match = MatchCanonicalLoop(loop_ptr.get());
if (!match.has_value()) continue;
if (RunFissionOnLoop(func, *match, ctx)) {
return true;
}
}
return false;
}
} // namespace passes
} // namespace ir