forked from p4jyxwm3q/nudt-compiler-cpp
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
240 lines
6.6 KiB
240 lines
6.6 KiB
#include "ir/PassManager.h"
|
|
|
|
#include "ir/IR.h"
|
|
#include "PassUtils.h"
|
|
|
|
#include <cstddef>
|
|
#include <vector>
|
|
|
|
namespace ir {
|
|
namespace {
|
|
|
|
Instruction* GetTerminator(BasicBlock* block) {
|
|
if (block == nullptr || block->GetInstructions().empty()) {
|
|
return nullptr;
|
|
}
|
|
auto* inst = block->GetInstructions().back().get();
|
|
return inst != nullptr && inst->IsTerminator() ? inst : nullptr;
|
|
}
|
|
|
|
std::size_t GetTerminatorIndex(BasicBlock* block) {
|
|
const auto& instructions = block->GetInstructions();
|
|
return instructions.empty() ? 0 : instructions.size() - 1;
|
|
}
|
|
|
|
ConstantInt* ConstInt(int value) {
|
|
return new ConstantInt(Type::GetInt32Type(), value);
|
|
}
|
|
|
|
PhiInst* GetSinglePhi(BasicBlock* block) {
|
|
if (block == nullptr) {
|
|
return nullptr;
|
|
}
|
|
|
|
PhiInst* phi = nullptr;
|
|
for (const auto& inst_ptr : block->GetInstructions()) {
|
|
auto* current = dyncast<PhiInst>(inst_ptr.get());
|
|
if (current == nullptr) {
|
|
break;
|
|
}
|
|
if (phi != nullptr) {
|
|
return nullptr;
|
|
}
|
|
phi = current;
|
|
}
|
|
return phi;
|
|
}
|
|
|
|
bool HasOnlyOneNonTerminator(BasicBlock* block, Instruction** out) {
|
|
if (block == nullptr) {
|
|
return false;
|
|
}
|
|
|
|
Instruction* candidate = nullptr;
|
|
for (const auto& inst_ptr : block->GetInstructions()) {
|
|
auto* inst = inst_ptr.get();
|
|
if (inst == nullptr || inst->IsTerminator()) {
|
|
continue;
|
|
}
|
|
if (candidate != nullptr) {
|
|
return false;
|
|
}
|
|
candidate = inst;
|
|
}
|
|
if (out != nullptr) {
|
|
*out = candidate;
|
|
}
|
|
return candidate != nullptr;
|
|
}
|
|
|
|
int IncomingIndexFor(PhiInst* phi, BasicBlock* block) {
|
|
if (phi == nullptr || block == nullptr) {
|
|
return -1;
|
|
}
|
|
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
|
|
if (phi->GetIncomingBlock(i) == block) {
|
|
return i;
|
|
}
|
|
}
|
|
return -1;
|
|
}
|
|
|
|
bool IsUsedOnlyBy(Value* value, User* expected_user) {
|
|
if (value == nullptr || expected_user == nullptr) {
|
|
return false;
|
|
}
|
|
for (const auto& use : value->GetUses()) {
|
|
if (use.GetUser() != expected_user) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
struct ConditionalAccumulation {
|
|
Value* base = nullptr;
|
|
Value* delta = nullptr;
|
|
Opcode opcode = Opcode::Add;
|
|
};
|
|
|
|
bool MatchConditionalAccumulation(PhiInst* phi, BasicBlock* pred,
|
|
BasicBlock* update_block,
|
|
BinaryInst* update,
|
|
ConditionalAccumulation* match) {
|
|
if (phi == nullptr || pred == nullptr || update_block == nullptr ||
|
|
update == nullptr || match == nullptr || phi->GetNumIncomings() != 2 ||
|
|
!phi->GetType()->IsInt32() || !update->GetType()->IsInt32()) {
|
|
return false;
|
|
}
|
|
|
|
const int pred_index = IncomingIndexFor(phi, pred);
|
|
const int update_index = IncomingIndexFor(phi, update_block);
|
|
if (pred_index < 0 || update_index < 0) {
|
|
return false;
|
|
}
|
|
|
|
auto* base = phi->GetIncomingValue(pred_index);
|
|
if (phi->GetIncomingValue(update_index) != update || base == nullptr ||
|
|
!base->GetType()->IsInt32() || !IsUsedOnlyBy(update, phi)) {
|
|
return false;
|
|
}
|
|
|
|
auto* lhs = update->GetLhs();
|
|
auto* rhs = update->GetRhs();
|
|
if (update->GetOpcode() == Opcode::Add) {
|
|
if (lhs == base && rhs != nullptr && rhs->GetType()->IsInt32()) {
|
|
*match = {base, rhs, Opcode::Add};
|
|
return true;
|
|
}
|
|
if (rhs == base && lhs != nullptr && lhs->GetType()->IsInt32()) {
|
|
*match = {base, lhs, Opcode::Add};
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
if (update->GetOpcode() == Opcode::Sub && lhs == base && rhs != nullptr &&
|
|
rhs->GetType()->IsInt32()) {
|
|
*match = {base, rhs, Opcode::Sub};
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool TryConvertConditionalAccumulation(Function& function, BasicBlock* pred) {
|
|
auto* branch = dyncast<CondBrInst>(GetTerminator(pred));
|
|
if (branch == nullptr || branch->GetCondition() == nullptr ||
|
|
!branch->GetCondition()->GetType()->IsInt1()) {
|
|
return false;
|
|
}
|
|
|
|
auto* update_block = branch->GetThenBlock();
|
|
auto* join = branch->GetElseBlock();
|
|
if (update_block == nullptr || join == nullptr || update_block == join ||
|
|
update_block->GetPredecessors().size() != 1 ||
|
|
update_block->GetPredecessors().front() != pred ||
|
|
update_block->GetSuccessors().size() != 1 ||
|
|
update_block->GetSuccessors().front() != join) {
|
|
return false;
|
|
}
|
|
|
|
auto* update_term = dyncast<UncondBrInst>(GetTerminator(update_block));
|
|
if (update_term == nullptr || update_term->GetDest() != join) {
|
|
return false;
|
|
}
|
|
|
|
Instruction* only_inst = nullptr;
|
|
if (!HasOnlyOneNonTerminator(update_block, &only_inst)) {
|
|
return false;
|
|
}
|
|
auto* update = dyncast<BinaryInst>(only_inst);
|
|
if (update == nullptr ||
|
|
(update->GetOpcode() != Opcode::Add && update->GetOpcode() != Opcode::Sub)) {
|
|
return false;
|
|
}
|
|
|
|
auto* phi = GetSinglePhi(join);
|
|
ConditionalAccumulation accum;
|
|
if (!MatchConditionalAccumulation(phi, pred, update_block, update, &accum)) {
|
|
return false;
|
|
}
|
|
|
|
const std::size_t insert_pos = GetTerminatorIndex(pred);
|
|
auto* enabled = pred->Insert<ZextInst>(insert_pos, branch->GetCondition(),
|
|
Type::GetInt32Type(), nullptr,
|
|
"%ifconv.zext");
|
|
auto* mask = pred->Insert<BinaryInst>(insert_pos + 1, Opcode::Sub,
|
|
Type::GetInt32Type(), ConstInt(0),
|
|
enabled, nullptr, "%ifconv.mask");
|
|
auto* masked_delta = pred->Insert<BinaryInst>(
|
|
insert_pos + 2, Opcode::And, Type::GetInt32Type(), accum.delta, mask,
|
|
nullptr, "%ifconv.delta");
|
|
auto* replacement = pred->Insert<BinaryInst>(
|
|
insert_pos + 3, accum.opcode, Type::GetInt32Type(), accum.base,
|
|
masked_delta, nullptr, "%ifconv.acc");
|
|
|
|
phi->ReplaceAllUsesWith(replacement);
|
|
join->EraseInstruction(phi);
|
|
|
|
passutils::ReplaceTerminatorWithBr(pred, join);
|
|
pred->RemoveSuccessor(update_block);
|
|
update_block->RemovePredecessor(pred);
|
|
passutils::RemoveUnreachableBlocks(function);
|
|
return true;
|
|
}
|
|
|
|
bool RunIfConversionOnFunction(Function& function) {
|
|
if (function.IsExternal()) {
|
|
return false;
|
|
}
|
|
|
|
bool changed = false;
|
|
bool local_changed = true;
|
|
while (local_changed) {
|
|
local_changed = false;
|
|
auto blocks = passutils::CollectReachableBlocks(function);
|
|
for (auto* block : blocks) {
|
|
if (TryConvertConditionalAccumulation(function, block)) {
|
|
local_changed = true;
|
|
changed = true;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool RunIfConversion(Module& module) {
|
|
bool changed = false;
|
|
for (const auto& function : module.GetFunctions()) {
|
|
if (function != nullptr) {
|
|
changed |= RunIfConversionOnFunction(*function);
|
|
}
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
} // namespace ir
|