You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/ir/passes/IfConversion.cpp

271 lines
9.3 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// IfConversion: 将简单 if-else diamond 转换为算术 select
// - 扫描 CondBr→T→Br→M 且 F==M 的 diamond 模式
// - 安全检查T 必须只有单一前驱B仅允许纯算术指令禁 Div/Mod/浮点)
// - 将 phi 转换为 fv + (tv-fv)*zext(cond)
// - 配合 CFGSimplify 清理空块,使循环体变为单 BB → 可被 LoopUnroll 展开
#include "ir/IR.h"
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
static Value* UnwrapCondition(Value* cond) {
for (int pass = 0; pass < 2; ++pass) {
auto* outer = dynamic_cast<BinaryInst*>(cond);
if (!outer || outer->GetOpcode() != Opcode::Ne) break;
auto* rc = dynamic_cast<ConstantInt*>(outer->GetRhs());
if (!rc || rc->GetValue() != 0) break;
auto* zext = dynamic_cast<CastInst*>(outer->GetLhs());
if (!zext || zext->GetOpcode() != Opcode::ZExt) break;
cond = zext->GetOperandValue();
}
return cond;
}
static BasicBlock* GetOnlyBrTarget(BasicBlock* bb) {
const auto& insts = bb->GetInstructions();
if (insts.empty()) return nullptr;
auto* br = dynamic_cast<BranchInst*>(insts.back().get());
return br ? br->GetTarget() : nullptr;
}
static std::vector<BasicBlock*> ComputePredecessors(
BasicBlock* bb, const std::vector<std::unique_ptr<BasicBlock>>& all_blocks) {
std::vector<BasicBlock*> preds;
for (const auto& other : all_blocks) {
if (other.get() == bb) continue;
const auto& insts = other->GetInstructions();
if (insts.empty()) continue;
auto* term = insts.back().get();
if (auto* br = dynamic_cast<BranchInst*>(term)) {
if (br->GetTarget() == bb) preds.push_back(other.get());
} else if (auto* cbr = dynamic_cast<CondBranchInst*>(term)) {
if (cbr->GetTrueTarget() == bb || cbr->GetFalseTarget() == bb)
preds.push_back(other.get());
}
}
return preds;
}
static bool IsSimpleBlock(BasicBlock* bb) {
for (const auto& inst : bb->GetInstructions()) {
switch (inst->GetOpcode()) {
case Opcode::Add: case Opcode::Sub: case Opcode::Mul:
case Opcode::And: case Opcode::Or:
case Opcode::Eq: case Opcode::Ne: case Opcode::Lt:
case Opcode::Le: case Opcode::Gt: case Opcode::Ge:
case Opcode::ZExt:
case Opcode::Br:
continue;
default:
return false;
}
}
return true;
}
static Value* GetPhiValueFrom(PhiInst* phi, BasicBlock* bb) {
for (size_t i = 0; i < phi->GetNumOperands(); i += 2) {
if (dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1)) == bb)
return phi->GetOperand(i);
}
return nullptr;
}
static void RemovePhiEntriesFrom(PhiInst* phi, BasicBlock* bb) {
std::vector<std::pair<Value*, Value*>> keep;
for (size_t i = 0; i < phi->GetNumOperands(); i += 2) {
auto* pred = dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1));
if (pred != bb)
keep.push_back({phi->GetOperand(i), phi->GetOperand(i + 1)});
}
if (keep.size() * 2 != phi->GetNumOperands()) {
phi->ClearOperands();
for (auto& [val, pred] : keep) {
phi->AddOperand(val);
phi->AddOperand(pred);
}
}
}
static void SetPhiEntry(PhiInst* phi, BasicBlock* bb, Value* val) {
for (size_t i = 0; i < phi->GetNumOperands(); i += 2) {
if (dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1)) == bb) {
phi->SetOperand(i, val);
return;
}
}
phi->AddOperand(val);
phi->AddOperand(bb);
}
static bool TryConvertOneDiamond(BasicBlock* B, BasicBlock* T, BasicBlock* M,
Value* cond_i1, Context& ctx,
const std::vector<std::unique_ptr<BasicBlock>>& all_blocks) {
if (!IsSimpleBlock(T)) return false;
if (GetOnlyBrTarget(T) != M) return false;
auto t_preds = ComputePredecessors(T, all_blocks);
if (t_preds.size() != 1 || t_preds[0] != B) return false;
struct PhiEntry { PhiInst* phi; Value* val_t; Value* val_f; };
std::vector<PhiEntry> to_convert;
for (const auto& inst : M->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst.get());
if (!phi) break;
Value* val_t = GetPhiValueFrom(phi, T);
if (!val_t) continue;
Value* val_f = GetPhiValueFrom(phi, B);
if (!val_f) {
for (size_t i = 0; i < phi->GetNumOperands(); i += 2) {
auto* pred = dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1));
if (pred != T) { val_f = phi->GetOperand(i); break; }
}
}
if (!val_f) continue;
to_convert.push_back({phi, val_t, val_f});
}
if (to_convert.empty()) return false;
auto* cbr = B->GetInstructions().back().get();
B->TakeInstruction(cbr);
auto& t_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(T->GetInstructions());
std::vector<Instruction*> t_to_move;
for (const auto& inst : t_insts)
if (inst->GetOpcode() != Opcode::Br)
t_to_move.push_back(inst.get());
for (auto* inst : t_to_move) {
auto taken = T->TakeInstruction(inst);
B->InsertInstructionBeforeTerminator(std::move(taken));
}
if (!T->GetInstructions().empty())
T->TakeInstruction(T->GetInstructions().back().get());
for (auto& [phi, val_t, val_f] : to_convert) {
if (val_t == val_f) {
RemovePhiEntriesFrom(phi, T);
SetPhiEntry(phi, B, val_f);
continue;
}
auto* zext = B->Append<CastInst>(Opcode::ZExt, Type::GetInt32Type(), cond_i1, ctx.NextTemp());
auto* diff = B->Append<BinaryInst>(Opcode::Sub, Type::GetInt32Type(), val_t, val_f, ctx.NextTemp());
auto* masked = B->Append<BinaryInst>(Opcode::Mul, Type::GetInt32Type(), diff, zext, ctx.NextTemp());
auto* select_val = B->Append<BinaryInst>(Opcode::Add, Type::GetInt32Type(), val_f, masked, ctx.NextTemp());
RemovePhiEntriesFrom(phi, T);
SetPhiEntry(phi, B, select_val);
}
B->Append<BranchInst>(Type::GetVoidType(), M);
return true;
}
static void IfConvertFunction(Function* func, Context& ctx) {
auto& blocks = const_cast<std::vector<std::unique_ptr<BasicBlock>>&>(func->GetBlocks());
bool changed = true;
while (changed) {
changed = false;
for (const auto& bb : blocks) {
const auto& insts = bb->GetInstructions();
if (insts.empty()) continue;
auto* cbr = dynamic_cast<CondBranchInst*>(insts.back().get());
if (!cbr) continue;
BasicBlock* T = cbr->GetTrueTarget();
BasicBlock* F = cbr->GetFalseTarget();
Value* cond = UnwrapCondition(cbr->GetCond());
if (TryConvertOneDiamond(bb.get(), T, F, cond, ctx, blocks)) {
changed = true;
break;
}
}
}
}
static void CleanupRedundantPhis(Function* func) {
for (const auto& bb : func->GetBlocks()) {
auto& insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(bb->GetInstructions());
for (size_t i = 0; i < insts.size(); ) {
auto* phi = dynamic_cast<PhiInst*>(insts[i].get());
if (!phi) break;
Value* unique_val = nullptr;
bool all_same = true;
for (size_t j = 0; j < phi->GetNumOperands(); j += 2) {
Value* v = phi->GetOperand(j);
if (!unique_val) unique_val = v;
else if (unique_val != v) { all_same = false; break; }
}
if (all_same && unique_val) {
phi->ReplaceAllUsesWith(unique_val);
phi->ClearOperands();
insts.erase(insts.begin() + i);
continue;
}
++i;
}
}
}
static void MergeSinglePredBlocks(Function* func) {
auto& blocks = const_cast<std::vector<std::unique_ptr<BasicBlock>>&>(func->GetBlocks());
bool changed = true;
while (changed) {
changed = false;
for (auto& bb_ptr : blocks) {
BasicBlock* bb = bb_ptr.get();
if (bb == func->GetEntry()) continue;
bool has_phi = false;
for (const auto& inst : bb->GetInstructions()) {
if (dynamic_cast<PhiInst*>(inst.get())) { has_phi = true; break; }
}
if (has_phi) continue;
auto preds = ComputePredecessors(bb, blocks);
if (preds.size() != 1) continue;
BasicBlock* pred = preds[0];
if (pred == bb) continue;
const auto& pred_insts = pred->GetInstructions();
if (pred_insts.empty()) continue;
auto* br = dynamic_cast<BranchInst*>(pred_insts.back().get());
if (!br || br->GetTarget() != bb) continue;
pred->TakeInstruction(pred_insts.back().get());
auto& bb_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(bb->GetInstructions());
std::vector<Instruction*> to_move;
for (auto& inst : bb_insts)
to_move.push_back(inst.get());
for (auto* inst : to_move) {
auto taken = bb->TakeInstruction(inst);
pred->InsertInstructionBeforeTerminator(std::move(taken));
}
for (auto& other : blocks) {
if (other.get() == bb) continue;
auto& o_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(other->GetInstructions());
for (auto& inst : o_insts) {
auto* phi = dynamic_cast<PhiInst*>(inst.get());
if (!phi) break;
for (size_t i = 0; i < phi->GetNumOperands(); i += 2) {
if (dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1)) == bb)
phi->SetOperand(i + 1, pred);
}
}
}
changed = true;
break;
}
}
}
} // namespace
void RunIfConversion(Module& module) {
for (auto& func : module.GetFunctions()) {
if (func->IsExternal()) continue;
IfConvertFunction(func.get(), module.GetContext());
CleanupRedundantPhis(func.get());
MergeSinglePredBlocks(func.get());
}
}
} // namespace ir