You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

240 lines
6.6 KiB

#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <cstddef>
#include <vector>
namespace ir {
namespace {
Instruction* GetTerminator(BasicBlock* block) {
if (block == nullptr || block->GetInstructions().empty()) {
return nullptr;
}
auto* inst = block->GetInstructions().back().get();
return inst != nullptr && inst->IsTerminator() ? inst : nullptr;
}
std::size_t GetTerminatorIndex(BasicBlock* block) {
const auto& instructions = block->GetInstructions();
return instructions.empty() ? 0 : instructions.size() - 1;
}
ConstantInt* ConstInt(int value) {
return new ConstantInt(Type::GetInt32Type(), value);
}
PhiInst* GetSinglePhi(BasicBlock* block) {
if (block == nullptr) {
return nullptr;
}
PhiInst* phi = nullptr;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* current = dyncast<PhiInst>(inst_ptr.get());
if (current == nullptr) {
break;
}
if (phi != nullptr) {
return nullptr;
}
phi = current;
}
return phi;
}
bool HasOnlyOneNonTerminator(BasicBlock* block, Instruction** out) {
if (block == nullptr) {
return false;
}
Instruction* candidate = nullptr;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst == nullptr || inst->IsTerminator()) {
continue;
}
if (candidate != nullptr) {
return false;
}
candidate = inst;
}
if (out != nullptr) {
*out = candidate;
}
return candidate != nullptr;
}
int IncomingIndexFor(PhiInst* phi, BasicBlock* block) {
if (phi == nullptr || block == nullptr) {
return -1;
}
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
if (phi->GetIncomingBlock(i) == block) {
return i;
}
}
return -1;
}
bool IsUsedOnlyBy(Value* value, User* expected_user) {
if (value == nullptr || expected_user == nullptr) {
return false;
}
for (const auto& use : value->GetUses()) {
if (use.GetUser() != expected_user) {
return false;
}
}
return true;
}
struct ConditionalAccumulation {
Value* base = nullptr;
Value* delta = nullptr;
Opcode opcode = Opcode::Add;
};
bool MatchConditionalAccumulation(PhiInst* phi, BasicBlock* pred,
BasicBlock* update_block,
BinaryInst* update,
ConditionalAccumulation* match) {
if (phi == nullptr || pred == nullptr || update_block == nullptr ||
update == nullptr || match == nullptr || phi->GetNumIncomings() != 2 ||
!phi->GetType()->IsInt32() || !update->GetType()->IsInt32()) {
return false;
}
const int pred_index = IncomingIndexFor(phi, pred);
const int update_index = IncomingIndexFor(phi, update_block);
if (pred_index < 0 || update_index < 0) {
return false;
}
auto* base = phi->GetIncomingValue(pred_index);
if (phi->GetIncomingValue(update_index) != update || base == nullptr ||
!base->GetType()->IsInt32() || !IsUsedOnlyBy(update, phi)) {
return false;
}
auto* lhs = update->GetLhs();
auto* rhs = update->GetRhs();
if (update->GetOpcode() == Opcode::Add) {
if (lhs == base && rhs != nullptr && rhs->GetType()->IsInt32()) {
*match = {base, rhs, Opcode::Add};
return true;
}
if (rhs == base && lhs != nullptr && lhs->GetType()->IsInt32()) {
*match = {base, lhs, Opcode::Add};
return true;
}
return false;
}
if (update->GetOpcode() == Opcode::Sub && lhs == base && rhs != nullptr &&
rhs->GetType()->IsInt32()) {
*match = {base, rhs, Opcode::Sub};
return true;
}
return false;
}
bool TryConvertConditionalAccumulation(Function& function, BasicBlock* pred) {
auto* branch = dyncast<CondBrInst>(GetTerminator(pred));
if (branch == nullptr || branch->GetCondition() == nullptr ||
!branch->GetCondition()->GetType()->IsInt1()) {
return false;
}
auto* update_block = branch->GetThenBlock();
auto* join = branch->GetElseBlock();
if (update_block == nullptr || join == nullptr || update_block == join ||
update_block->GetPredecessors().size() != 1 ||
update_block->GetPredecessors().front() != pred ||
update_block->GetSuccessors().size() != 1 ||
update_block->GetSuccessors().front() != join) {
return false;
}
auto* update_term = dyncast<UncondBrInst>(GetTerminator(update_block));
if (update_term == nullptr || update_term->GetDest() != join) {
return false;
}
Instruction* only_inst = nullptr;
if (!HasOnlyOneNonTerminator(update_block, &only_inst)) {
return false;
}
auto* update = dyncast<BinaryInst>(only_inst);
if (update == nullptr ||
(update->GetOpcode() != Opcode::Add && update->GetOpcode() != Opcode::Sub)) {
return false;
}
auto* phi = GetSinglePhi(join);
ConditionalAccumulation accum;
if (!MatchConditionalAccumulation(phi, pred, update_block, update, &accum)) {
return false;
}
const std::size_t insert_pos = GetTerminatorIndex(pred);
auto* enabled = pred->Insert<ZextInst>(insert_pos, branch->GetCondition(),
Type::GetInt32Type(), nullptr,
"%ifconv.zext");
auto* mask = pred->Insert<BinaryInst>(insert_pos + 1, Opcode::Sub,
Type::GetInt32Type(), ConstInt(0),
enabled, nullptr, "%ifconv.mask");
auto* masked_delta = pred->Insert<BinaryInst>(
insert_pos + 2, Opcode::And, Type::GetInt32Type(), accum.delta, mask,
nullptr, "%ifconv.delta");
auto* replacement = pred->Insert<BinaryInst>(
insert_pos + 3, accum.opcode, Type::GetInt32Type(), accum.base,
masked_delta, nullptr, "%ifconv.acc");
phi->ReplaceAllUsesWith(replacement);
join->EraseInstruction(phi);
passutils::ReplaceTerminatorWithBr(pred, join);
pred->RemoveSuccessor(update_block);
update_block->RemovePredecessor(pred);
passutils::RemoveUnreachableBlocks(function);
return true;
}
bool RunIfConversionOnFunction(Function& function) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
bool local_changed = true;
while (local_changed) {
local_changed = false;
auto blocks = passutils::CollectReachableBlocks(function);
for (auto* block : blocks) {
if (TryConvertConditionalAccumulation(function, block)) {
local_changed = true;
changed = true;
break;
}
}
}
return changed;
}
} // namespace
bool RunIfConversion(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function != nullptr) {
changed |= RunIfConversionOnFunction(*function);
}
}
return changed;
}
} // namespace ir