You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

327 lines
9.7 KiB

#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopMemoryUtils.h"
#include "LoopPassUtils.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
struct FissionLoopInfo {
Loop* loop = nullptr;
BasicBlock* preheader = nullptr;
BasicBlock* header = nullptr;
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
CondBrInst* branch = nullptr;
BinaryInst* compare = nullptr;
Opcode compare_opcode = Opcode::ICmpLT;
Value* bound = nullptr;
loopmem::SimpleInductionVar induction_var;
PhiInst* iv = nullptr;
BinaryInst* step_inst = nullptr;
};
bool HasSyntheticLoopTag(const std::string& name) {
return name.find("unroll.") != std::string::npos ||
name.find("fission.") != std::string::npos;
}
bool IsAlreadyTransformedLoop(const Loop& loop, BasicBlock* body) {
if (!loop.preheader || !loop.header || !body) {
return true;
}
return HasSyntheticLoopTag(loop.preheader->GetName()) ||
HasSyntheticLoopTag(loop.header->GetName()) ||
HasSyntheticLoopTag(body->GetName());
}
Opcode SwapCompareOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::ICmpLT:
return Opcode::ICmpGT;
case Opcode::ICmpLE:
return Opcode::ICmpGE;
case Opcode::ICmpGT:
return Opcode::ICmpLT;
case Opcode::ICmpGE:
return Opcode::ICmpLE;
default:
return opcode;
}
}
bool MatchFissionLoop(Loop& loop, FissionLoopInfo& info) {
if (!loop.preheader || !loop.header || !loop.IsInnermost()) {
return false;
}
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) {
return false;
}
if (IsAlreadyTransformedLoop(loop, body)) {
return false;
}
std::vector<PhiInst*> phis;
loopmem::SimpleInductionVar induction_var;
bool found_iv = false;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
phis.push_back(phi);
if (!found_iv &&
loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) {
found_iv = true;
}
}
if (!found_iv || phis.size() != 1) {
return false;
}
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
auto* compare = branch ? dyncast<BinaryInst>(branch->GetCondition()) : nullptr;
if (!branch || branch->GetThenBlock() != body || !compare) {
return false;
}
Opcode compare_opcode = compare->GetOpcode();
Value* bound = nullptr;
if (compare->GetLhs() == induction_var.phi &&
looputils::IsLoopInvariantValue(loop, compare->GetRhs())) {
bound = compare->GetRhs();
} else if (compare->GetRhs() == induction_var.phi &&
looputils::IsLoopInvariantValue(loop, compare->GetLhs())) {
bound = compare->GetLhs();
compare_opcode = SwapCompareOpcode(compare_opcode);
} else {
return false;
}
auto* step_inst = dyncast<BinaryInst>(induction_var.latch_value);
if (!step_inst || step_inst->GetParent() != body) {
return false;
}
for (const auto& inst_ptr : body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || inst == step_inst) {
continue;
}
if (!looputils::IsCloneableInstruction(inst) || dyncast<CallInst>(inst) ||
dyncast<MemsetInst>(inst) || dyncast<AllocaInst>(inst)) {
return false;
}
}
info.loop = &loop;
info.preheader = loop.preheader;
info.header = loop.header;
info.body = body;
info.exit = exit;
info.branch = branch;
info.compare = compare;
info.compare_opcode = compare_opcode;
info.bound = bound;
info.induction_var = induction_var;
info.iv = induction_var.phi;
info.step_inst = step_inst;
return true;
}
bool ContainsInterestingPayload(const std::vector<Instruction*>& group) {
bool has_memory = false;
for (auto* inst : group) {
if (dyncast<LoadInst>(inst) || dyncast<StoreInst>(inst)) {
has_memory = true;
}
}
return has_memory;
}
Value* RemapExitValue(Value* value, PhiInst* old_iv, PhiInst* new_iv) {
if (value == old_iv) {
return new_iv;
}
return value;
}
bool BuildSecondLoop(Function& function, const FissionLoopInfo& info,
const std::vector<Instruction*>& second_group) {
auto* second_header =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.header"));
auto* second_body =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.body"));
const int preheader_index = looputils::GetPhiIncomingIndex(info.iv, info.preheader);
if (preheader_index < 0) {
return false;
}
auto* second_iv = second_header->Append<PhiInst>(
info.iv->GetType(), nullptr,
looputils::NextSyntheticName(function, "fission.iv."));
second_iv->AddIncoming(info.iv->GetIncomingValue(preheader_index), info.header);
auto* second_cmp = second_header->Append<BinaryInst>(
info.compare_opcode, Type::GetBoolType(), second_iv, info.bound, nullptr,
looputils::NextSyntheticName(function, "fission.cmp."));
second_header->Append<CondBrInst>(second_cmp, second_body, info.exit, nullptr);
second_header->AddPredecessor(info.header);
second_header->AddSuccessor(second_body);
second_header->AddSuccessor(info.exit);
std::unordered_map<Value*, Value*> remap;
remap[info.iv] = second_iv;
std::unordered_set<Instruction*> selected(second_group.begin(), second_group.end());
selected.insert(info.step_inst);
for (const auto& inst_ptr : info.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || selected.find(inst) == selected.end()) {
continue;
}
looputils::CloneInstruction(function, inst, second_body, remap, "fission.");
}
auto* cloned_step_value = looputils::RemapValue(remap, info.step_inst);
if (!cloned_step_value) {
return false;
}
second_iv->AddIncoming(cloned_step_value, second_body);
second_body->Append<UncondBrInst>(second_header, nullptr);
second_body->AddPredecessor(second_header);
second_body->AddSuccessor(second_header);
second_header->AddPredecessor(second_body);
if (!looputils::RedirectSuccessorEdge(info.header, info.exit, second_header)) {
return false;
}
info.exit->RemovePredecessor(info.header);
info.exit->AddPredecessor(second_header);
for (const auto& inst_ptr : info.exit->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
const int incoming = looputils::GetPhiIncomingIndex(phi, info.header);
if (incoming < 0) {
continue;
}
phi->SetOperand(static_cast<std::size_t>(2 * incoming),
RemapExitValue(phi->GetIncomingValue(incoming), info.iv, second_iv));
phi->SetOperand(static_cast<std::size_t>(2 * incoming + 1), second_header);
}
return true;
}
bool RunLoopFissionOnFunction(Function& function) {
if (function.IsExternal() || !function.GetEntryBlock()) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
FissionLoopInfo info;
if (!MatchFissionLoop(*loop, info)) {
continue;
}
const auto accesses = loopmem::CollectMemoryAccesses(*loop, info.iv);
std::vector<Instruction*> payload;
for (const auto& inst_ptr : info.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || inst == info.step_inst) {
continue;
}
payload.push_back(inst);
}
if (payload.size() < 2) {
continue;
}
int chosen_cut = -1;
std::vector<Instruction*> first_group;
std::vector<Instruction*> second_group;
for (std::size_t cut = 1; cut < payload.size(); ++cut) {
std::vector<Instruction*> first(payload.begin(), payload.begin() + static_cast<long long>(cut));
std::vector<Instruction*> second(payload.begin() + static_cast<long long>(cut),
payload.end());
if (!ContainsInterestingPayload(first) || !ContainsInterestingPayload(second)) {
continue;
}
std::unordered_set<Instruction*> first_set(first.begin(), first.end());
std::unordered_set<Instruction*> second_set(second.begin(), second.end());
if (loopmem::HasScalarDependenceAcrossCut(first, second_set) ||
loopmem::HasMemoryDependenceAcrossCut(accesses, first_set, second_set,
info.induction_var.stride)) {
continue;
}
chosen_cut = static_cast<int>(cut);
first_group = std::move(first);
second_group = std::move(second);
break;
}
if (chosen_cut < 0) {
continue;
}
std::unordered_set<Instruction*> keep(first_group.begin(), first_group.end());
keep.insert(info.step_inst);
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : info.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || keep.find(inst) != keep.end()) {
continue;
}
to_remove.push_back(inst);
}
if (!BuildSecondLoop(function, info, second_group)) {
continue;
}
for (auto* inst : to_remove) {
info.body->EraseInstruction(inst);
}
changed = true;
local_changed = true;
break;
}
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLoopFission(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoopFissionOnFunction(*function);
}
}
return changed;
}
} // namespace ir