You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/ir/passes/Mem2Reg.cpp

337 lines
9.5 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// Mem2RegSSA 构造):
// - 将局部变量的 alloca/load/store 提升为 SSA 形式
// - 插入 PHI 并重写使用,依赖支配树等分析
//
// 算法流程:
// 1. 识别可提升的 alloca标量仅通过 load/store 访问)
// 2. 计算支配树与支配边界
// 3. 在支配边界处插入 phi
// 4. 沿支配树重命名变量
// 5. 删除冗余 alloca/load/store
#include "ir/IR.h"
#include <algorithm>
#include <cassert>
#include <functional>
#include <queue>
#include <set>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace passes {
// ============ 内联支配树(与 analysis 版本相同) ============
namespace {
class DomTree {
public:
explicit DomTree(Function& func) : func_(func) { Compute(); }
BasicBlock* GetIDom(BasicBlock* bb) const {
auto it = idom_.find(bb);
return it != idom_.end() ? it->second : nullptr;
}
const std::vector<BasicBlock*>& GetDF(BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = df_.find(bb);
return it != df_.end() ? it->second : empty;
}
const std::vector<BasicBlock*>& GetChildren(BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = children_.find(bb);
return it != children_.end() ? it->second : empty;
}
const std::vector<BasicBlock*>& GetRPO() const { return rpo_; }
private:
void Compute() {
auto* entry = func_.GetEntry();
if (!entry) return;
ComputeRPO(entry);
if (rpo_.empty()) return;
for (auto* bb : rpo_) {
idom_[bb] = nullptr;
rpo_index_[bb] = 0;
}
for (size_t i = 0; i < rpo_.size(); ++i) {
rpo_index_[rpo_[i]] = i;
}
idom_[entry] = entry;
bool changed = true;
while (changed) {
changed = false;
for (auto* bb : rpo_) {
if (bb == entry) continue;
BasicBlock* new_idom = nullptr;
for (auto* pred : bb->GetPredecessors()) {
if (idom_.count(pred) && idom_[pred] != nullptr) {
if (!new_idom) {
new_idom = pred;
} else {
new_idom = Intersect(new_idom, pred);
}
}
}
if (new_idom && idom_[bb] != new_idom) {
idom_[bb] = new_idom;
changed = true;
}
}
}
for (auto* bb : rpo_) {
auto* p = GetIDom(bb);
if (p && p != bb) {
children_[p].push_back(bb);
}
}
ComputeDF();
}
void ComputeRPO(BasicBlock* entry) {
std::unordered_set<BasicBlock*> visited;
std::vector<BasicBlock*> post_order;
std::function<void(BasicBlock*)> dfs = [&](BasicBlock* bb) {
visited.insert(bb);
for (auto* succ : bb->GetSuccessors()) {
if (!visited.count(succ)) {
dfs(succ);
}
}
post_order.push_back(bb);
};
dfs(entry);
rpo_.assign(post_order.rbegin(), post_order.rend());
}
BasicBlock* Intersect(BasicBlock* b1, BasicBlock* b2) {
while (b1 != b2) {
while (rpo_index_[b1] > rpo_index_[b2]) b1 = idom_[b1];
while (rpo_index_[b2] > rpo_index_[b1]) b2 = idom_[b2];
}
return b1;
}
void ComputeDF() {
for (auto* bb : rpo_) {
df_[bb] = {};
}
for (auto* bb : rpo_) {
if (bb->GetPredecessors().size() < 2) continue;
for (auto* pred : bb->GetPredecessors()) {
auto* runner = pred;
while (runner && runner != idom_[bb]) {
auto& df_set = df_[runner];
if (std::find(df_set.begin(), df_set.end(), bb) == df_set.end()) {
df_set.push_back(bb);
}
if (runner == idom_[runner]) break;
runner = idom_[runner];
}
}
}
}
Function& func_;
std::vector<BasicBlock*> rpo_;
std::unordered_map<BasicBlock*, size_t> rpo_index_;
std::unordered_map<BasicBlock*, BasicBlock*> idom_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> children_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> df_;
};
// 判断一个 alloca 是否可以被提升为寄存器:
// - 必须是标量count == 1
// - 只被 load 和 store 使用
bool IsPromotable(AllocaInst* alloca) {
if (alloca->IsArray()) return false;
for (const auto& use : alloca->GetUses()) {
auto* user = use.GetUser();
if (!user) return false;
auto* inst = dynamic_cast<Instruction*>(user);
if (!inst) return false;
if (inst->GetOpcode() != Opcode::Load &&
inst->GetOpcode() != Opcode::Store) {
return false;
}
// store 只能把 alloca 作为 ptroperand 1不能作为 valoperand 0
if (inst->GetOpcode() == Opcode::Store) {
auto* store = static_cast<StoreInst*>(inst);
if (store->GetPtr() != alloca) return false;
}
}
return true;
}
} // namespace
bool RunMem2Reg(Function& func) {
if (func.IsExternal()) return false;
DomTree dom(func);
// 1. 收集可提升的 alloca
std::vector<AllocaInst*> promotable;
auto* entry = func.GetEntry();
if (!entry) return false;
for (const auto& inst : entry->GetInstructions()) {
if (auto* alloca = dynamic_cast<AllocaInst*>(inst.get())) {
if (IsPromotable(alloca)) {
promotable.push_back(alloca);
}
}
}
if (promotable.empty()) return false;
// 对每个可提升的 alloca 分别执行
for (auto* alloca : promotable) {
// 确定 alloca 值的类型
std::shared_ptr<Type> val_type;
if (alloca->GetType()->IsPtrInt32()) {
val_type = Type::GetInt32Type();
} else if (alloca->GetType()->IsPtrFloat32()) {
val_type = Type::GetFloat32Type();
} else {
continue;
}
// 2. 收集所有 def 块(包含 store 的块)和 use 块(包含 load 的块)
std::unordered_set<BasicBlock*> def_blocks;
std::vector<StoreInst*> stores;
std::vector<LoadInst*> loads;
for (const auto& use : alloca->GetUses()) {
auto* inst = dynamic_cast<Instruction*>(use.GetUser());
if (!inst || !inst->GetParent()) continue;
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
if (store->GetPtr() == alloca) {
def_blocks.insert(store->GetParent());
stores.push_back(store);
}
} else if (auto* load = dynamic_cast<LoadInst*>(inst)) {
loads.push_back(load);
}
}
// 3. 插入 phi 节点(使用迭代支配边界)
// 用 map 精确记录当前 alloca 在每个块中插入的 phi
std::unordered_map<BasicBlock*, PhiInst*> phi_map;
std::unordered_set<BasicBlock*> phi_blocks;
std::queue<BasicBlock*> worklist;
for (auto* bb : def_blocks) {
worklist.push(bb);
}
static int phi_counter = 0;
while (!worklist.empty()) {
auto* bb = worklist.front();
worklist.pop();
for (auto* df_bb : dom.GetDF(bb)) {
if (!phi_blocks.count(df_bb)) {
phi_blocks.insert(df_bb);
auto* phi = df_bb->PrependPhi(val_type,
"%phi." + std::to_string(phi_counter++));
phi_map[df_bb] = phi;
worklist.push(df_bb);
}
}
}
// 4. 重命名:沿支配树 DFS
std::stack<Value*> val_stack;
std::function<void(BasicBlock*)> Rename = [&](BasicBlock* bb) {
size_t stack_size = val_stack.size();
// 处理当前块中我们插入的 phi
auto phi_it = phi_map.find(bb);
if (phi_it != phi_map.end()) {
val_stack.push(phi_it->second);
}
// 遍历块中所有指令
std::vector<Instruction*> to_remove;
for (auto& inst_ptr : bb->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
if (store->GetPtr() == alloca) {
val_stack.push(store->GetValue());
to_remove.push_back(store);
}
} else if (auto* load = dynamic_cast<LoadInst*>(inst)) {
if (load->GetPtr() == alloca) {
Value* cur_val = val_stack.empty() ? nullptr : val_stack.top();
if (cur_val) {
load->ReplaceAllUsesWith(cur_val);
}
to_remove.push_back(load);
}
}
}
// 填充后继块中 phi 的入边
for (auto* succ : bb->GetSuccessors()) {
auto succ_phi_it = phi_map.find(succ);
if (succ_phi_it == phi_map.end()) continue;
Value* cur_val = val_stack.empty() ? nullptr : val_stack.top();
if (cur_val) {
succ_phi_it->second->AddIncoming(cur_val, bb);
}
}
// 递归处理支配树的孩子
for (auto* child : dom.GetChildren(bb)) {
Rename(child);
}
// 恢复栈
while (val_stack.size() > stack_size) {
val_stack.pop();
}
// 删除已标记的指令
for (auto* inst : to_remove) {
bb->RemoveInstruction(inst);
}
};
Rename(entry);
// 5. 删除 alloca
entry->RemoveInstruction(alloca);
// 6. 清理没有入边的 phi
for (auto* bb : dom.GetRPO()) {
std::vector<Instruction*> dead_phis;
for (auto& inst_ptr : bb->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst_ptr.get());
if (!phi) break;
if (phi->GetNumIncoming() == 0) {
dead_phis.push_back(phi);
}
// 如果 phi 只有一个入边,直接替换为该值
if (phi->GetNumIncoming() == 1) {
phi->ReplaceAllUsesWith(phi->GetIncomingValue(0));
dead_phis.push_back(phi);
}
}
for (auto* phi : dead_phis) {
bb->RemoveInstruction(phi);
}
}
}
return true;
}
} // namespace passes
} // namespace ir