You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

196 lines
6.2 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// Mem2RegSSA 构造):
// - 将局部变量的 alloca/load/store 提升为 SSA 形式
// - 插入 PHI 节点并重写使用,所有可提升的 alloca 在一次重命名遍中处理
#include "ir/IR.h"
#include <functional>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
bool IsPromotable(AllocaInst* alloca) {
if (alloca->IsArray()) return false;
for (const auto& use : alloca->GetUses()) {
auto* user = use.GetUser();
if (auto* load = dynamic_cast<LoadInst*>(user)) {
if (load->GetPtr() != alloca) return false;
} else if (auto* store = dynamic_cast<StoreInst*>(user)) {
if (store->GetPtr() != alloca) return false;
} else {
return false;
}
}
return true;
}
void CollectStoresAndLoads(AllocaInst* alloca,
std::vector<StoreInst*>& stores,
std::vector<LoadInst*>& loads) {
for (const auto& use : alloca->GetUses()) {
if (auto* store = dynamic_cast<StoreInst*>(use.GetUser())) {
if (use.GetOperandIndex() == 1) stores.push_back(store);
} else if (auto* load = dynamic_cast<LoadInst*>(use.GetUser())) {
loads.push_back(load);
}
}
}
std::set<BasicBlock*> ComputeIDF(const std::set<BasicBlock*>& def_blocks,
DominatorTree& dt) {
std::set<BasicBlock*> df_plus;
std::vector<BasicBlock*> worklist(def_blocks.begin(), def_blocks.end());
std::set<BasicBlock*> visited(def_blocks.begin(), def_blocks.end());
while (!worklist.empty()) {
auto* bb = worklist.back();
worklist.pop_back();
for (auto* df_bb : dt.GetDominanceFrontier(bb)) {
if (df_plus.insert(df_bb).second) {
if (visited.insert(df_bb).second) worklist.push_back(df_bb);
}
}
}
return df_plus;
}
struct AllocaInfo {
AllocaInst* alloca = nullptr;
std::vector<StoreInst*> stores;
std::vector<LoadInst*> loads;
std::unordered_map<BasicBlock*, PhiInst*> phis;
std::vector<Value*> value_stack;
Value* undef_val = nullptr;
};
} // namespace
bool RunMem2Reg(Function& func, Context& ctx) {
DominatorTree dt;
dt.Compute(func);
// 收集所有可提升的 alloca
std::vector<AllocaInst*> promotable;
for (auto& bb : func.GetBlocks()) {
for (auto& inst : bb->GetInstructions()) {
if (auto* alloca = dynamic_cast<AllocaInst*>(inst.get())) {
if (IsPromotable(alloca)) promotable.push_back(alloca);
}
}
}
if (promotable.empty()) return false;
// 为每个可提升的 alloca 构建信息
std::vector<AllocaInfo> infos(promotable.size());
std::unordered_map<StoreInst*, int> store_to_info;
std::unordered_map<LoadInst*, int> load_to_info;
for (size_t i = 0; i < promotable.size(); ++i) {
auto* alloca = promotable[i];
auto& info = infos[i];
info.alloca = alloca;
CollectStoresAndLoads(alloca, info.stores, info.loads);
std::set<BasicBlock*> def_blocks;
for (auto* s : info.stores) def_blocks.insert(s->GetParent());
auto val_type = alloca->GetType()->IsPtrFloat32() ? Type::GetFloat32Type()
: Type::GetInt32Type();
info.undef_val = alloca->GetType()->IsPtrFloat32()
? static_cast<Value*>(ctx.GetConstFloat(0.0f))
: static_cast<Value*>(ctx.GetConstInt(0));
info.value_stack.push_back(info.undef_val);
// 插入 PHI 节点到迭代支配边界
auto df_plus = ComputeIDF(def_blocks, dt);
for (auto* bb : df_plus) {
auto* phi = bb->Prepend<PhiInst>(val_type, "");
info.phis[bb] = phi;
}
// 建立快速查找映射
for (auto* s : info.stores) store_to_info[s] = (int)i;
for (auto* l : info.loads) load_to_info[l] = (int)i;
}
// ─── 单次重命名遍DFS 遍历支配树,同时处理所有 alloca ──────────────
std::function<void(BasicBlock*)> rename = [&](BasicBlock* bb) {
// 保存所有栈大小
std::vector<size_t> saved_sizes(infos.size());
for (size_t i = 0; i < infos.size(); ++i) {
saved_sizes[i] = infos[i].value_stack.size();
auto phi_it = infos[i].phis.find(bb);
if (phi_it != infos[i].phis.end()) {
infos[i].value_stack.push_back(phi_it->second);
}
}
// 处理块内指令
for (auto& inst_up : bb->GetInstructions()) {
auto* inst = inst_up.get();
// Skip PHI nodes (they've already been pushed onto stacks)
if (dynamic_cast<PhiInst*>(inst)) continue;
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
auto it = store_to_info.find(store);
if (it != store_to_info.end()) {
infos[it->second].value_stack.push_back(store->GetValue());
}
} else if (auto* load = dynamic_cast<LoadInst*>(inst)) {
auto it = load_to_info.find(load);
if (it != load_to_info.end()) {
load->ReplaceAllUsesWith(infos[it->second].value_stack.back());
}
}
}
// 设置后继块中 PHI 节点的 incoming values
for (auto* succ : bb->GetSuccessors()) {
for (size_t i = 0; i < infos.size(); ++i) {
auto phi_it = infos[i].phis.find(succ);
if (phi_it != infos[i].phis.end()) {
phi_it->second->AddIncoming(infos[i].value_stack.back(), bb);
}
}
}
// 递归遍历支配树子节点
for (auto* child : dt.GetChildren(bb)) rename(child);
// 恢复栈
for (size_t i = 0; i < infos.size(); ++i) {
infos[i].value_stack.resize(saved_sizes[i]);
}
};
rename(func.GetEntry());
// 删除已提升的 load、store 和 alloca
// 必须先断开 use-def 链再删除,否则其他值的使用列表中会有悬空指针
for (auto& info : infos) {
for (auto* ld : info.loads) {
ld->SetOperand(0, nullptr); // 断开对 alloca 的引用
ld->RemoveFromParent();
}
for (auto* st : info.stores) {
st->SetOperand(0, nullptr); // 断开对 value 的引用
st->SetOperand(1, nullptr); // 断开对 alloca 的引用
st->RemoveFromParent();
}
info.alloca->RemoveFromParent();
}
// 批量清除已标记的指令(一次 O(n) sweep 替代每次 O(n) 删除)
for (auto& bb : func.GetBlocks()) {
bb->SweepDeadInstructions();
}
return true;
}
} // namespace ir