parent
69f2cdf11a
commit
184d5c7cb5
@ -1,4 +1,135 @@
|
||||
// CFG 简化:
|
||||
// - 删除不可达块、合并空块、简化分支等
|
||||
// - 改善 IR 结构,便于后续优化与后端生成
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
Instruction* Terminator(BasicBlock& block) {
|
||||
auto& insts = block.GetMutableInstructions();
|
||||
if (insts.empty()) return nullptr;
|
||||
auto* inst = insts.back().get();
|
||||
return inst && inst->IsTerminator() ? inst : nullptr;
|
||||
}
|
||||
|
||||
void AddCFGEdge(BasicBlock* from, BasicBlock* to) {
|
||||
if (!from || !to) return;
|
||||
from->AddSuccessor(to);
|
||||
to->AddPredecessor(from);
|
||||
}
|
||||
|
||||
void RebuildCFG(Function& func) {
|
||||
for (const auto& block : func.GetBlocks()) {
|
||||
if (!block) continue;
|
||||
block->ClearPredecessors();
|
||||
block->ClearSuccessors();
|
||||
}
|
||||
for (const auto& block : func.GetBlocks()) {
|
||||
if (!block) continue;
|
||||
auto* term = Terminator(*block);
|
||||
if (auto* br = dynamic_cast<BranchInst*>(term)) {
|
||||
AddCFGEdge(block.get(), br->GetDest());
|
||||
} else if (auto* cbr = dynamic_cast<CondBrInst*>(term)) {
|
||||
AddCFGEdge(block.get(), cbr->GetTrueDest());
|
||||
AddCFGEdge(block.get(), cbr->GetFalseDest());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool SimplifyBranches(Function& func) {
|
||||
bool changed = false;
|
||||
for (const auto& block : func.GetBlocks()) {
|
||||
if (!block) continue;
|
||||
auto* cbr = dynamic_cast<CondBrInst*>(Terminator(*block));
|
||||
if (!cbr) continue;
|
||||
BasicBlock* dest = nullptr;
|
||||
if (cbr->GetTrueDest() == cbr->GetFalseDest()) {
|
||||
dest = cbr->GetTrueDest();
|
||||
} else if (auto* cond = dynamic_cast<ConstantInt*>(cbr->GetCond())) {
|
||||
dest = cond->GetValue() != 0 ? cbr->GetTrueDest() : cbr->GetFalseDest();
|
||||
}
|
||||
if (dest) {
|
||||
block->ReplaceTerminator(std::make_unique<BranchInst>(dest));
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
std::unordered_set<BasicBlock*> ReachableBlocks(Function& func) {
|
||||
std::unordered_set<BasicBlock*> reachable;
|
||||
std::queue<BasicBlock*> work;
|
||||
if (auto* entry = func.GetEntry()) {
|
||||
reachable.insert(entry);
|
||||
work.push(entry);
|
||||
}
|
||||
while (!work.empty()) {
|
||||
auto* block = work.front();
|
||||
work.pop();
|
||||
for (auto* succ : block->GetSuccessors()) {
|
||||
if (succ && reachable.insert(succ).second) {
|
||||
work.push(succ);
|
||||
}
|
||||
}
|
||||
}
|
||||
return reachable;
|
||||
}
|
||||
|
||||
bool RemoveUnreachable(Function& func) {
|
||||
auto reachable = ReachableBlocks(func);
|
||||
bool changed = false;
|
||||
|
||||
for (const auto& block : func.GetBlocks()) {
|
||||
if (!block || reachable.count(block.get()) == 0) continue;
|
||||
for (const auto& inst : block->GetInstructions()) {
|
||||
auto* phi = dynamic_cast<PhiInst*>(inst.get());
|
||||
if (!phi) continue;
|
||||
for (const auto& other : func.GetBlocks()) {
|
||||
if (other && reachable.count(other.get()) == 0) {
|
||||
phi->RemoveIncomingFrom(other.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto& blocks = func.GetMutableBlocks();
|
||||
for (auto it = blocks.begin(); it != blocks.end();) {
|
||||
auto* block = it->get();
|
||||
if (!block || reachable.count(block) != 0) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
auto& insts = block->GetMutableInstructions();
|
||||
while (!insts.empty()) {
|
||||
block->EraseInstruction(insts.back().get());
|
||||
}
|
||||
it = blocks.erase(it);
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunCFGSimplify(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& func : module.GetFunctions()) {
|
||||
if (!func || func->IsDeclaration()) continue;
|
||||
RebuildCFG(*func);
|
||||
bool local_changed = SimplifyBranches(*func);
|
||||
if (local_changed) {
|
||||
RebuildCFG(*func);
|
||||
}
|
||||
local_changed |= RemoveUnreachable(*func);
|
||||
if (local_changed) {
|
||||
RebuildCFG(*func);
|
||||
}
|
||||
changed |= local_changed;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
|
||||
@ -1,4 +1,239 @@
|
||||
// IR 常量折叠:
|
||||
// - 折叠可判定的常量表达式
|
||||
// - 简化常量控制流分支(按实现范围裁剪)
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
ConstantInt* AsConstInt(Value* value) {
|
||||
return dynamic_cast<ConstantInt*>(value);
|
||||
}
|
||||
|
||||
ConstantFloat* AsConstFloat(Value* value) {
|
||||
return dynamic_cast<ConstantFloat*>(value);
|
||||
}
|
||||
|
||||
bool IsZero(Value* value) {
|
||||
if (auto* i = AsConstInt(value)) return i->GetValue() == 0;
|
||||
if (auto* f = AsConstFloat(value)) return f->GetValue() == 0.0f;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsOne(Value* value) {
|
||||
if (auto* i = AsConstInt(value)) return i->GetValue() == 1;
|
||||
if (auto* f = AsConstFloat(value)) return f->GetValue() == 1.0f;
|
||||
return false;
|
||||
}
|
||||
|
||||
ConstantValue* FoldBinary(BinaryInst& inst, Context& ctx) {
|
||||
auto* li = AsConstInt(inst.GetLhs());
|
||||
auto* ri = AsConstInt(inst.GetRhs());
|
||||
if (li && ri) {
|
||||
int lhs = li->GetValue();
|
||||
int rhs = ri->GetValue();
|
||||
switch (inst.GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
return ctx.GetConstInt(lhs + rhs);
|
||||
case Opcode::Sub:
|
||||
return ctx.GetConstInt(lhs - rhs);
|
||||
case Opcode::Mul:
|
||||
return ctx.GetConstInt(lhs * rhs);
|
||||
case Opcode::SDiv:
|
||||
if (rhs != 0) return ctx.GetConstInt(lhs / rhs);
|
||||
break;
|
||||
case Opcode::SRem:
|
||||
if (rhs != 0) return ctx.GetConstInt(lhs % rhs);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto* lf = AsConstFloat(inst.GetLhs());
|
||||
auto* rf = AsConstFloat(inst.GetRhs());
|
||||
if (lf && rf) {
|
||||
float lhs = lf->GetValue();
|
||||
float rhs = rf->GetValue();
|
||||
switch (inst.GetOpcode()) {
|
||||
case Opcode::FAdd:
|
||||
return ctx.GetConstFloat(lhs + rhs);
|
||||
case Opcode::FSub:
|
||||
return ctx.GetConstFloat(lhs - rhs);
|
||||
case Opcode::FMul:
|
||||
return ctx.GetConstFloat(lhs * rhs);
|
||||
case Opcode::FDiv:
|
||||
return ctx.GetConstFloat(lhs / rhs);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Value* SimplifyBinary(BinaryInst& inst) {
|
||||
auto op = inst.GetOpcode();
|
||||
auto* lhs = inst.GetLhs();
|
||||
auto* rhs = inst.GetRhs();
|
||||
switch (op) {
|
||||
case Opcode::Add:
|
||||
case Opcode::FAdd:
|
||||
if (IsZero(rhs)) return lhs;
|
||||
if (IsZero(lhs)) return rhs;
|
||||
break;
|
||||
case Opcode::Sub:
|
||||
case Opcode::FSub:
|
||||
if (IsZero(rhs)) return lhs;
|
||||
break;
|
||||
case Opcode::Mul:
|
||||
if (IsOne(rhs)) return lhs;
|
||||
if (IsOne(lhs)) return rhs;
|
||||
if (IsZero(rhs)) return rhs;
|
||||
if (IsZero(lhs)) return lhs;
|
||||
break;
|
||||
case Opcode::FMul:
|
||||
if (IsOne(rhs)) return lhs;
|
||||
if (IsOne(lhs)) return rhs;
|
||||
break;
|
||||
case Opcode::SDiv:
|
||||
case Opcode::FDiv:
|
||||
if (IsOne(rhs)) return lhs;
|
||||
break;
|
||||
case Opcode::SRem:
|
||||
if (IsOne(rhs)) return rhs;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ConstantInt* FoldICmp(ICmpInst& inst, Context& ctx) {
|
||||
auto* lhs = AsConstInt(inst.GetLhs());
|
||||
auto* rhs = AsConstInt(inst.GetRhs());
|
||||
if (!lhs || !rhs) return nullptr;
|
||||
int l = lhs->GetValue();
|
||||
int r = rhs->GetValue();
|
||||
bool result = false;
|
||||
switch (inst.GetPredicate()) {
|
||||
case ICmpPredicate::Eq:
|
||||
result = l == r;
|
||||
break;
|
||||
case ICmpPredicate::Ne:
|
||||
result = l != r;
|
||||
break;
|
||||
case ICmpPredicate::Slt:
|
||||
result = l < r;
|
||||
break;
|
||||
case ICmpPredicate::Sle:
|
||||
result = l <= r;
|
||||
break;
|
||||
case ICmpPredicate::Sgt:
|
||||
result = l > r;
|
||||
break;
|
||||
case ICmpPredicate::Sge:
|
||||
result = l >= r;
|
||||
break;
|
||||
}
|
||||
return ctx.GetConstBool(result);
|
||||
}
|
||||
|
||||
ConstantInt* FoldFCmp(FCmpInst& inst, Context& ctx) {
|
||||
auto* lhs = AsConstFloat(inst.GetLhs());
|
||||
auto* rhs = AsConstFloat(inst.GetRhs());
|
||||
if (!lhs || !rhs) return nullptr;
|
||||
float l = lhs->GetValue();
|
||||
float r = rhs->GetValue();
|
||||
bool ordered = !std::isnan(l) && !std::isnan(r);
|
||||
bool result = false;
|
||||
switch (inst.GetPredicate()) {
|
||||
case FCmpPredicate::Oeq:
|
||||
result = ordered && l == r;
|
||||
break;
|
||||
case FCmpPredicate::One:
|
||||
result = ordered && l != r;
|
||||
break;
|
||||
case FCmpPredicate::Olt:
|
||||
result = ordered && l < r;
|
||||
break;
|
||||
case FCmpPredicate::Ole:
|
||||
result = ordered && l <= r;
|
||||
break;
|
||||
case FCmpPredicate::Ogt:
|
||||
result = ordered && l > r;
|
||||
break;
|
||||
case FCmpPredicate::Oge:
|
||||
result = ordered && l >= r;
|
||||
break;
|
||||
}
|
||||
return ctx.GetConstBool(result);
|
||||
}
|
||||
|
||||
ConstantValue* FoldCast(CastInst& inst, Context& ctx) {
|
||||
switch (inst.GetOpcode()) {
|
||||
case Opcode::SIToFP:
|
||||
if (auto* c = AsConstInt(inst.GetValue())) {
|
||||
return ctx.GetConstFloat(static_cast<float>(c->GetValue()));
|
||||
}
|
||||
break;
|
||||
case Opcode::FPToSI:
|
||||
if (auto* c = AsConstFloat(inst.GetValue())) {
|
||||
return ctx.GetConstInt(static_cast<int>(c->GetValue()));
|
||||
}
|
||||
break;
|
||||
case Opcode::ZExt:
|
||||
if (auto* c = AsConstInt(inst.GetValue())) {
|
||||
return ctx.GetConstInt(c->GetValue() != 0 ? 1 : 0);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Value* SimplifyPhi(PhiInst& phi) {
|
||||
const auto& values = phi.GetIncomingValues();
|
||||
if (values.empty()) return nullptr;
|
||||
auto* first = values.front();
|
||||
for (auto* value : values) {
|
||||
if (value != first) return nullptr;
|
||||
}
|
||||
return first;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunConstFold(Module& module) {
|
||||
bool changed = false;
|
||||
auto& ctx = module.GetContext();
|
||||
for (const auto& func : module.GetFunctions()) {
|
||||
if (!func || func->IsDeclaration()) continue;
|
||||
for (const auto& block : func->GetBlocks()) {
|
||||
if (!block) continue;
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
Value* replacement = nullptr;
|
||||
if (auto* bin = dynamic_cast<BinaryInst*>(inst)) {
|
||||
replacement = FoldBinary(*bin, ctx);
|
||||
if (!replacement) replacement = SimplifyBinary(*bin);
|
||||
} else if (auto* icmp = dynamic_cast<ICmpInst*>(inst)) {
|
||||
replacement = FoldICmp(*icmp, ctx);
|
||||
} else if (auto* fcmp = dynamic_cast<FCmpInst*>(inst)) {
|
||||
replacement = FoldFCmp(*fcmp, ctx);
|
||||
} else if (auto* cast = dynamic_cast<CastInst*>(inst)) {
|
||||
replacement = FoldCast(*cast, ctx);
|
||||
} else if (auto* phi = dynamic_cast<PhiInst*>(inst)) {
|
||||
replacement = SimplifyPhi(*phi);
|
||||
}
|
||||
if (replacement && replacement != inst) {
|
||||
inst->ReplaceAllUsesWith(replacement);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
|
||||
@ -1 +1,23 @@
|
||||
// IR Pass 管理骨架。
|
||||
#include "ir/IR.h"
|
||||
|
||||
namespace ir {
|
||||
|
||||
bool RunScalarOptimizationPipeline(Module& module) {
|
||||
bool changed = false;
|
||||
changed |= RunMem2Reg(module);
|
||||
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
bool iter_changed = false;
|
||||
iter_changed |= RunConstFold(module);
|
||||
iter_changed |= RunConstProp(module);
|
||||
iter_changed |= RunCSE(module);
|
||||
iter_changed |= RunDCE(module);
|
||||
iter_changed |= RunCFGSimplify(module);
|
||||
changed |= iter_changed;
|
||||
if (!iter_changed) break;
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
Loading…
Reference in new issue