forked from p4jyxwm3q/nudt-compiler-cpp
Compare commits
5 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
78168fe917 | 4 weeks ago |
|
|
e55421f447 | 1 month ago |
|
|
69892ef133 | 1 month ago |
|
|
407be0fca1 | 1 month ago |
|
|
08ce9d96ab | 1 month ago |
@ -0,0 +1,137 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
bool IsPowerOfTwoPositive(int value) {
|
||||
return value > 0 && (value & (value - 1)) == 0;
|
||||
}
|
||||
|
||||
std::size_t FindInstructionIndex(BasicBlock* block, Instruction* inst) {
|
||||
if (!block || !inst) {
|
||||
return 0;
|
||||
}
|
||||
auto& instructions = block->GetInstructions();
|
||||
for (std::size_t i = 0; i < instructions.size(); ++i) {
|
||||
if (instructions[i].get() == inst) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return instructions.size();
|
||||
}
|
||||
|
||||
bool IsZero(Value* value) {
|
||||
if (auto* ci = dyncast<ConstantInt>(value)) {
|
||||
return ci->GetValue() == 0;
|
||||
}
|
||||
if (auto* cb = dyncast<ConstantI1>(value)) {
|
||||
return !cb->GetValue();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Value* OtherCompareOperand(BinaryInst* cmp, Value* value) {
|
||||
if (!cmp || cmp->GetNumOperands() != 2) {
|
||||
return nullptr;
|
||||
}
|
||||
if (cmp->GetLhs() == value) {
|
||||
return cmp->GetRhs();
|
||||
}
|
||||
if (cmp->GetRhs() == value) {
|
||||
return cmp->GetLhs();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool SimplifyPowerOfTwoRemTests(Function& function) {
|
||||
bool changed = false;
|
||||
std::vector<Instruction*> dead_rems;
|
||||
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
auto* block = block_ptr.get();
|
||||
if (!block) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* rem = dyncast<BinaryInst>(inst_ptr.get());
|
||||
if (!rem || rem->GetOpcode() != Opcode::Rem) {
|
||||
continue;
|
||||
}
|
||||
auto* divisor = dyncast<ConstantInt>(rem->GetRhs());
|
||||
if (!divisor || !IsPowerOfTwoPositive(divisor->GetValue())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int mask_value = divisor->GetValue() - 1;
|
||||
if (mask_value == 0) {
|
||||
rem->ReplaceAllUsesWith(looputils::ConstInt(0));
|
||||
dead_rems.push_back(rem);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<BinaryInst*> compare_uses;
|
||||
bool all_uses_are_zero_tests = !rem->GetUses().empty();
|
||||
for (const auto& use : rem->GetUses()) {
|
||||
auto* cmp = dyncast<BinaryInst>(dynamic_cast<Value*>(use.GetUser()));
|
||||
if (!cmp || (cmp->GetOpcode() != Opcode::ICmpEQ &&
|
||||
cmp->GetOpcode() != Opcode::ICmpNE) ||
|
||||
!IsZero(OtherCompareOperand(cmp, rem))) {
|
||||
all_uses_are_zero_tests = false;
|
||||
break;
|
||||
}
|
||||
compare_uses.push_back(cmp);
|
||||
}
|
||||
if (!all_uses_are_zero_tests || compare_uses.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto insert_index = FindInstructionIndex(block, rem) + 1;
|
||||
auto* masked = block->Insert<BinaryInst>(
|
||||
insert_index, Opcode::And, Type::GetInt32Type(), rem->GetLhs(),
|
||||
looputils::ConstInt(mask_value), nullptr,
|
||||
looputils::NextSyntheticName(function, "pow2.mask."));
|
||||
|
||||
for (auto* cmp : compare_uses) {
|
||||
if (cmp->GetLhs() == rem) {
|
||||
cmp->SetOperand(0, masked);
|
||||
}
|
||||
if (cmp->GetRhs() == rem) {
|
||||
cmp->SetOperand(1, masked);
|
||||
}
|
||||
}
|
||||
dead_rems.push_back(rem);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* rem : dead_rems) {
|
||||
if (rem->GetUses().empty() && rem->GetParent()) {
|
||||
rem->GetParent()->EraseInstruction(rem);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunArithmeticSimplify(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (!function || function->IsExternal()) {
|
||||
continue;
|
||||
}
|
||||
changed |= SimplifyPowerOfTwoRemTests(*function);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,239 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "PassUtils.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
Instruction* GetTerminator(BasicBlock* block) {
|
||||
if (block == nullptr || block->GetInstructions().empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* inst = block->GetInstructions().back().get();
|
||||
return inst != nullptr && inst->IsTerminator() ? inst : nullptr;
|
||||
}
|
||||
|
||||
std::size_t GetTerminatorIndex(BasicBlock* block) {
|
||||
const auto& instructions = block->GetInstructions();
|
||||
return instructions.empty() ? 0 : instructions.size() - 1;
|
||||
}
|
||||
|
||||
ConstantInt* ConstInt(int value) {
|
||||
return new ConstantInt(Type::GetInt32Type(), value);
|
||||
}
|
||||
|
||||
PhiInst* GetSinglePhi(BasicBlock* block) {
|
||||
if (block == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PhiInst* phi = nullptr;
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* current = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (current == nullptr) {
|
||||
break;
|
||||
}
|
||||
if (phi != nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
phi = current;
|
||||
}
|
||||
return phi;
|
||||
}
|
||||
|
||||
bool HasOnlyOneNonTerminator(BasicBlock* block, Instruction** out) {
|
||||
if (block == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Instruction* candidate = nullptr;
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst == nullptr || inst->IsTerminator()) {
|
||||
continue;
|
||||
}
|
||||
if (candidate != nullptr) {
|
||||
return false;
|
||||
}
|
||||
candidate = inst;
|
||||
}
|
||||
if (out != nullptr) {
|
||||
*out = candidate;
|
||||
}
|
||||
return candidate != nullptr;
|
||||
}
|
||||
|
||||
int IncomingIndexFor(PhiInst* phi, BasicBlock* block) {
|
||||
if (phi == nullptr || block == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
|
||||
if (phi->GetIncomingBlock(i) == block) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool IsUsedOnlyBy(Value* value, User* expected_user) {
|
||||
if (value == nullptr || expected_user == nullptr) {
|
||||
return false;
|
||||
}
|
||||
for (const auto& use : value->GetUses()) {
|
||||
if (use.GetUser() != expected_user) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
struct ConditionalAccumulation {
|
||||
Value* base = nullptr;
|
||||
Value* delta = nullptr;
|
||||
Opcode opcode = Opcode::Add;
|
||||
};
|
||||
|
||||
bool MatchConditionalAccumulation(PhiInst* phi, BasicBlock* pred,
|
||||
BasicBlock* update_block,
|
||||
BinaryInst* update,
|
||||
ConditionalAccumulation* match) {
|
||||
if (phi == nullptr || pred == nullptr || update_block == nullptr ||
|
||||
update == nullptr || match == nullptr || phi->GetNumIncomings() != 2 ||
|
||||
!phi->GetType()->IsInt32() || !update->GetType()->IsInt32()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int pred_index = IncomingIndexFor(phi, pred);
|
||||
const int update_index = IncomingIndexFor(phi, update_block);
|
||||
if (pred_index < 0 || update_index < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* base = phi->GetIncomingValue(pred_index);
|
||||
if (phi->GetIncomingValue(update_index) != update || base == nullptr ||
|
||||
!base->GetType()->IsInt32() || !IsUsedOnlyBy(update, phi)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* lhs = update->GetLhs();
|
||||
auto* rhs = update->GetRhs();
|
||||
if (update->GetOpcode() == Opcode::Add) {
|
||||
if (lhs == base && rhs != nullptr && rhs->GetType()->IsInt32()) {
|
||||
*match = {base, rhs, Opcode::Add};
|
||||
return true;
|
||||
}
|
||||
if (rhs == base && lhs != nullptr && lhs->GetType()->IsInt32()) {
|
||||
*match = {base, lhs, Opcode::Add};
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (update->GetOpcode() == Opcode::Sub && lhs == base && rhs != nullptr &&
|
||||
rhs->GetType()->IsInt32()) {
|
||||
*match = {base, rhs, Opcode::Sub};
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TryConvertConditionalAccumulation(Function& function, BasicBlock* pred) {
|
||||
auto* branch = dyncast<CondBrInst>(GetTerminator(pred));
|
||||
if (branch == nullptr || branch->GetCondition() == nullptr ||
|
||||
!branch->GetCondition()->GetType()->IsInt1()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* update_block = branch->GetThenBlock();
|
||||
auto* join = branch->GetElseBlock();
|
||||
if (update_block == nullptr || join == nullptr || update_block == join ||
|
||||
update_block->GetPredecessors().size() != 1 ||
|
||||
update_block->GetPredecessors().front() != pred ||
|
||||
update_block->GetSuccessors().size() != 1 ||
|
||||
update_block->GetSuccessors().front() != join) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* update_term = dyncast<UncondBrInst>(GetTerminator(update_block));
|
||||
if (update_term == nullptr || update_term->GetDest() != join) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Instruction* only_inst = nullptr;
|
||||
if (!HasOnlyOneNonTerminator(update_block, &only_inst)) {
|
||||
return false;
|
||||
}
|
||||
auto* update = dyncast<BinaryInst>(only_inst);
|
||||
if (update == nullptr ||
|
||||
(update->GetOpcode() != Opcode::Add && update->GetOpcode() != Opcode::Sub)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* phi = GetSinglePhi(join);
|
||||
ConditionalAccumulation accum;
|
||||
if (!MatchConditionalAccumulation(phi, pred, update_block, update, &accum)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::size_t insert_pos = GetTerminatorIndex(pred);
|
||||
auto* enabled = pred->Insert<ZextInst>(insert_pos, branch->GetCondition(),
|
||||
Type::GetInt32Type(), nullptr,
|
||||
"%ifconv.zext");
|
||||
auto* mask = pred->Insert<BinaryInst>(insert_pos + 1, Opcode::Sub,
|
||||
Type::GetInt32Type(), ConstInt(0),
|
||||
enabled, nullptr, "%ifconv.mask");
|
||||
auto* masked_delta = pred->Insert<BinaryInst>(
|
||||
insert_pos + 2, Opcode::And, Type::GetInt32Type(), accum.delta, mask,
|
||||
nullptr, "%ifconv.delta");
|
||||
auto* replacement = pred->Insert<BinaryInst>(
|
||||
insert_pos + 3, accum.opcode, Type::GetInt32Type(), accum.base,
|
||||
masked_delta, nullptr, "%ifconv.acc");
|
||||
|
||||
phi->ReplaceAllUsesWith(replacement);
|
||||
join->EraseInstruction(phi);
|
||||
|
||||
passutils::ReplaceTerminatorWithBr(pred, join);
|
||||
pred->RemoveSuccessor(update_block);
|
||||
update_block->RemovePredecessor(pred);
|
||||
passutils::RemoveUnreachableBlocks(function);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RunIfConversionOnFunction(Function& function) {
|
||||
if (function.IsExternal()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
bool local_changed = true;
|
||||
while (local_changed) {
|
||||
local_changed = false;
|
||||
auto blocks = passutils::CollectReachableBlocks(function);
|
||||
for (auto* block : blocks) {
|
||||
if (TryConvertConditionalAccumulation(function, block)) {
|
||||
local_changed = true;
|
||||
changed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunIfConversion(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function != nullptr) {
|
||||
changed |= RunIfConversionOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,145 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "PassUtils.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
bool IsScalarConstant(Value* value) {
|
||||
return dyncast<ConstantInt>(value) != nullptr ||
|
||||
dyncast<ConstantI1>(value) != nullptr ||
|
||||
dyncast<ConstantFloat>(value) != nullptr;
|
||||
}
|
||||
|
||||
bool IsScalarType(const std::shared_ptr<Type>& type) {
|
||||
return type && (type->IsInt32() || type->IsInt1() || type->IsFloat());
|
||||
}
|
||||
|
||||
bool IsReadonlyScalarGlobal(GlobalValue* global) {
|
||||
if (global == nullptr || !IsScalarType(global->GetObjectType()) ||
|
||||
!IsScalarConstant(global->GetInitializer())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto& use : global->GetUses()) {
|
||||
auto* user = dyncast<Instruction>(use.GetUser());
|
||||
if (auto* load = dyncast<LoadInst>(user)) {
|
||||
if (load->GetPtr() == global) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PropagateReadonlyScalarGlobals(Module& module) {
|
||||
bool changed = false;
|
||||
std::vector<LoadInst*> dead_loads;
|
||||
|
||||
for (const auto& global_ptr : module.GetGlobalValues()) {
|
||||
auto* global = global_ptr.get();
|
||||
if (!IsReadonlyScalarGlobal(global)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto uses = global->GetUses();
|
||||
for (const auto& use : uses) {
|
||||
auto* load = dyncast<LoadInst>(use.GetUser());
|
||||
if (load == nullptr || load->GetPtr() != global) {
|
||||
continue;
|
||||
}
|
||||
load->ReplaceAllUsesWith(global->GetInitializer());
|
||||
dead_loads.push_back(load);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* load : dead_loads) {
|
||||
if (load != nullptr && load->GetParent() != nullptr && load->GetUses().empty()) {
|
||||
load->GetParent()->EraseInstruction(load);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
std::vector<CallInst*> CollectDirectCalls(Function& function, bool* all_uses_are_calls) {
|
||||
std::vector<CallInst*> calls;
|
||||
*all_uses_are_calls = true;
|
||||
for (const auto& use : function.GetUses()) {
|
||||
if (use.GetOperandIndex() != 0) {
|
||||
*all_uses_are_calls = false;
|
||||
return {};
|
||||
}
|
||||
auto* call = dyncast<CallInst>(use.GetUser());
|
||||
if (call == nullptr || call->GetCallee() != &function) {
|
||||
*all_uses_are_calls = false;
|
||||
return {};
|
||||
}
|
||||
calls.push_back(call);
|
||||
}
|
||||
return calls;
|
||||
}
|
||||
|
||||
bool PropagateConstantArguments(Function& function) {
|
||||
if (function.IsExternal() || function.GetName() == "main" ||
|
||||
function.GetArguments().empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool all_uses_are_calls = false;
|
||||
auto calls = CollectDirectCalls(function, &all_uses_are_calls);
|
||||
if (!all_uses_are_calls || calls.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
for (std::size_t index = 0; index < function.GetArguments().size(); ++index) {
|
||||
auto* argument = function.GetArgument(index);
|
||||
if (argument == nullptr || !IsScalarType(argument->GetType()) ||
|
||||
argument->GetUses().empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Value* constant = nullptr;
|
||||
bool same_constant = true;
|
||||
for (auto* call : calls) {
|
||||
const auto args = call->GetArguments();
|
||||
if (index >= args.size() || !IsScalarConstant(args[index])) {
|
||||
same_constant = false;
|
||||
break;
|
||||
}
|
||||
if (constant == nullptr) {
|
||||
constant = args[index];
|
||||
} else if (!passutils::AreEquivalentValues(constant, args[index])) {
|
||||
same_constant = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!same_constant || constant == nullptr) {
|
||||
continue;
|
||||
}
|
||||
argument->ReplaceAllUsesWith(constant);
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunInterproceduralConstProp(Module& module) {
|
||||
bool changed = false;
|
||||
changed |= PropagateReadonlyScalarGlobals(module);
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function != nullptr) {
|
||||
changed |= PropagateConstantArguments(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,264 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "ir/passes/LoopPassUtils.h"
|
||||
|
||||
#include <queue>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
bool IsConstInt(Value* value, int expected) {
|
||||
auto* constant = dyncast<ConstantInt>(value);
|
||||
return constant != nullptr && constant->GetValue() == expected;
|
||||
}
|
||||
|
||||
bool IsAddOneOf(Value* value, Value* base) {
|
||||
auto* add = dyncast<BinaryInst>(value);
|
||||
if (!add || add->GetOpcode() != Opcode::Add) {
|
||||
return false;
|
||||
}
|
||||
return (add->GetLhs() == base && IsConstInt(add->GetRhs(), 1)) ||
|
||||
(add->GetRhs() == base && IsConstInt(add->GetLhs(), 1));
|
||||
}
|
||||
|
||||
bool HasForbiddenSideEffects(const Loop& loop) {
|
||||
for (auto* block : loop.block_list) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Store:
|
||||
case Opcode::Memset:
|
||||
case Opcode::Call:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool HasUseOutsideLoop(Value* value, const Loop& loop) {
|
||||
for (const auto& use : value->GetUses()) {
|
||||
auto* inst = dyncast<Instruction>(use.GetUser());
|
||||
if (!inst || !loop.Contains(inst->GetParent())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool InductionOnlyControlsRepeatCount(PhiInst* induction, BinaryInst* compare,
|
||||
BinaryInst* next, const Loop& loop) {
|
||||
for (const auto& use : induction->GetUses()) {
|
||||
auto* inst = dyncast<Instruction>(use.GetUser());
|
||||
if (!inst) {
|
||||
return false;
|
||||
}
|
||||
if (inst == compare || inst == next) {
|
||||
continue;
|
||||
}
|
||||
if (loop.Contains(inst->GetParent())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsAdditiveAccumulator(PhiInst* accumulator, BasicBlock* preheader,
|
||||
BasicBlock* latch, const Loop& loop) {
|
||||
if (!accumulator || !accumulator->IsInt32()) {
|
||||
return false;
|
||||
}
|
||||
const int preheader_index = looputils::GetPhiIncomingIndex(accumulator, preheader);
|
||||
const int latch_index = looputils::GetPhiIncomingIndex(accumulator, latch);
|
||||
if (preheader_index < 0 || latch_index < 0) {
|
||||
return false;
|
||||
}
|
||||
if (!IsConstInt(accumulator->GetIncomingValue(preheader_index), 0)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* latch_value = accumulator->GetIncomingValue(latch_index);
|
||||
if (latch_value == accumulator) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<Value*> derived;
|
||||
std::vector<BinaryInst*> additive_steps;
|
||||
std::queue<Value*> worklist;
|
||||
derived.insert(accumulator);
|
||||
worklist.push(accumulator);
|
||||
|
||||
auto remember = [&](Value* value) {
|
||||
if (derived.insert(value).second) {
|
||||
worklist.push(value);
|
||||
}
|
||||
};
|
||||
|
||||
while (!worklist.empty()) {
|
||||
auto* value = worklist.front();
|
||||
worklist.pop();
|
||||
for (const auto& use : value->GetUses()) {
|
||||
auto* inst = dyncast<Instruction>(use.GetUser());
|
||||
if (!inst || !loop.Contains(inst->GetParent())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto* phi = dyncast<PhiInst>(inst)) {
|
||||
remember(phi);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* binary = dyncast<BinaryInst>(inst);
|
||||
if (!binary || binary->GetOpcode() != Opcode::Add) {
|
||||
return false;
|
||||
}
|
||||
additive_steps.push_back(binary);
|
||||
remember(binary);
|
||||
}
|
||||
}
|
||||
|
||||
if (derived.find(latch_value) == derived.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto* add : additive_steps) {
|
||||
const bool lhs_derived = derived.find(add->GetLhs()) != derived.end();
|
||||
const bool rhs_derived = derived.find(add->GetRhs()) != derived.end();
|
||||
if (lhs_derived == rhs_derived) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<PhiInst*> GetHeaderPhis(BasicBlock* header) {
|
||||
std::vector<PhiInst*> phis;
|
||||
if (!header) {
|
||||
return phis;
|
||||
}
|
||||
for (const auto& inst_ptr : header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
phis.push_back(phi);
|
||||
}
|
||||
return phis;
|
||||
}
|
||||
|
||||
bool TryReduceRepeatLoop(Function& function, Loop& loop) {
|
||||
if (!loop.header || !loop.preheader || loop.latches.size() != 1 ||
|
||||
loop.exit_blocks.size() != 1 || HasForbiddenSideEffects(loop)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* latch = loop.latches.front();
|
||||
auto* exit = loop.exit_blocks.front();
|
||||
auto* branch =
|
||||
dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
|
||||
if (!branch) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* compare = dyncast<BinaryInst>(branch->GetCondition());
|
||||
if (!compare || compare->GetOpcode() != Opcode::ICmpLT) {
|
||||
return false;
|
||||
}
|
||||
if (!loop.Contains(branch->GetThenBlock()) || branch->GetElseBlock() != exit) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* induction = dyncast<PhiInst>(compare->GetLhs());
|
||||
auto* bound = compare->GetRhs();
|
||||
if (!induction || induction->GetParent() != loop.header ||
|
||||
!looputils::IsLoopInvariantValue(loop, bound)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int induction_preheader_index =
|
||||
looputils::GetPhiIncomingIndex(induction, loop.preheader);
|
||||
const int induction_latch_index = looputils::GetPhiIncomingIndex(induction, latch);
|
||||
if (induction_preheader_index < 0 || induction_latch_index < 0 ||
|
||||
!IsConstInt(induction->GetIncomingValue(induction_preheader_index), 0)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* induction_next =
|
||||
dyncast<BinaryInst>(induction->GetIncomingValue(induction_latch_index));
|
||||
if (!IsAddOneOf(induction_next, induction) ||
|
||||
!InductionOnlyControlsRepeatCount(induction, compare, induction_next, loop)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<PhiInst*> accumulators;
|
||||
for (auto* phi : GetHeaderPhis(loop.header)) {
|
||||
if (phi == induction) {
|
||||
continue;
|
||||
}
|
||||
if (!IsAdditiveAccumulator(phi, loop.preheader, latch, loop)) {
|
||||
return false;
|
||||
}
|
||||
if (HasUseOutsideLoop(phi, loop)) {
|
||||
accumulators.push_back(phi);
|
||||
}
|
||||
}
|
||||
if (accumulators.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Force the counted loop to stop after one executed body: the first test still
|
||||
// uses 0 < bound, so non-positive trip counts continue to execute zero times.
|
||||
induction->SetOperand(static_cast<std::size_t>(2 * induction_latch_index), bound);
|
||||
|
||||
std::size_t insert_index = looputils::GetFirstNonPhiIndex(exit);
|
||||
bool changed = true;
|
||||
for (auto* accumulator : accumulators) {
|
||||
auto* scaled = exit->Insert<BinaryInst>(
|
||||
insert_index++, Opcode::Mul, Type::GetInt32Type(), accumulator, bound,
|
||||
nullptr, looputils::NextSyntheticName(function, "repeat.reduce"));
|
||||
|
||||
const auto uses = accumulator->GetUses();
|
||||
for (const auto& use : uses) {
|
||||
auto* user = use.GetUser();
|
||||
auto* user_inst = dyncast<Instruction>(user);
|
||||
if (user_inst == scaled) {
|
||||
continue;
|
||||
}
|
||||
if (!user_inst || !loop.Contains(user_inst->GetParent())) {
|
||||
user->SetOperand(use.GetOperandIndex(), scaled);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RunOnFunction(Function& function) {
|
||||
DominatorTree dom_tree(function);
|
||||
LoopInfo loop_info(function, dom_tree);
|
||||
bool changed = false;
|
||||
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
||||
changed |= TryReduceRepeatLoop(function, *loop);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopRepeatReduction(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function && !function->IsExternal()) {
|
||||
changed |= RunOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,249 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct TailCallSite {
|
||||
BasicBlock* block = nullptr;
|
||||
CallInst* call = nullptr;
|
||||
ReturnInst* ret = nullptr;
|
||||
};
|
||||
|
||||
bool HasEntryPhi(Function& function) {
|
||||
auto* entry = function.GetEntryBlock();
|
||||
if (!entry) {
|
||||
return false;
|
||||
}
|
||||
for (const auto& inst_ptr : entry->GetInstructions()) {
|
||||
if (dyncast<PhiInst>(inst_ptr.get())) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsOnlyUsedByReturn(CallInst* call, ReturnInst* ret) {
|
||||
if (!call || !ret) {
|
||||
return false;
|
||||
}
|
||||
const auto& uses = call->GetUses();
|
||||
return uses.size() == 1 && uses.front().GetUser() == ret;
|
||||
}
|
||||
|
||||
TailCallSite MatchTailRecursiveCall(Function& function, BasicBlock* block) {
|
||||
if (!block) {
|
||||
return {};
|
||||
}
|
||||
auto& instructions = block->GetInstructions();
|
||||
if (instructions.size() < 2) {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto* ret = dyncast<ReturnInst>(instructions.back().get());
|
||||
if (!ret) {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto* previous = instructions[instructions.size() - 2].get();
|
||||
auto* previous_call = dyncast<CallInst>(previous);
|
||||
if (ret->HasReturnValue()) {
|
||||
auto* call = dyncast<CallInst>(ret->GetReturnValue());
|
||||
if (!call || call != previous_call || call->GetParent() != block ||
|
||||
call->GetCallee() != &function || !IsOnlyUsedByReturn(call, ret)) {
|
||||
return {};
|
||||
}
|
||||
return {block, call, ret};
|
||||
}
|
||||
|
||||
if (!previous_call || previous_call->GetCallee() != &function ||
|
||||
!previous_call->GetType()->IsVoid() || !previous_call->GetUses().empty()) {
|
||||
return {};
|
||||
}
|
||||
return {block, previous_call, ret};
|
||||
}
|
||||
|
||||
std::vector<TailCallSite> CollectTailCallSites(Function& function) {
|
||||
std::vector<TailCallSite> sites;
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
auto site = MatchTailRecursiveCall(function, block_ptr.get());
|
||||
if (site.block && site.call && site.ret) {
|
||||
sites.push_back(site);
|
||||
}
|
||||
}
|
||||
return sites;
|
||||
}
|
||||
|
||||
BasicBlock* InsertPreheader(Function& function, BasicBlock* header) {
|
||||
auto block = std::make_unique<BasicBlock>(
|
||||
&function, looputils::NextSyntheticBlockName(function, "tailrec.entry"));
|
||||
auto* preheader = block.get();
|
||||
|
||||
auto& blocks = function.GetBlocks();
|
||||
blocks.insert(blocks.begin(), std::move(block));
|
||||
function.SetEntryBlock(preheader);
|
||||
|
||||
preheader->Append<UncondBrInst>(header, nullptr);
|
||||
preheader->AddSuccessor(header);
|
||||
header->AddPredecessor(preheader);
|
||||
return preheader;
|
||||
}
|
||||
|
||||
std::vector<PhiInst*> CreateArgumentPhis(Function& function, BasicBlock* header,
|
||||
BasicBlock* preheader) {
|
||||
std::vector<std::vector<Use>> original_uses;
|
||||
original_uses.reserve(function.GetArguments().size());
|
||||
for (const auto& arg : function.GetArguments()) {
|
||||
original_uses.push_back(arg->GetUses());
|
||||
}
|
||||
|
||||
std::vector<PhiInst*> phis;
|
||||
phis.reserve(function.GetArguments().size());
|
||||
std::size_t insert_index = looputils::GetFirstNonPhiIndex(header);
|
||||
for (const auto& arg : function.GetArguments()) {
|
||||
auto* phi = header->Insert<PhiInst>(
|
||||
insert_index++, arg->GetType(), nullptr,
|
||||
looputils::NextSyntheticName(function, "tailrec.arg."));
|
||||
phi->AddIncoming(arg.get(), preheader);
|
||||
phis.push_back(phi);
|
||||
}
|
||||
|
||||
for (std::size_t i = 0; i < function.GetArguments().size(); ++i) {
|
||||
for (const auto& use : original_uses[i]) {
|
||||
if (auto* user = use.GetUser()) {
|
||||
user->SetOperand(use.GetOperandIndex(), phis[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return phis;
|
||||
}
|
||||
|
||||
void ReplaceTerminatorWithBranch(BasicBlock* block, BasicBlock* dest) {
|
||||
auto& instructions = block->GetInstructions();
|
||||
instructions.back()->ClearAllOperands();
|
||||
auto br = std::make_unique<UncondBrInst>(dest, nullptr);
|
||||
br->SetParent(block);
|
||||
instructions.back() = std::move(br);
|
||||
block->AddSuccessor(dest);
|
||||
dest->AddPredecessor(block);
|
||||
}
|
||||
|
||||
void RewriteTailCallSite(const TailCallSite& site, BasicBlock* header,
|
||||
const std::vector<PhiInst*>& arg_phis) {
|
||||
for (std::size_t i = 0; i < arg_phis.size(); ++i) {
|
||||
arg_phis[i]->AddIncoming(site.call->GetOperand(i + 1), site.block);
|
||||
}
|
||||
|
||||
ReplaceTerminatorWithBranch(site.block, header);
|
||||
site.block->EraseInstruction(site.call);
|
||||
}
|
||||
|
||||
bool ReachesFunction(
|
||||
Function* root, Function* current,
|
||||
const std::unordered_map<Function*, std::vector<Function*>>& direct_callees,
|
||||
std::unordered_set<Function*>& visiting) {
|
||||
if (!root || !current || current->IsExternal()) {
|
||||
return false;
|
||||
}
|
||||
if (!visiting.insert(current).second) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto it = direct_callees.find(current);
|
||||
if (it == direct_callees.end()) {
|
||||
return false;
|
||||
}
|
||||
for (auto* callee : it->second) {
|
||||
if (callee == root) {
|
||||
return true;
|
||||
}
|
||||
if (ReachesFunction(root, callee, direct_callees, visiting)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void RecomputeRecursiveFlags(Module& module) {
|
||||
std::unordered_map<Function*, std::vector<Function*>> direct_callees;
|
||||
for (const auto& function_ptr : module.GetFunctions()) {
|
||||
auto* function = function_ptr.get();
|
||||
if (!function || function->IsExternal()) {
|
||||
continue;
|
||||
}
|
||||
auto& callees = direct_callees[function];
|
||||
for (const auto& block_ptr : function->GetBlocks()) {
|
||||
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
|
||||
auto* call = dyncast<CallInst>(inst_ptr.get());
|
||||
auto* callee = call ? call->GetCallee() : nullptr;
|
||||
if (callee && !callee->IsExternal() &&
|
||||
std::find(callees.begin(), callees.end(), callee) == callees.end()) {
|
||||
callees.push_back(callee);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& function_ptr : module.GetFunctions()) {
|
||||
auto* function = function_ptr.get();
|
||||
if (!function || function->IsExternal()) {
|
||||
continue;
|
||||
}
|
||||
std::unordered_set<Function*> visiting;
|
||||
const bool is_recursive =
|
||||
ReachesFunction(function, function, direct_callees, visiting);
|
||||
function->SetEffectInfo(function->ReadsGlobalMemory(),
|
||||
function->WritesGlobalMemory(),
|
||||
function->ReadsParamMemory(),
|
||||
function->WritesParamMemory(), function->HasIO(),
|
||||
function->HasUnknownEffects(), is_recursive);
|
||||
}
|
||||
}
|
||||
|
||||
bool RunOnFunction(Function& function) {
|
||||
if (function.IsExternal() || !function.GetEntryBlock() || HasEntryPhi(function)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto sites = CollectTailCallSites(function);
|
||||
if (sites.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* header = function.GetEntryBlock();
|
||||
auto* preheader = InsertPreheader(function, header);
|
||||
auto arg_phis = CreateArgumentPhis(function, header, preheader);
|
||||
|
||||
for (const auto& site : sites) {
|
||||
RewriteTailCallSite(site, header, arg_phis);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunTailRecursionElim(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function_ptr : module.GetFunctions()) {
|
||||
if (function_ptr) {
|
||||
changed |= RunOnFunction(*function_ptr);
|
||||
}
|
||||
}
|
||||
if (changed) {
|
||||
RecomputeRecursiveFlags(module);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
Loading…
Reference in new issue