forked from plf6vcqwa/test
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
327 lines
9.7 KiB
327 lines
9.7 KiB
#include "ir/PassManager.h"
|
|
|
|
#include "ir/Analysis.h"
|
|
#include "ir/IR.h"
|
|
#include "LoopMemoryUtils.h"
|
|
#include "LoopPassUtils.h"
|
|
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <vector>
|
|
|
|
namespace ir {
|
|
namespace {
|
|
|
|
struct FissionLoopInfo {
|
|
Loop* loop = nullptr;
|
|
BasicBlock* preheader = nullptr;
|
|
BasicBlock* header = nullptr;
|
|
BasicBlock* body = nullptr;
|
|
BasicBlock* exit = nullptr;
|
|
CondBrInst* branch = nullptr;
|
|
BinaryInst* compare = nullptr;
|
|
Opcode compare_opcode = Opcode::ICmpLT;
|
|
Value* bound = nullptr;
|
|
loopmem::SimpleInductionVar induction_var;
|
|
PhiInst* iv = nullptr;
|
|
BinaryInst* step_inst = nullptr;
|
|
};
|
|
|
|
bool HasSyntheticLoopTag(const std::string& name) {
|
|
return name.find("unroll.") != std::string::npos ||
|
|
name.find("fission.") != std::string::npos;
|
|
}
|
|
|
|
bool IsAlreadyTransformedLoop(const Loop& loop, BasicBlock* body) {
|
|
if (!loop.preheader || !loop.header || !body) {
|
|
return true;
|
|
}
|
|
return HasSyntheticLoopTag(loop.preheader->GetName()) ||
|
|
HasSyntheticLoopTag(loop.header->GetName()) ||
|
|
HasSyntheticLoopTag(body->GetName());
|
|
}
|
|
|
|
Opcode SwapCompareOpcode(Opcode opcode) {
|
|
switch (opcode) {
|
|
case Opcode::ICmpLT:
|
|
return Opcode::ICmpGT;
|
|
case Opcode::ICmpLE:
|
|
return Opcode::ICmpGE;
|
|
case Opcode::ICmpGT:
|
|
return Opcode::ICmpLT;
|
|
case Opcode::ICmpGE:
|
|
return Opcode::ICmpLE;
|
|
default:
|
|
return opcode;
|
|
}
|
|
}
|
|
|
|
bool MatchFissionLoop(Loop& loop, FissionLoopInfo& info) {
|
|
if (!loop.preheader || !loop.header || !loop.IsInnermost()) {
|
|
return false;
|
|
}
|
|
|
|
BasicBlock* body = nullptr;
|
|
BasicBlock* exit = nullptr;
|
|
if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) {
|
|
return false;
|
|
}
|
|
if (IsAlreadyTransformedLoop(loop, body)) {
|
|
return false;
|
|
}
|
|
|
|
std::vector<PhiInst*> phis;
|
|
loopmem::SimpleInductionVar induction_var;
|
|
bool found_iv = false;
|
|
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
|
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
|
if (!phi) {
|
|
break;
|
|
}
|
|
phis.push_back(phi);
|
|
if (!found_iv &&
|
|
loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) {
|
|
found_iv = true;
|
|
}
|
|
}
|
|
if (!found_iv || phis.size() != 1) {
|
|
return false;
|
|
}
|
|
|
|
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
|
|
auto* compare = branch ? dyncast<BinaryInst>(branch->GetCondition()) : nullptr;
|
|
if (!branch || branch->GetThenBlock() != body || !compare) {
|
|
return false;
|
|
}
|
|
|
|
Opcode compare_opcode = compare->GetOpcode();
|
|
Value* bound = nullptr;
|
|
if (compare->GetLhs() == induction_var.phi &&
|
|
looputils::IsLoopInvariantValue(loop, compare->GetRhs())) {
|
|
bound = compare->GetRhs();
|
|
} else if (compare->GetRhs() == induction_var.phi &&
|
|
looputils::IsLoopInvariantValue(loop, compare->GetLhs())) {
|
|
bound = compare->GetLhs();
|
|
compare_opcode = SwapCompareOpcode(compare_opcode);
|
|
} else {
|
|
return false;
|
|
}
|
|
|
|
auto* step_inst = dyncast<BinaryInst>(induction_var.latch_value);
|
|
if (!step_inst || step_inst->GetParent() != body) {
|
|
return false;
|
|
}
|
|
|
|
for (const auto& inst_ptr : body->GetInstructions()) {
|
|
auto* inst = inst_ptr.get();
|
|
if (inst->IsTerminator() || inst == step_inst) {
|
|
continue;
|
|
}
|
|
if (!looputils::IsCloneableInstruction(inst) || dyncast<CallInst>(inst) ||
|
|
dyncast<MemsetInst>(inst) || dyncast<AllocaInst>(inst)) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
info.loop = &loop;
|
|
info.preheader = loop.preheader;
|
|
info.header = loop.header;
|
|
info.body = body;
|
|
info.exit = exit;
|
|
info.branch = branch;
|
|
info.compare = compare;
|
|
info.compare_opcode = compare_opcode;
|
|
info.bound = bound;
|
|
info.induction_var = induction_var;
|
|
info.iv = induction_var.phi;
|
|
info.step_inst = step_inst;
|
|
return true;
|
|
}
|
|
|
|
bool ContainsInterestingPayload(const std::vector<Instruction*>& group) {
|
|
bool has_memory = false;
|
|
for (auto* inst : group) {
|
|
if (dyncast<LoadInst>(inst) || dyncast<StoreInst>(inst)) {
|
|
has_memory = true;
|
|
}
|
|
}
|
|
return has_memory;
|
|
}
|
|
|
|
Value* RemapExitValue(Value* value, PhiInst* old_iv, PhiInst* new_iv) {
|
|
if (value == old_iv) {
|
|
return new_iv;
|
|
}
|
|
return value;
|
|
}
|
|
|
|
bool BuildSecondLoop(Function& function, const FissionLoopInfo& info,
|
|
const std::vector<Instruction*>& second_group) {
|
|
auto* second_header =
|
|
function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.header"));
|
|
auto* second_body =
|
|
function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.body"));
|
|
|
|
const int preheader_index = looputils::GetPhiIncomingIndex(info.iv, info.preheader);
|
|
if (preheader_index < 0) {
|
|
return false;
|
|
}
|
|
auto* second_iv = second_header->Append<PhiInst>(
|
|
info.iv->GetType(), nullptr,
|
|
looputils::NextSyntheticName(function, "fission.iv."));
|
|
second_iv->AddIncoming(info.iv->GetIncomingValue(preheader_index), info.header);
|
|
|
|
auto* second_cmp = second_header->Append<BinaryInst>(
|
|
info.compare_opcode, Type::GetBoolType(), second_iv, info.bound, nullptr,
|
|
looputils::NextSyntheticName(function, "fission.cmp."));
|
|
second_header->Append<CondBrInst>(second_cmp, second_body, info.exit, nullptr);
|
|
second_header->AddPredecessor(info.header);
|
|
second_header->AddSuccessor(second_body);
|
|
second_header->AddSuccessor(info.exit);
|
|
|
|
std::unordered_map<Value*, Value*> remap;
|
|
remap[info.iv] = second_iv;
|
|
std::unordered_set<Instruction*> selected(second_group.begin(), second_group.end());
|
|
selected.insert(info.step_inst);
|
|
for (const auto& inst_ptr : info.body->GetInstructions()) {
|
|
auto* inst = inst_ptr.get();
|
|
if (inst->IsTerminator() || selected.find(inst) == selected.end()) {
|
|
continue;
|
|
}
|
|
looputils::CloneInstruction(function, inst, second_body, remap, "fission.");
|
|
}
|
|
auto* cloned_step_value = looputils::RemapValue(remap, info.step_inst);
|
|
if (!cloned_step_value) {
|
|
return false;
|
|
}
|
|
second_iv->AddIncoming(cloned_step_value, second_body);
|
|
second_body->Append<UncondBrInst>(second_header, nullptr);
|
|
second_body->AddPredecessor(second_header);
|
|
second_body->AddSuccessor(second_header);
|
|
second_header->AddPredecessor(second_body);
|
|
|
|
if (!looputils::RedirectSuccessorEdge(info.header, info.exit, second_header)) {
|
|
return false;
|
|
}
|
|
info.exit->RemovePredecessor(info.header);
|
|
info.exit->AddPredecessor(second_header);
|
|
|
|
for (const auto& inst_ptr : info.exit->GetInstructions()) {
|
|
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
|
if (!phi) {
|
|
break;
|
|
}
|
|
const int incoming = looputils::GetPhiIncomingIndex(phi, info.header);
|
|
if (incoming < 0) {
|
|
continue;
|
|
}
|
|
phi->SetOperand(static_cast<std::size_t>(2 * incoming),
|
|
RemapExitValue(phi->GetIncomingValue(incoming), info.iv, second_iv));
|
|
phi->SetOperand(static_cast<std::size_t>(2 * incoming + 1), second_header);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool RunLoopFissionOnFunction(Function& function) {
|
|
if (function.IsExternal() || !function.GetEntryBlock()) {
|
|
return false;
|
|
}
|
|
|
|
bool changed = false;
|
|
while (true) {
|
|
DominatorTree dom_tree(function);
|
|
LoopInfo loop_info(function, dom_tree);
|
|
bool local_changed = false;
|
|
|
|
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
|
FissionLoopInfo info;
|
|
if (!MatchFissionLoop(*loop, info)) {
|
|
continue;
|
|
}
|
|
|
|
const auto accesses = loopmem::CollectMemoryAccesses(*loop, info.iv);
|
|
std::vector<Instruction*> payload;
|
|
for (const auto& inst_ptr : info.body->GetInstructions()) {
|
|
auto* inst = inst_ptr.get();
|
|
if (inst->IsTerminator() || inst == info.step_inst) {
|
|
continue;
|
|
}
|
|
payload.push_back(inst);
|
|
}
|
|
if (payload.size() < 2) {
|
|
continue;
|
|
}
|
|
|
|
int chosen_cut = -1;
|
|
std::vector<Instruction*> first_group;
|
|
std::vector<Instruction*> second_group;
|
|
for (std::size_t cut = 1; cut < payload.size(); ++cut) {
|
|
std::vector<Instruction*> first(payload.begin(), payload.begin() + static_cast<long long>(cut));
|
|
std::vector<Instruction*> second(payload.begin() + static_cast<long long>(cut),
|
|
payload.end());
|
|
if (!ContainsInterestingPayload(first) || !ContainsInterestingPayload(second)) {
|
|
continue;
|
|
}
|
|
|
|
std::unordered_set<Instruction*> first_set(first.begin(), first.end());
|
|
std::unordered_set<Instruction*> second_set(second.begin(), second.end());
|
|
if (loopmem::HasScalarDependenceAcrossCut(first, second_set) ||
|
|
loopmem::HasMemoryDependenceAcrossCut(accesses, first_set, second_set,
|
|
info.induction_var.stride)) {
|
|
continue;
|
|
}
|
|
|
|
chosen_cut = static_cast<int>(cut);
|
|
first_group = std::move(first);
|
|
second_group = std::move(second);
|
|
break;
|
|
}
|
|
|
|
if (chosen_cut < 0) {
|
|
continue;
|
|
}
|
|
|
|
std::unordered_set<Instruction*> keep(first_group.begin(), first_group.end());
|
|
keep.insert(info.step_inst);
|
|
std::vector<Instruction*> to_remove;
|
|
for (const auto& inst_ptr : info.body->GetInstructions()) {
|
|
auto* inst = inst_ptr.get();
|
|
if (inst->IsTerminator() || keep.find(inst) != keep.end()) {
|
|
continue;
|
|
}
|
|
to_remove.push_back(inst);
|
|
}
|
|
if (!BuildSecondLoop(function, info, second_group)) {
|
|
continue;
|
|
}
|
|
for (auto* inst : to_remove) {
|
|
info.body->EraseInstruction(inst);
|
|
}
|
|
|
|
changed = true;
|
|
local_changed = true;
|
|
break;
|
|
}
|
|
|
|
if (!local_changed) {
|
|
break;
|
|
}
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool RunLoopFission(Module& module) {
|
|
bool changed = false;
|
|
for (const auto& function : module.GetFunctions()) {
|
|
if (function) {
|
|
changed |= RunLoopFissionOnFunction(*function);
|
|
}
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
} // namespace ir
|