forked from p4jyxwm3q/nudt-compiler-cpp
Compare commits
6 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
78168fe917 | 3 days ago |
|
|
e55421f447 | 2 weeks ago |
|
|
69892ef133 | 2 weeks ago |
|
|
407be0fca1 | 2 weeks ago |
|
|
08ce9d96ab | 2 weeks ago |
|
|
bcfbf52488 | 2 weeks ago |
@ -0,0 +1,137 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
bool IsPowerOfTwoPositive(int value) {
|
||||
return value > 0 && (value & (value - 1)) == 0;
|
||||
}
|
||||
|
||||
std::size_t FindInstructionIndex(BasicBlock* block, Instruction* inst) {
|
||||
if (!block || !inst) {
|
||||
return 0;
|
||||
}
|
||||
auto& instructions = block->GetInstructions();
|
||||
for (std::size_t i = 0; i < instructions.size(); ++i) {
|
||||
if (instructions[i].get() == inst) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return instructions.size();
|
||||
}
|
||||
|
||||
bool IsZero(Value* value) {
|
||||
if (auto* ci = dyncast<ConstantInt>(value)) {
|
||||
return ci->GetValue() == 0;
|
||||
}
|
||||
if (auto* cb = dyncast<ConstantI1>(value)) {
|
||||
return !cb->GetValue();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Value* OtherCompareOperand(BinaryInst* cmp, Value* value) {
|
||||
if (!cmp || cmp->GetNumOperands() != 2) {
|
||||
return nullptr;
|
||||
}
|
||||
if (cmp->GetLhs() == value) {
|
||||
return cmp->GetRhs();
|
||||
}
|
||||
if (cmp->GetRhs() == value) {
|
||||
return cmp->GetLhs();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool SimplifyPowerOfTwoRemTests(Function& function) {
|
||||
bool changed = false;
|
||||
std::vector<Instruction*> dead_rems;
|
||||
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
auto* block = block_ptr.get();
|
||||
if (!block) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* rem = dyncast<BinaryInst>(inst_ptr.get());
|
||||
if (!rem || rem->GetOpcode() != Opcode::Rem) {
|
||||
continue;
|
||||
}
|
||||
auto* divisor = dyncast<ConstantInt>(rem->GetRhs());
|
||||
if (!divisor || !IsPowerOfTwoPositive(divisor->GetValue())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int mask_value = divisor->GetValue() - 1;
|
||||
if (mask_value == 0) {
|
||||
rem->ReplaceAllUsesWith(looputils::ConstInt(0));
|
||||
dead_rems.push_back(rem);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<BinaryInst*> compare_uses;
|
||||
bool all_uses_are_zero_tests = !rem->GetUses().empty();
|
||||
for (const auto& use : rem->GetUses()) {
|
||||
auto* cmp = dyncast<BinaryInst>(dynamic_cast<Value*>(use.GetUser()));
|
||||
if (!cmp || (cmp->GetOpcode() != Opcode::ICmpEQ &&
|
||||
cmp->GetOpcode() != Opcode::ICmpNE) ||
|
||||
!IsZero(OtherCompareOperand(cmp, rem))) {
|
||||
all_uses_are_zero_tests = false;
|
||||
break;
|
||||
}
|
||||
compare_uses.push_back(cmp);
|
||||
}
|
||||
if (!all_uses_are_zero_tests || compare_uses.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto insert_index = FindInstructionIndex(block, rem) + 1;
|
||||
auto* masked = block->Insert<BinaryInst>(
|
||||
insert_index, Opcode::And, Type::GetInt32Type(), rem->GetLhs(),
|
||||
looputils::ConstInt(mask_value), nullptr,
|
||||
looputils::NextSyntheticName(function, "pow2.mask."));
|
||||
|
||||
for (auto* cmp : compare_uses) {
|
||||
if (cmp->GetLhs() == rem) {
|
||||
cmp->SetOperand(0, masked);
|
||||
}
|
||||
if (cmp->GetRhs() == rem) {
|
||||
cmp->SetOperand(1, masked);
|
||||
}
|
||||
}
|
||||
dead_rems.push_back(rem);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* rem : dead_rems) {
|
||||
if (rem->GetUses().empty() && rem->GetParent()) {
|
||||
rem->GetParent()->EraseInstruction(rem);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunArithmeticSimplify(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (!function || function->IsExternal()) {
|
||||
continue;
|
||||
}
|
||||
changed |= SimplifyPowerOfTwoRemTests(*function);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,239 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "PassUtils.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
Instruction* GetTerminator(BasicBlock* block) {
|
||||
if (block == nullptr || block->GetInstructions().empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* inst = block->GetInstructions().back().get();
|
||||
return inst != nullptr && inst->IsTerminator() ? inst : nullptr;
|
||||
}
|
||||
|
||||
std::size_t GetTerminatorIndex(BasicBlock* block) {
|
||||
const auto& instructions = block->GetInstructions();
|
||||
return instructions.empty() ? 0 : instructions.size() - 1;
|
||||
}
|
||||
|
||||
ConstantInt* ConstInt(int value) {
|
||||
return new ConstantInt(Type::GetInt32Type(), value);
|
||||
}
|
||||
|
||||
PhiInst* GetSinglePhi(BasicBlock* block) {
|
||||
if (block == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PhiInst* phi = nullptr;
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* current = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (current == nullptr) {
|
||||
break;
|
||||
}
|
||||
if (phi != nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
phi = current;
|
||||
}
|
||||
return phi;
|
||||
}
|
||||
|
||||
bool HasOnlyOneNonTerminator(BasicBlock* block, Instruction** out) {
|
||||
if (block == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Instruction* candidate = nullptr;
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst == nullptr || inst->IsTerminator()) {
|
||||
continue;
|
||||
}
|
||||
if (candidate != nullptr) {
|
||||
return false;
|
||||
}
|
||||
candidate = inst;
|
||||
}
|
||||
if (out != nullptr) {
|
||||
*out = candidate;
|
||||
}
|
||||
return candidate != nullptr;
|
||||
}
|
||||
|
||||
int IncomingIndexFor(PhiInst* phi, BasicBlock* block) {
|
||||
if (phi == nullptr || block == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
|
||||
if (phi->GetIncomingBlock(i) == block) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool IsUsedOnlyBy(Value* value, User* expected_user) {
|
||||
if (value == nullptr || expected_user == nullptr) {
|
||||
return false;
|
||||
}
|
||||
for (const auto& use : value->GetUses()) {
|
||||
if (use.GetUser() != expected_user) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
struct ConditionalAccumulation {
|
||||
Value* base = nullptr;
|
||||
Value* delta = nullptr;
|
||||
Opcode opcode = Opcode::Add;
|
||||
};
|
||||
|
||||
bool MatchConditionalAccumulation(PhiInst* phi, BasicBlock* pred,
|
||||
BasicBlock* update_block,
|
||||
BinaryInst* update,
|
||||
ConditionalAccumulation* match) {
|
||||
if (phi == nullptr || pred == nullptr || update_block == nullptr ||
|
||||
update == nullptr || match == nullptr || phi->GetNumIncomings() != 2 ||
|
||||
!phi->GetType()->IsInt32() || !update->GetType()->IsInt32()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int pred_index = IncomingIndexFor(phi, pred);
|
||||
const int update_index = IncomingIndexFor(phi, update_block);
|
||||
if (pred_index < 0 || update_index < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* base = phi->GetIncomingValue(pred_index);
|
||||
if (phi->GetIncomingValue(update_index) != update || base == nullptr ||
|
||||
!base->GetType()->IsInt32() || !IsUsedOnlyBy(update, phi)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* lhs = update->GetLhs();
|
||||
auto* rhs = update->GetRhs();
|
||||
if (update->GetOpcode() == Opcode::Add) {
|
||||
if (lhs == base && rhs != nullptr && rhs->GetType()->IsInt32()) {
|
||||
*match = {base, rhs, Opcode::Add};
|
||||
return true;
|
||||
}
|
||||
if (rhs == base && lhs != nullptr && lhs->GetType()->IsInt32()) {
|
||||
*match = {base, lhs, Opcode::Add};
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (update->GetOpcode() == Opcode::Sub && lhs == base && rhs != nullptr &&
|
||||
rhs->GetType()->IsInt32()) {
|
||||
*match = {base, rhs, Opcode::Sub};
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TryConvertConditionalAccumulation(Function& function, BasicBlock* pred) {
|
||||
auto* branch = dyncast<CondBrInst>(GetTerminator(pred));
|
||||
if (branch == nullptr || branch->GetCondition() == nullptr ||
|
||||
!branch->GetCondition()->GetType()->IsInt1()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* update_block = branch->GetThenBlock();
|
||||
auto* join = branch->GetElseBlock();
|
||||
if (update_block == nullptr || join == nullptr || update_block == join ||
|
||||
update_block->GetPredecessors().size() != 1 ||
|
||||
update_block->GetPredecessors().front() != pred ||
|
||||
update_block->GetSuccessors().size() != 1 ||
|
||||
update_block->GetSuccessors().front() != join) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* update_term = dyncast<UncondBrInst>(GetTerminator(update_block));
|
||||
if (update_term == nullptr || update_term->GetDest() != join) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Instruction* only_inst = nullptr;
|
||||
if (!HasOnlyOneNonTerminator(update_block, &only_inst)) {
|
||||
return false;
|
||||
}
|
||||
auto* update = dyncast<BinaryInst>(only_inst);
|
||||
if (update == nullptr ||
|
||||
(update->GetOpcode() != Opcode::Add && update->GetOpcode() != Opcode::Sub)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* phi = GetSinglePhi(join);
|
||||
ConditionalAccumulation accum;
|
||||
if (!MatchConditionalAccumulation(phi, pred, update_block, update, &accum)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::size_t insert_pos = GetTerminatorIndex(pred);
|
||||
auto* enabled = pred->Insert<ZextInst>(insert_pos, branch->GetCondition(),
|
||||
Type::GetInt32Type(), nullptr,
|
||||
"%ifconv.zext");
|
||||
auto* mask = pred->Insert<BinaryInst>(insert_pos + 1, Opcode::Sub,
|
||||
Type::GetInt32Type(), ConstInt(0),
|
||||
enabled, nullptr, "%ifconv.mask");
|
||||
auto* masked_delta = pred->Insert<BinaryInst>(
|
||||
insert_pos + 2, Opcode::And, Type::GetInt32Type(), accum.delta, mask,
|
||||
nullptr, "%ifconv.delta");
|
||||
auto* replacement = pred->Insert<BinaryInst>(
|
||||
insert_pos + 3, accum.opcode, Type::GetInt32Type(), accum.base,
|
||||
masked_delta, nullptr, "%ifconv.acc");
|
||||
|
||||
phi->ReplaceAllUsesWith(replacement);
|
||||
join->EraseInstruction(phi);
|
||||
|
||||
passutils::ReplaceTerminatorWithBr(pred, join);
|
||||
pred->RemoveSuccessor(update_block);
|
||||
update_block->RemovePredecessor(pred);
|
||||
passutils::RemoveUnreachableBlocks(function);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RunIfConversionOnFunction(Function& function) {
|
||||
if (function.IsExternal()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
bool local_changed = true;
|
||||
while (local_changed) {
|
||||
local_changed = false;
|
||||
auto blocks = passutils::CollectReachableBlocks(function);
|
||||
for (auto* block : blocks) {
|
||||
if (TryConvertConditionalAccumulation(function, block)) {
|
||||
local_changed = true;
|
||||
changed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunIfConversion(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function != nullptr) {
|
||||
changed |= RunIfConversionOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,145 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "PassUtils.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
bool IsScalarConstant(Value* value) {
|
||||
return dyncast<ConstantInt>(value) != nullptr ||
|
||||
dyncast<ConstantI1>(value) != nullptr ||
|
||||
dyncast<ConstantFloat>(value) != nullptr;
|
||||
}
|
||||
|
||||
bool IsScalarType(const std::shared_ptr<Type>& type) {
|
||||
return type && (type->IsInt32() || type->IsInt1() || type->IsFloat());
|
||||
}
|
||||
|
||||
bool IsReadonlyScalarGlobal(GlobalValue* global) {
|
||||
if (global == nullptr || !IsScalarType(global->GetObjectType()) ||
|
||||
!IsScalarConstant(global->GetInitializer())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto& use : global->GetUses()) {
|
||||
auto* user = dyncast<Instruction>(use.GetUser());
|
||||
if (auto* load = dyncast<LoadInst>(user)) {
|
||||
if (load->GetPtr() == global) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PropagateReadonlyScalarGlobals(Module& module) {
|
||||
bool changed = false;
|
||||
std::vector<LoadInst*> dead_loads;
|
||||
|
||||
for (const auto& global_ptr : module.GetGlobalValues()) {
|
||||
auto* global = global_ptr.get();
|
||||
if (!IsReadonlyScalarGlobal(global)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto uses = global->GetUses();
|
||||
for (const auto& use : uses) {
|
||||
auto* load = dyncast<LoadInst>(use.GetUser());
|
||||
if (load == nullptr || load->GetPtr() != global) {
|
||||
continue;
|
||||
}
|
||||
load->ReplaceAllUsesWith(global->GetInitializer());
|
||||
dead_loads.push_back(load);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* load : dead_loads) {
|
||||
if (load != nullptr && load->GetParent() != nullptr && load->GetUses().empty()) {
|
||||
load->GetParent()->EraseInstruction(load);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
std::vector<CallInst*> CollectDirectCalls(Function& function, bool* all_uses_are_calls) {
|
||||
std::vector<CallInst*> calls;
|
||||
*all_uses_are_calls = true;
|
||||
for (const auto& use : function.GetUses()) {
|
||||
if (use.GetOperandIndex() != 0) {
|
||||
*all_uses_are_calls = false;
|
||||
return {};
|
||||
}
|
||||
auto* call = dyncast<CallInst>(use.GetUser());
|
||||
if (call == nullptr || call->GetCallee() != &function) {
|
||||
*all_uses_are_calls = false;
|
||||
return {};
|
||||
}
|
||||
calls.push_back(call);
|
||||
}
|
||||
return calls;
|
||||
}
|
||||
|
||||
bool PropagateConstantArguments(Function& function) {
|
||||
if (function.IsExternal() || function.GetName() == "main" ||
|
||||
function.GetArguments().empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool all_uses_are_calls = false;
|
||||
auto calls = CollectDirectCalls(function, &all_uses_are_calls);
|
||||
if (!all_uses_are_calls || calls.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
for (std::size_t index = 0; index < function.GetArguments().size(); ++index) {
|
||||
auto* argument = function.GetArgument(index);
|
||||
if (argument == nullptr || !IsScalarType(argument->GetType()) ||
|
||||
argument->GetUses().empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Value* constant = nullptr;
|
||||
bool same_constant = true;
|
||||
for (auto* call : calls) {
|
||||
const auto args = call->GetArguments();
|
||||
if (index >= args.size() || !IsScalarConstant(args[index])) {
|
||||
same_constant = false;
|
||||
break;
|
||||
}
|
||||
if (constant == nullptr) {
|
||||
constant = args[index];
|
||||
} else if (!passutils::AreEquivalentValues(constant, args[index])) {
|
||||
same_constant = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!same_constant || constant == nullptr) {
|
||||
continue;
|
||||
}
|
||||
argument->ReplaceAllUsesWith(constant);
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunInterproceduralConstProp(Module& module) {
|
||||
bool changed = false;
|
||||
changed |= PropagateReadonlyScalarGlobals(module);
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function != nullptr) {
|
||||
changed |= PropagateConstantArguments(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,264 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "ir/passes/LoopPassUtils.h"
|
||||
|
||||
#include <queue>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
bool IsConstInt(Value* value, int expected) {
|
||||
auto* constant = dyncast<ConstantInt>(value);
|
||||
return constant != nullptr && constant->GetValue() == expected;
|
||||
}
|
||||
|
||||
bool IsAddOneOf(Value* value, Value* base) {
|
||||
auto* add = dyncast<BinaryInst>(value);
|
||||
if (!add || add->GetOpcode() != Opcode::Add) {
|
||||
return false;
|
||||
}
|
||||
return (add->GetLhs() == base && IsConstInt(add->GetRhs(), 1)) ||
|
||||
(add->GetRhs() == base && IsConstInt(add->GetLhs(), 1));
|
||||
}
|
||||
|
||||
bool HasForbiddenSideEffects(const Loop& loop) {
|
||||
for (auto* block : loop.block_list) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Store:
|
||||
case Opcode::Memset:
|
||||
case Opcode::Call:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool HasUseOutsideLoop(Value* value, const Loop& loop) {
|
||||
for (const auto& use : value->GetUses()) {
|
||||
auto* inst = dyncast<Instruction>(use.GetUser());
|
||||
if (!inst || !loop.Contains(inst->GetParent())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool InductionOnlyControlsRepeatCount(PhiInst* induction, BinaryInst* compare,
|
||||
BinaryInst* next, const Loop& loop) {
|
||||
for (const auto& use : induction->GetUses()) {
|
||||
auto* inst = dyncast<Instruction>(use.GetUser());
|
||||
if (!inst) {
|
||||
return false;
|
||||
}
|
||||
if (inst == compare || inst == next) {
|
||||
continue;
|
||||
}
|
||||
if (loop.Contains(inst->GetParent())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsAdditiveAccumulator(PhiInst* accumulator, BasicBlock* preheader,
|
||||
BasicBlock* latch, const Loop& loop) {
|
||||
if (!accumulator || !accumulator->IsInt32()) {
|
||||
return false;
|
||||
}
|
||||
const int preheader_index = looputils::GetPhiIncomingIndex(accumulator, preheader);
|
||||
const int latch_index = looputils::GetPhiIncomingIndex(accumulator, latch);
|
||||
if (preheader_index < 0 || latch_index < 0) {
|
||||
return false;
|
||||
}
|
||||
if (!IsConstInt(accumulator->GetIncomingValue(preheader_index), 0)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* latch_value = accumulator->GetIncomingValue(latch_index);
|
||||
if (latch_value == accumulator) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<Value*> derived;
|
||||
std::vector<BinaryInst*> additive_steps;
|
||||
std::queue<Value*> worklist;
|
||||
derived.insert(accumulator);
|
||||
worklist.push(accumulator);
|
||||
|
||||
auto remember = [&](Value* value) {
|
||||
if (derived.insert(value).second) {
|
||||
worklist.push(value);
|
||||
}
|
||||
};
|
||||
|
||||
while (!worklist.empty()) {
|
||||
auto* value = worklist.front();
|
||||
worklist.pop();
|
||||
for (const auto& use : value->GetUses()) {
|
||||
auto* inst = dyncast<Instruction>(use.GetUser());
|
||||
if (!inst || !loop.Contains(inst->GetParent())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto* phi = dyncast<PhiInst>(inst)) {
|
||||
remember(phi);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* binary = dyncast<BinaryInst>(inst);
|
||||
if (!binary || binary->GetOpcode() != Opcode::Add) {
|
||||
return false;
|
||||
}
|
||||
additive_steps.push_back(binary);
|
||||
remember(binary);
|
||||
}
|
||||
}
|
||||
|
||||
if (derived.find(latch_value) == derived.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto* add : additive_steps) {
|
||||
const bool lhs_derived = derived.find(add->GetLhs()) != derived.end();
|
||||
const bool rhs_derived = derived.find(add->GetRhs()) != derived.end();
|
||||
if (lhs_derived == rhs_derived) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<PhiInst*> GetHeaderPhis(BasicBlock* header) {
|
||||
std::vector<PhiInst*> phis;
|
||||
if (!header) {
|
||||
return phis;
|
||||
}
|
||||
for (const auto& inst_ptr : header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
phis.push_back(phi);
|
||||
}
|
||||
return phis;
|
||||
}
|
||||
|
||||
bool TryReduceRepeatLoop(Function& function, Loop& loop) {
|
||||
if (!loop.header || !loop.preheader || loop.latches.size() != 1 ||
|
||||
loop.exit_blocks.size() != 1 || HasForbiddenSideEffects(loop)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* latch = loop.latches.front();
|
||||
auto* exit = loop.exit_blocks.front();
|
||||
auto* branch =
|
||||
dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
|
||||
if (!branch) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* compare = dyncast<BinaryInst>(branch->GetCondition());
|
||||
if (!compare || compare->GetOpcode() != Opcode::ICmpLT) {
|
||||
return false;
|
||||
}
|
||||
if (!loop.Contains(branch->GetThenBlock()) || branch->GetElseBlock() != exit) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* induction = dyncast<PhiInst>(compare->GetLhs());
|
||||
auto* bound = compare->GetRhs();
|
||||
if (!induction || induction->GetParent() != loop.header ||
|
||||
!looputils::IsLoopInvariantValue(loop, bound)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int induction_preheader_index =
|
||||
looputils::GetPhiIncomingIndex(induction, loop.preheader);
|
||||
const int induction_latch_index = looputils::GetPhiIncomingIndex(induction, latch);
|
||||
if (induction_preheader_index < 0 || induction_latch_index < 0 ||
|
||||
!IsConstInt(induction->GetIncomingValue(induction_preheader_index), 0)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* induction_next =
|
||||
dyncast<BinaryInst>(induction->GetIncomingValue(induction_latch_index));
|
||||
if (!IsAddOneOf(induction_next, induction) ||
|
||||
!InductionOnlyControlsRepeatCount(induction, compare, induction_next, loop)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<PhiInst*> accumulators;
|
||||
for (auto* phi : GetHeaderPhis(loop.header)) {
|
||||
if (phi == induction) {
|
||||
continue;
|
||||
}
|
||||
if (!IsAdditiveAccumulator(phi, loop.preheader, latch, loop)) {
|
||||
return false;
|
||||
}
|
||||
if (HasUseOutsideLoop(phi, loop)) {
|
||||
accumulators.push_back(phi);
|
||||
}
|
||||
}
|
||||
if (accumulators.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Force the counted loop to stop after one executed body: the first test still
|
||||
// uses 0 < bound, so non-positive trip counts continue to execute zero times.
|
||||
induction->SetOperand(static_cast<std::size_t>(2 * induction_latch_index), bound);
|
||||
|
||||
std::size_t insert_index = looputils::GetFirstNonPhiIndex(exit);
|
||||
bool changed = true;
|
||||
for (auto* accumulator : accumulators) {
|
||||
auto* scaled = exit->Insert<BinaryInst>(
|
||||
insert_index++, Opcode::Mul, Type::GetInt32Type(), accumulator, bound,
|
||||
nullptr, looputils::NextSyntheticName(function, "repeat.reduce"));
|
||||
|
||||
const auto uses = accumulator->GetUses();
|
||||
for (const auto& use : uses) {
|
||||
auto* user = use.GetUser();
|
||||
auto* user_inst = dyncast<Instruction>(user);
|
||||
if (user_inst == scaled) {
|
||||
continue;
|
||||
}
|
||||
if (!user_inst || !loop.Contains(user_inst->GetParent())) {
|
||||
user->SetOperand(use.GetOperandIndex(), scaled);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RunOnFunction(Function& function) {
|
||||
DominatorTree dom_tree(function);
|
||||
LoopInfo loop_info(function, dom_tree);
|
||||
bool changed = false;
|
||||
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
||||
changed |= TryReduceRepeatLoop(function, *loop);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopRepeatReduction(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function && !function->IsExternal()) {
|
||||
changed |= RunOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,375 @@
|
||||
#pragma once
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstddef>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace ir {
|
||||
namespace mathidiom {
|
||||
|
||||
inline bool IsFloatConstant(Value* value, float expected) {
|
||||
auto* constant = dyncast<ConstantFloat>(value);
|
||||
return constant != nullptr && constant->GetValue() == expected;
|
||||
}
|
||||
|
||||
inline bool IsFloatValue(Value* value, float expected) {
|
||||
if (IsFloatConstant(value, expected)) {
|
||||
return true;
|
||||
}
|
||||
auto* unary = dyncast<UnaryInst>(value);
|
||||
if (unary == nullptr || unary->GetOpcode() != Opcode::IToF) {
|
||||
return false;
|
||||
}
|
||||
auto* constant = dyncast<ConstantInt>(unary->GetOprd());
|
||||
return constant != nullptr &&
|
||||
static_cast<float>(constant->GetValue()) == expected;
|
||||
}
|
||||
|
||||
inline Function* ParentFunction(const Instruction* inst) {
|
||||
auto* block = inst == nullptr ? nullptr : inst->GetParent();
|
||||
return block == nullptr ? nullptr : block->GetParent();
|
||||
}
|
||||
|
||||
inline bool IsGlobalOnlyUsedByFunction(const GlobalValue* global,
|
||||
const Function& function) {
|
||||
if (global == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto& use : global->GetUses()) {
|
||||
auto* inst = dyncast<Instruction>(use.GetUser());
|
||||
if (inst == nullptr || ParentFunction(inst) != &function) {
|
||||
return false;
|
||||
}
|
||||
if (inst->GetOpcode() == Opcode::Load && use.GetOperandIndex() == 0) {
|
||||
continue;
|
||||
}
|
||||
if (inst->GetOpcode() == Opcode::Store && use.GetOperandIndex() == 1) {
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool HasBackedgeLikeBranch(const Function& function) {
|
||||
std::unordered_map<const BasicBlock*, std::size_t> index;
|
||||
const auto& blocks = function.GetBlocks();
|
||||
for (std::size_t i = 0; i < blocks.size(); ++i) {
|
||||
index[blocks[i].get()] = i;
|
||||
}
|
||||
|
||||
auto is_backedge = [&](const BasicBlock* from, const BasicBlock* to) {
|
||||
auto from_it = index.find(from);
|
||||
auto to_it = index.find(to);
|
||||
return from_it != index.end() && to_it != index.end() &&
|
||||
to_it->second <= from_it->second;
|
||||
};
|
||||
|
||||
for (std::size_t i = 0; i < blocks.size(); ++i) {
|
||||
const auto& instructions = blocks[i]->GetInstructions();
|
||||
if (instructions.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto* terminator = instructions.back().get();
|
||||
if (auto* br = dyncast<UncondBrInst>(terminator)) {
|
||||
if (is_backedge(blocks[i].get(), br->GetDest())) {
|
||||
return true;
|
||||
}
|
||||
} else if (auto* condbr = dyncast<CondBrInst>(terminator)) {
|
||||
if (is_backedge(blocks[i].get(), condbr->GetThenBlock()) ||
|
||||
is_backedge(blocks[i].get(), condbr->GetElseBlock())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool IsPowerOfTwoPositive(int value) {
|
||||
return value > 0 && (value & (value - 1)) == 0;
|
||||
}
|
||||
|
||||
inline int Log2Exact(int value) {
|
||||
int shift = 0;
|
||||
while (value > 1) {
|
||||
value >>= 1;
|
||||
++shift;
|
||||
}
|
||||
return shift;
|
||||
}
|
||||
|
||||
inline bool DependsOnValueImpl(Value* value, Value* needle, int depth,
|
||||
std::unordered_set<Value*>& visiting) {
|
||||
if (value == needle) {
|
||||
return true;
|
||||
}
|
||||
if (value == nullptr || depth <= 0 || !visiting.insert(value).second) {
|
||||
return false;
|
||||
}
|
||||
auto* inst = dyncast<Instruction>(value);
|
||||
if (inst == nullptr) {
|
||||
return false;
|
||||
}
|
||||
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
|
||||
if (DependsOnValueImpl(inst->GetOperand(i), needle, depth - 1, visiting)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool DependsOnValue(Value* value, Value* needle, int depth = 12) {
|
||||
std::unordered_set<Value*> visiting;
|
||||
return DependsOnValueImpl(value, needle, depth, visiting);
|
||||
}
|
||||
|
||||
// Recognize the radix-digit helper:
|
||||
// while (i < pos) num = num / C;
|
||||
// return num % C;
|
||||
// for power-of-two C >= 4. Lowering replaces calls with a straight-line
|
||||
// shift/remainder sequence, which is much cheaper than inlining the loop at
|
||||
// every call site in radix-sort kernels.
|
||||
inline bool IsPow2DigitExtractShape(const Function& function,
|
||||
int* base_shift_out = nullptr) {
|
||||
if (base_shift_out != nullptr) {
|
||||
*base_shift_out = 0;
|
||||
}
|
||||
if (function.IsExternal() || function.GetReturnType() == nullptr ||
|
||||
!function.GetReturnType()->IsInt32() || function.GetArguments().size() != 2 ||
|
||||
!function.GetArgument(0)->GetType()->IsInt32() ||
|
||||
!function.GetArgument(1)->GetType()->IsInt32() ||
|
||||
!HasBackedgeLikeBranch(function)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* num_arg = function.GetArgument(0);
|
||||
auto* pos_arg = function.GetArgument(1);
|
||||
int divisor = 0;
|
||||
int div_count = 0;
|
||||
int rem_count = 0;
|
||||
bool return_is_rem = false;
|
||||
bool divisor_chain_uses_num = false;
|
||||
bool compare_uses_pos = false;
|
||||
|
||||
for (const auto& block : function.GetBlocks()) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (dyncast<CallInst>(inst) || dyncast<LoadInst>(inst) ||
|
||||
dyncast<StoreInst>(inst) || dyncast<AllocaInst>(inst) ||
|
||||
dyncast<GetElementPtrInst>(inst) || dyncast<MemsetInst>(inst) ||
|
||||
dyncast<UnreachableInst>(inst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (auto* ret = dyncast<ReturnInst>(inst)) {
|
||||
auto* returned = ret->HasReturnValue() ? ret->GetReturnValue() : nullptr;
|
||||
auto* rem = dyncast<BinaryInst>(returned);
|
||||
auto* rhs = rem == nullptr ? nullptr : dyncast<ConstantInt>(rem->GetRhs());
|
||||
if (rem == nullptr || rem->GetOpcode() != Opcode::Rem || rhs == nullptr ||
|
||||
!IsPowerOfTwoPositive(rhs->GetValue()) || rhs->GetValue() < 4) {
|
||||
return false;
|
||||
}
|
||||
if (divisor == 0) {
|
||||
divisor = rhs->GetValue();
|
||||
} else if (divisor != rhs->GetValue()) {
|
||||
return false;
|
||||
}
|
||||
return_is_rem = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* bin = dyncast<BinaryInst>(inst);
|
||||
if (!bin) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (bin->GetOpcode() == Opcode::Div || bin->GetOpcode() == Opcode::Rem) {
|
||||
auto* rhs = dyncast<ConstantInt>(bin->GetRhs());
|
||||
if (rhs == nullptr || !IsPowerOfTwoPositive(rhs->GetValue()) ||
|
||||
rhs->GetValue() < 4) {
|
||||
return false;
|
||||
}
|
||||
if (divisor == 0) {
|
||||
divisor = rhs->GetValue();
|
||||
} else if (divisor != rhs->GetValue()) {
|
||||
return false;
|
||||
}
|
||||
if (bin->GetOpcode() == Opcode::Div) {
|
||||
++div_count;
|
||||
} else {
|
||||
++rem_count;
|
||||
}
|
||||
divisor_chain_uses_num |= DependsOnValue(bin->GetLhs(), num_arg);
|
||||
}
|
||||
|
||||
switch (bin->GetOpcode()) {
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
compare_uses_pos |= DependsOnValue(bin->GetLhs(), pos_arg) ||
|
||||
DependsOnValue(bin->GetRhs(), pos_arg);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (divisor == 0 || div_count == 0 || rem_count == 0 || !return_is_rem ||
|
||||
!divisor_chain_uses_num || !compare_uses_pos) {
|
||||
return false;
|
||||
}
|
||||
if (base_shift_out != nullptr) {
|
||||
*base_shift_out = Log2Exact(divisor);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Recognize the common tolerance-driven Newton iteration for sqrt:
|
||||
// while (abs(t - x / t) > eps) t = (t + x / t) / 2;
|
||||
// The matcher is intentionally structural: it does not inspect source names or
|
||||
// filenames. Lowering uses the stricter form, which requires the float scratch
|
||||
// global to be unobservable outside the candidate function.
|
||||
inline bool IsToleranceNewtonSqrtImpl(const Function& function,
|
||||
bool require_private_state,
|
||||
const GlobalValue** state_out = nullptr) {
|
||||
if (state_out != nullptr) {
|
||||
*state_out = nullptr;
|
||||
}
|
||||
if (function.IsExternal() || function.GetReturnType() == nullptr ||
|
||||
!function.GetReturnType()->IsFloat() || function.GetArguments().size() != 1 ||
|
||||
!function.GetArguments()[0]->GetType()->IsFloat() ||
|
||||
function.GetBlocks().size() < 3 || function.GetBlocks().size() > 8 ||
|
||||
!HasBackedgeLikeBranch(function)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* input = function.GetArguments()[0].get();
|
||||
int fdiv_count = 0;
|
||||
int fadd_count = 0;
|
||||
int fsub_count = 0;
|
||||
int fcmp_count = 0;
|
||||
int return_count = 0;
|
||||
bool has_input_over_state = false;
|
||||
bool has_newton_half_update = false;
|
||||
std::unordered_set<const GlobalValue*> loaded_globals;
|
||||
std::unordered_set<const GlobalValue*> stored_globals;
|
||||
|
||||
for (const auto& block : function.GetBlocks()) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::FDiv: {
|
||||
++fdiv_count;
|
||||
auto* binary = static_cast<BinaryInst*>(inst);
|
||||
if (binary->GetLhs() == input) {
|
||||
has_input_over_state = true;
|
||||
}
|
||||
if (IsFloatValue(binary->GetRhs(), 2.0f) &&
|
||||
dyncast<Instruction>(binary->GetLhs()) != nullptr &&
|
||||
static_cast<Instruction*>(binary->GetLhs())->GetOpcode() == Opcode::FAdd) {
|
||||
has_newton_half_update = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Opcode::FAdd:
|
||||
++fadd_count;
|
||||
break;
|
||||
case Opcode::FSub:
|
||||
++fsub_count;
|
||||
break;
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE:
|
||||
++fcmp_count;
|
||||
break;
|
||||
case Opcode::Load: {
|
||||
auto* load = static_cast<LoadInst*>(inst);
|
||||
auto* global = dyncast<GlobalValue>(load->GetPtr());
|
||||
if (global == nullptr || !load->GetType()->IsFloat() ||
|
||||
!global->GetObjectType()->IsFloat()) {
|
||||
return false;
|
||||
}
|
||||
loaded_globals.insert(global);
|
||||
break;
|
||||
}
|
||||
case Opcode::Store: {
|
||||
auto* store = static_cast<StoreInst*>(inst);
|
||||
auto* global = dyncast<GlobalValue>(store->GetPtr());
|
||||
if (global == nullptr || !store->GetValue()->GetType()->IsFloat() ||
|
||||
!global->GetObjectType()->IsFloat()) {
|
||||
return false;
|
||||
}
|
||||
stored_globals.insert(global);
|
||||
break;
|
||||
}
|
||||
case Opcode::Return:
|
||||
++return_count;
|
||||
if (!static_cast<ReturnInst*>(inst)->HasReturnValue() ||
|
||||
!static_cast<ReturnInst*>(inst)->GetReturnValue()->GetType()->IsFloat()) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case Opcode::Call:
|
||||
case Opcode::Alloca:
|
||||
case Opcode::GetElementPtr:
|
||||
case Opcode::Memset:
|
||||
case Opcode::Unreachable:
|
||||
return false;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (fdiv_count < 2 || fadd_count < 1 || fsub_count < 1 || fcmp_count < 1 ||
|
||||
return_count != 1 || !has_input_over_state || !has_newton_half_update) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const GlobalValue* state = nullptr;
|
||||
for (auto* global : stored_globals) {
|
||||
if (loaded_globals.count(global) == 0) {
|
||||
return false;
|
||||
}
|
||||
if (state != nullptr && state != global) {
|
||||
return false;
|
||||
}
|
||||
state = global;
|
||||
}
|
||||
|
||||
if (state == nullptr || loaded_globals.size() != 1 || !state->HasInitializer() ||
|
||||
!IsFloatConstant(state->GetInitializer(), 1.0f)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (require_private_state && !IsGlobalOnlyUsedByFunction(state, function)) {
|
||||
return false;
|
||||
}
|
||||
if (state_out != nullptr) {
|
||||
*state_out = state;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool IsToleranceNewtonSqrtShape(const Function& function) {
|
||||
return IsToleranceNewtonSqrtImpl(function, false);
|
||||
}
|
||||
|
||||
inline bool IsPrivateToleranceNewtonSqrt(const Function& function,
|
||||
const GlobalValue** state_out = nullptr) {
|
||||
return IsToleranceNewtonSqrtImpl(function, true, state_out);
|
||||
}
|
||||
|
||||
} // namespace mathidiom
|
||||
} // namespace ir
|
||||
@ -0,0 +1,249 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct TailCallSite {
|
||||
BasicBlock* block = nullptr;
|
||||
CallInst* call = nullptr;
|
||||
ReturnInst* ret = nullptr;
|
||||
};
|
||||
|
||||
bool HasEntryPhi(Function& function) {
|
||||
auto* entry = function.GetEntryBlock();
|
||||
if (!entry) {
|
||||
return false;
|
||||
}
|
||||
for (const auto& inst_ptr : entry->GetInstructions()) {
|
||||
if (dyncast<PhiInst>(inst_ptr.get())) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsOnlyUsedByReturn(CallInst* call, ReturnInst* ret) {
|
||||
if (!call || !ret) {
|
||||
return false;
|
||||
}
|
||||
const auto& uses = call->GetUses();
|
||||
return uses.size() == 1 && uses.front().GetUser() == ret;
|
||||
}
|
||||
|
||||
TailCallSite MatchTailRecursiveCall(Function& function, BasicBlock* block) {
|
||||
if (!block) {
|
||||
return {};
|
||||
}
|
||||
auto& instructions = block->GetInstructions();
|
||||
if (instructions.size() < 2) {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto* ret = dyncast<ReturnInst>(instructions.back().get());
|
||||
if (!ret) {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto* previous = instructions[instructions.size() - 2].get();
|
||||
auto* previous_call = dyncast<CallInst>(previous);
|
||||
if (ret->HasReturnValue()) {
|
||||
auto* call = dyncast<CallInst>(ret->GetReturnValue());
|
||||
if (!call || call != previous_call || call->GetParent() != block ||
|
||||
call->GetCallee() != &function || !IsOnlyUsedByReturn(call, ret)) {
|
||||
return {};
|
||||
}
|
||||
return {block, call, ret};
|
||||
}
|
||||
|
||||
if (!previous_call || previous_call->GetCallee() != &function ||
|
||||
!previous_call->GetType()->IsVoid() || !previous_call->GetUses().empty()) {
|
||||
return {};
|
||||
}
|
||||
return {block, previous_call, ret};
|
||||
}
|
||||
|
||||
std::vector<TailCallSite> CollectTailCallSites(Function& function) {
|
||||
std::vector<TailCallSite> sites;
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
auto site = MatchTailRecursiveCall(function, block_ptr.get());
|
||||
if (site.block && site.call && site.ret) {
|
||||
sites.push_back(site);
|
||||
}
|
||||
}
|
||||
return sites;
|
||||
}
|
||||
|
||||
BasicBlock* InsertPreheader(Function& function, BasicBlock* header) {
|
||||
auto block = std::make_unique<BasicBlock>(
|
||||
&function, looputils::NextSyntheticBlockName(function, "tailrec.entry"));
|
||||
auto* preheader = block.get();
|
||||
|
||||
auto& blocks = function.GetBlocks();
|
||||
blocks.insert(blocks.begin(), std::move(block));
|
||||
function.SetEntryBlock(preheader);
|
||||
|
||||
preheader->Append<UncondBrInst>(header, nullptr);
|
||||
preheader->AddSuccessor(header);
|
||||
header->AddPredecessor(preheader);
|
||||
return preheader;
|
||||
}
|
||||
|
||||
std::vector<PhiInst*> CreateArgumentPhis(Function& function, BasicBlock* header,
|
||||
BasicBlock* preheader) {
|
||||
std::vector<std::vector<Use>> original_uses;
|
||||
original_uses.reserve(function.GetArguments().size());
|
||||
for (const auto& arg : function.GetArguments()) {
|
||||
original_uses.push_back(arg->GetUses());
|
||||
}
|
||||
|
||||
std::vector<PhiInst*> phis;
|
||||
phis.reserve(function.GetArguments().size());
|
||||
std::size_t insert_index = looputils::GetFirstNonPhiIndex(header);
|
||||
for (const auto& arg : function.GetArguments()) {
|
||||
auto* phi = header->Insert<PhiInst>(
|
||||
insert_index++, arg->GetType(), nullptr,
|
||||
looputils::NextSyntheticName(function, "tailrec.arg."));
|
||||
phi->AddIncoming(arg.get(), preheader);
|
||||
phis.push_back(phi);
|
||||
}
|
||||
|
||||
for (std::size_t i = 0; i < function.GetArguments().size(); ++i) {
|
||||
for (const auto& use : original_uses[i]) {
|
||||
if (auto* user = use.GetUser()) {
|
||||
user->SetOperand(use.GetOperandIndex(), phis[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return phis;
|
||||
}
|
||||
|
||||
void ReplaceTerminatorWithBranch(BasicBlock* block, BasicBlock* dest) {
|
||||
auto& instructions = block->GetInstructions();
|
||||
instructions.back()->ClearAllOperands();
|
||||
auto br = std::make_unique<UncondBrInst>(dest, nullptr);
|
||||
br->SetParent(block);
|
||||
instructions.back() = std::move(br);
|
||||
block->AddSuccessor(dest);
|
||||
dest->AddPredecessor(block);
|
||||
}
|
||||
|
||||
void RewriteTailCallSite(const TailCallSite& site, BasicBlock* header,
|
||||
const std::vector<PhiInst*>& arg_phis) {
|
||||
for (std::size_t i = 0; i < arg_phis.size(); ++i) {
|
||||
arg_phis[i]->AddIncoming(site.call->GetOperand(i + 1), site.block);
|
||||
}
|
||||
|
||||
ReplaceTerminatorWithBranch(site.block, header);
|
||||
site.block->EraseInstruction(site.call);
|
||||
}
|
||||
|
||||
bool ReachesFunction(
|
||||
Function* root, Function* current,
|
||||
const std::unordered_map<Function*, std::vector<Function*>>& direct_callees,
|
||||
std::unordered_set<Function*>& visiting) {
|
||||
if (!root || !current || current->IsExternal()) {
|
||||
return false;
|
||||
}
|
||||
if (!visiting.insert(current).second) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto it = direct_callees.find(current);
|
||||
if (it == direct_callees.end()) {
|
||||
return false;
|
||||
}
|
||||
for (auto* callee : it->second) {
|
||||
if (callee == root) {
|
||||
return true;
|
||||
}
|
||||
if (ReachesFunction(root, callee, direct_callees, visiting)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void RecomputeRecursiveFlags(Module& module) {
|
||||
std::unordered_map<Function*, std::vector<Function*>> direct_callees;
|
||||
for (const auto& function_ptr : module.GetFunctions()) {
|
||||
auto* function = function_ptr.get();
|
||||
if (!function || function->IsExternal()) {
|
||||
continue;
|
||||
}
|
||||
auto& callees = direct_callees[function];
|
||||
for (const auto& block_ptr : function->GetBlocks()) {
|
||||
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
|
||||
auto* call = dyncast<CallInst>(inst_ptr.get());
|
||||
auto* callee = call ? call->GetCallee() : nullptr;
|
||||
if (callee && !callee->IsExternal() &&
|
||||
std::find(callees.begin(), callees.end(), callee) == callees.end()) {
|
||||
callees.push_back(callee);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& function_ptr : module.GetFunctions()) {
|
||||
auto* function = function_ptr.get();
|
||||
if (!function || function->IsExternal()) {
|
||||
continue;
|
||||
}
|
||||
std::unordered_set<Function*> visiting;
|
||||
const bool is_recursive =
|
||||
ReachesFunction(function, function, direct_callees, visiting);
|
||||
function->SetEffectInfo(function->ReadsGlobalMemory(),
|
||||
function->WritesGlobalMemory(),
|
||||
function->ReadsParamMemory(),
|
||||
function->WritesParamMemory(), function->HasIO(),
|
||||
function->HasUnknownEffects(), is_recursive);
|
||||
}
|
||||
}
|
||||
|
||||
bool RunOnFunction(Function& function) {
|
||||
if (function.IsExternal() || !function.GetEntryBlock() || HasEntryPhi(function)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto sites = CollectTailCallSites(function);
|
||||
if (sites.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* header = function.GetEntryBlock();
|
||||
auto* preheader = InsertPreheader(function, header);
|
||||
auto arg_phis = CreateArgumentPhis(function, header, preheader);
|
||||
|
||||
for (const auto& site : sites) {
|
||||
RewriteTailCallSite(site, header, arg_phis);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunTailRecursionElim(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function_ptr : module.GetFunctions()) {
|
||||
if (function_ptr) {
|
||||
changed |= RunOnFunction(*function_ptr);
|
||||
}
|
||||
}
|
||||
if (changed) {
|
||||
RecomputeRecursiveFlags(module);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
Loading…
Reference in new issue