You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/ir/passes/CFGSimplify.cpp

266 lines
8.5 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// CFG 简化:
// - 删除不可达块、合并空块、简化分支等
// - 改善 IR 结构,便于后续优化与后端生成
#include "ir/IR.h"
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
// 从入口块开始进行 BFS/DFS标记所有可达的基本块
std::unordered_set<BasicBlock*> FindReachableBlocks(Function* func) {
std::unordered_set<BasicBlock*> reachable;
std::vector<BasicBlock*> worklist;
auto* entry = func->GetEntry();
if (!entry) return reachable;
reachable.insert(entry);
worklist.push_back(entry);
while (!worklist.empty()) {
auto* bb = worklist.back();
worklist.pop_back();
// 遍历该块的所有后继块
for (const auto& inst_ptr : bb->GetInstructions()) {
auto* inst = inst_ptr.get();
// 无条件分支:跳转到目标块
if (auto* br = dynamic_cast<BranchInst*>(inst)) {
auto* target = br->GetTarget();
if (reachable.insert(target).second) {
worklist.push_back(target);
}
}
// 条件分支:跳转到 true 和 false 目标
else if (auto* condbr = dynamic_cast<CondBranchInst*>(inst)) {
auto* true_target = condbr->GetTrueTarget();
auto* false_target = condbr->GetFalseTarget();
if (reachable.insert(true_target).second) {
worklist.push_back(true_target);
}
if (reachable.insert(false_target).second) {
worklist.push_back(false_target);
}
}
// 终止指令,没有后继
if (inst->IsTerminator()) break;
}
}
return reachable;
}
// 获取基本块的 terminator 指令
Instruction* GetTerminator(BasicBlock* bb) {
const auto& insts = bb->GetInstructions();
if (insts.empty()) return nullptr;
return insts.back().get();
}
// 清理 PHI 节点中对已移除前驱的引用。
// 当某个基本块不再跳转到 target_block 时target_block 的 PHI 节点
// 需要移除引用该前驱的条目。
void CleanupPhiReferences(BasicBlock* target_block, BasicBlock* removed_pred) {
auto& insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
target_block->GetInstructions());
std::vector<PhiInst*> phis_to_delete;
for (auto& inst_ptr : insts) {
auto* phi = dynamic_cast<PhiInst*>(inst_ptr.get());
if (!phi) break;
// 收集保留的 (value, block) 对
std::vector<std::pair<Value*, BasicBlock*>> keep_pairs;
for (size_t i = 0; i < phi->GetNumOperands(); i += 2) {
auto* val = phi->GetOperand(i);
auto* pred = static_cast<BasicBlock*>(phi->GetOperand(i + 1));
if (pred != removed_pred) {
keep_pairs.push_back({val, pred});
}
}
if (keep_pairs.empty()) {
phis_to_delete.push_back(phi);
} else if (keep_pairs.size() == 1) {
phi->ReplaceAllUsesWith(keep_pairs[0].first);
phis_to_delete.push_back(phi);
} else if (keep_pairs.size() * 2 != phi->GetNumOperands()) {
// 部分条目被移除:重建操作数列表
phi->ClearOperands();
for (auto& [val, pred] : keep_pairs) {
phi->AddOperand(val);
phi->AddOperand(pred);
}
}
}
// 删除标记的 PHI 节点
if (!phis_to_delete.empty()) {
auto new_end = std::remove_if(
insts.begin(), insts.end(),
[&phis_to_delete](const std::unique_ptr<Instruction>& inst_ptr) {
return std::find(phis_to_delete.begin(), phis_to_delete.end(),
inst_ptr.get()) != phis_to_delete.end();
});
insts.erase(new_end, insts.end());
}
}
} // namespace
void RunCFGSimplify(Module& module) {
for (auto& func_ptr : module.GetFunctions()) {
auto* func = func_ptr.get();
if (func->IsExternal()) continue;
auto& blocks = const_cast<std::vector<std::unique_ptr<BasicBlock>>&>(func->GetBlocks());
bool changed = true;
while (changed) {
changed = false;
auto reachable = FindReachableBlocks(func);
std::unordered_set<BasicBlock*> unreachable;
for (auto& bb : blocks) {
if (reachable.find(bb.get()) == reachable.end()) {
unreachable.insert(bb.get());
}
}
for (auto& bb_ptr : blocks) {
auto* bb = bb_ptr.get();
if (unreachable.find(bb) != unreachable.end()) continue;
std::vector<std::pair<PhiInst*, Value*>> phi_replacements;
std::vector<PhiInst*> phi_to_delete;
for (auto& inst_ptr : bb->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst_ptr.get());
if (!phi) break;
// 收集所有可达前驱对应的 (value, block) 对
std::vector<std::pair<Value*, BasicBlock*>> valid_pairs;
bool has_unreachable = false;
for (size_t i = 0; i < phi->GetNumOperands(); i += 2) {
auto* val = phi->GetOperand(i);
auto* pred = static_cast<BasicBlock*>(phi->GetOperand(i + 1));
if (unreachable.find(pred) == unreachable.end()) {
valid_pairs.push_back({val, pred});
} else {
has_unreachable = true;
}
}
if (valid_pairs.size() == 1) {
phi_replacements.push_back({phi, valid_pairs[0].first});
} else if (valid_pairs.empty()) {
phi_to_delete.push_back(phi);
} else if (has_unreachable && valid_pairs.size() >= 2) {
// 部分前驱不可达:清理后仅保留可达条目
phi->ClearOperands();
for (auto& [val, pred] : valid_pairs) {
phi->AddOperand(val);
phi->AddOperand(pred);
}
}
}
for (auto& [phi, val] : phi_replacements) {
phi->ReplaceAllUsesWith(val);
phi_to_delete.push_back(phi);
}
auto& insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(bb->GetInstructions());
auto new_end = std::remove_if(insts.begin(), insts.end(),
[&phi_to_delete](const std::unique_ptr<Instruction>& inst_ptr) {
return std::find(phi_to_delete.begin(), phi_to_delete.end(), inst_ptr.get()) != phi_to_delete.end();
}
);
insts.erase(new_end, insts.end());
}
size_t old_size = blocks.size();
blocks.erase(
std::remove_if(blocks.begin(), blocks.end(),
[&reachable](const std::unique_ptr<BasicBlock>& bb_ptr) {
return reachable.find(bb_ptr.get()) == reachable.end();
}
),
blocks.end()
);
if (blocks.size() != old_size) {
changed = true;
}
for (auto& bb_ptr : blocks) {
auto* bb = bb_ptr.get();
auto* term = GetTerminator(bb);
if (!term) continue;
if (auto* condbr = dynamic_cast<CondBranchInst*>(term)) {
auto* cond = condbr->GetCond();
if (auto* const_int = dynamic_cast<ConstantInt*>(cond)) {
auto* true_target = condbr->GetTrueTarget();
auto* false_target = condbr->GetFalseTarget();
auto* live_target = (const_int->GetValue() != 0) ? true_target : false_target;
auto* dead_target = (const_int->GetValue() != 0) ? false_target : true_target;
bb->RemoveInstruction(condbr);
bb->Append<BranchInst>(Type::GetVoidType(), live_target);
changed = true;
// 清理 dead_target 的 PHI 节点中对 bb 的引用
CleanupPhiReferences(dead_target, bb);
}
}
}
}
for (auto& bb_ptr : blocks) {
auto* bb = bb_ptr.get();
const auto& preds = bb->GetPredecessors();
if (preds.size() == 1) {
std::vector<std::pair<PhiInst*, Value*>> phi_replacements;
for (auto& inst_ptr : bb->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst_ptr.get());
if (!phi) break;
if (phi->GetNumOperands() >= 2) {
Value* incoming_val = phi->GetOperand(0);
phi_replacements.push_back({phi, incoming_val});
}
}
for (auto& [phi, val] : phi_replacements) {
phi->ReplaceAllUsesWith(val);
}
auto& insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(bb->GetInstructions());
auto new_end = std::remove_if(insts.begin(), insts.end(),
[](const std::unique_ptr<Instruction>& inst_ptr) {
return dynamic_cast<const PhiInst*>(inst_ptr.get()) != nullptr;
}
);
insts.erase(new_end, insts.end());
}
}
}
}
} // namespace ir