You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/mir/RegAlloc.cpp

691 lines
20 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "mir/MIR.h"
#include <algorithm>
#include <map>
#include <set>
#include <stdexcept>
#include <unordered_map>
#include <vector>
#include <queue>
#include <cmath>
#include "utils/Log.h"
namespace mir {
namespace {
// ========== VReg 类型 ==========
enum class VRegClass { kInt32, kInt64, kFloat32 };
// ========== 活跃区间 ==========
struct LiveInterval {
int vreg;
int start;
int end;
VRegClass reg_class;
LiveInterval(int v, int s, int e, VRegClass rc)
: vreg(v), start(s), end(e), reg_class(rc) {}
};
// ========== 物理寄存器池 ==========
// GPR: X19-X28 / W19-W28 (10个物理寄存器Xn和Wn是同一寄存器的不同视图)
// 注意Int32和Int64共享这10个物理GPR
const PhysReg kGPRPool[] = {
PhysReg::X19, PhysReg::X20, PhysReg::X21, PhysReg::X22,
PhysReg::X23, PhysReg::X24, PhysReg::X25, PhysReg::X26,
PhysReg::X27, PhysReg::X28,
};
constexpr int kNumGPR = sizeof(kGPRPool) / sizeof(kGPRPool[0]);
// 获取对应的W寄存器
PhysReg ToWReg(PhysReg xreg) {
int idx = static_cast<int>(xreg) - static_cast<int>(PhysReg::X0);
return static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + idx);
}
// 浮点寄存器池
const PhysReg kFPR32Pool[] = {
PhysReg::S8, PhysReg::S9, PhysReg::S10, PhysReg::S11,
PhysReg::S12, PhysReg::S13,
};
constexpr int kNumFPR32 = sizeof(kFPR32Pool) / sizeof(kFPR32Pool[0]);
// Spill scratch registers
const PhysReg kSpillScratchInt32[] = { PhysReg::W15, PhysReg::W14 };
const PhysReg kSpillScratchInt64[] = { PhysReg::X15, PhysReg::X14 };
const PhysReg kSpillScratchFloat[] = { PhysReg::S15, PhysReg::S14 };
PhysReg GetSpillScratch(VRegClass rc, int idx) {
switch (rc) {
case VRegClass::kInt32: return kSpillScratchInt32[idx % 2];
case VRegClass::kInt64: return kSpillScratchInt64[idx % 2];
case VRegClass::kFloat32: return kSpillScratchFloat[idx % 2];
}
return kSpillScratchInt32[0];
}
bool IsCalleeSaved(PhysReg reg) {
switch (reg) {
case PhysReg::W19: case PhysReg::W20: case PhysReg::W21: case PhysReg::W22:
case PhysReg::W23: case PhysReg::W24: case PhysReg::W25: case PhysReg::W26:
case PhysReg::W27: case PhysReg::W28:
case PhysReg::X19: case PhysReg::X20: case PhysReg::X21: case PhysReg::X22:
case PhysReg::X23: case PhysReg::X24: case PhysReg::X25: case PhysReg::X26:
case PhysReg::X27: case PhysReg::X28:
case PhysReg::S8: case PhysReg::S9: case PhysReg::S10: case PhysReg::S11:
case PhysReg::S12: case PhysReg::S13: case PhysReg::S14: case PhysReg::S15:
case PhysReg::S16: case PhysReg::S17: case PhysReg::S18: case PhysReg::S19:
case PhysReg::S20: case PhysReg::S21: case PhysReg::S22: case PhysReg::S23:
case PhysReg::S24: case PhysReg::S25: case PhysReg::S26: case PhysReg::S27:
case PhysReg::S28: case PhysReg::S29: case PhysReg::S30: case PhysReg::S31:
return true;
default: return false;
}
}
// 获取GPR的统一编号0-9对应X19/W19到X28/W28
int GetGPRIndex(PhysReg reg) {
if (reg >= PhysReg::W19 && reg <= PhysReg::W28)
return static_cast<int>(reg) - static_cast<int>(PhysReg::W19);
if (reg >= PhysReg::X19 && reg <= PhysReg::X28)
return static_cast<int>(reg) - static_cast<int>(PhysReg::X19);
return -1;
}
// 获取FPR编号
int GetFPRIndex(PhysReg reg) {
if (reg >= PhysReg::S8 && reg <= PhysReg::S13)
return static_cast<int>(reg) - static_cast<int>(PhysReg::S8);
return -1;
}
// ========== 推断 vreg 类型 ==========
VRegClass InferVRegClass(int vreg, MachineFunction& function) {
if (function.HasVRegType(vreg)) {
switch (function.GetVRegType(vreg)) {
case MachineFunction::VRegType::kFloat32: return VRegClass::kFloat32;
case MachineFunction::VRegType::kInt64: return VRegClass::kInt64;
case MachineFunction::VRegType::kInt32: return VRegClass::kInt32;
}
}
return VRegClass::kInt32;
}
// ========== 指令编号 ==========
void NumberInstructions(MachineFunction& function,
std::unordered_map<MachineInstr*, int>& instrToIdx,
std::vector<MachineInstr*>& idxToInstr,
std::map<int, MachineBasicBlock*>& blockBoundary) {
int idx = 0;
for (auto& bb : function.GetBasicBlocks()) {
blockBoundary[idx] = bb.get();
for (auto& inst : bb->GetInstructions()) {
instrToIdx[&inst] = idx;
idxToInstr.push_back(&inst);
++idx;
}
}
}
// ========== 计算活跃区间和数据流信息 ==========
std::vector<LiveInterval> ComputeLiveIntervals(
MachineFunction& function,
std::unordered_map<MachineBasicBlock*, std::set<int>>& liveIn,
std::unordered_map<MachineBasicBlock*, std::set<int>>& liveOut,
std::unordered_map<MachineInstr*, int>& instrToIdx,
std::vector<MachineInstr*>& idxToInstr) {
const auto& blocks = function.GetBasicBlocks();
if (blocks.empty()) return {};
std::map<int, MachineBasicBlock*> blockBoundary;
NumberInstructions(function, instrToIdx, idxToInstr, blockBoundary);
// 收集所有 vreg
std::set<int> allVRegs;
for (auto* inst : idxToInstr) {
for (int d : inst->GetDefs()) allVRegs.insert(d);
for (int u : inst->GetUses()) allVRegs.insert(u);
}
// 每个 vreg 的活跃位置
std::unordered_map<int, std::set<int>> vregPositions;
struct BlockInfo {
std::set<int> use;
std::set<int> def;
int startIdx;
int endIdx;
};
std::unordered_map<MachineBasicBlock*, BlockInfo> blockInfo;
for (const auto& bb : blocks) {
auto& info = blockInfo[bb.get()];
auto& insts = bb->GetInstructions();
if (!insts.empty()) {
info.startIdx = instrToIdx[&insts.front()];
info.endIdx = instrToIdx[&insts.back()] + 1;
} else {
info.startIdx = 0;
info.endIdx = 0;
}
for (auto& inst : insts) {
int pos = instrToIdx[&inst];
for (int def : inst.GetDefs()) {
info.def.insert(def);
vregPositions[def].insert(pos);
}
for (int use : inst.GetUses()) {
if (info.def.count(use) == 0) {
info.use.insert(use);
}
vregPositions[use].insert(pos);
}
}
}
// 数据流分析
bool changed = true;
while (changed) {
changed = false;
for (auto it = blocks.rbegin(); it != blocks.rend(); ++it) {
MachineBasicBlock* bb = it->get();
auto& info = blockInfo[bb];
std::set<int> newLiveOut;
for (auto* succ : bb->GetSuccessors()) {
for (int v : liveIn[succ]) newLiveOut.insert(v);
}
if (newLiveOut != liveOut[bb]) {
liveOut[bb] = newLiveOut;
changed = true;
}
std::set<int> newLiveIn = info.use;
for (int v : liveOut[bb]) {
if (info.def.count(v) == 0) newLiveIn.insert(v);
}
if (newLiveIn != liveIn[bb]) {
liveIn[bb] = newLiveIn;
changed = true;
}
}
}
// 生成 LiveInterval
std::vector<LiveInterval> intervals;
for (int vreg : allVRegs) {
auto it = vregPositions.find(vreg);
if (it == vregPositions.end() || it->second.empty()) continue;
int start = *it->second.begin();
int end = *it->second.rbegin();
for (const auto& bb : blocks) {
auto& info = blockInfo[bb.get()];
if (info.startIdx == 0 && info.endIdx == 0) continue;
bool isLiveIn = liveIn[bb.get()].count(vreg) != 0;
bool isLiveOut = liveOut[bb.get()].count(vreg) != 0;
if (isLiveIn || isLiveOut) {
if (info.startIdx < start) start = info.startIdx;
if (info.endIdx > end) end = info.endIdx;
}
}
VRegClass rc = InferVRegClass(vreg, function);
intervals.emplace_back(vreg, start, end, rc);
}
return intervals;
}
// ========== 图着色核心数据结构 ==========
struct IGNode {
int vreg;
VRegClass reg_class;
std::set<int> neighbors;
int degree;
bool removed;
bool is_spill_candidate;
double spill_cost;
};
struct StackEntry {
int vreg;
bool is_spill_candidate;
};
// 干涉图:按"寄存器类别组"构建
// GPR组: Int32 + Int64 (共享物理寄存器)
// FPR组: Float32
struct InterferenceGraph {
std::unordered_map<int, IGNode> nodes;
std::set<int> remaining;
int k; // 可用颜色数
bool is_gpr; // true=GPR组(Int32+Int64), false=FPR组
};
// ========== 计算使用频率 ==========
std::unordered_map<int, int> ComputeUseCounts(MachineFunction& function) {
std::unordered_map<int, int> useCounts;
for (auto& bb : function.GetBasicBlocks()) {
for (auto& inst : bb->GetInstructions()) {
for (int u : inst.GetUses()) useCounts[u]++;
for (int d : inst.GetDefs()) useCounts[d]++;
}
}
return useCounts;
}
// ========== 构建干涉图 ==========
// 关键修复Int32和Int64在同一个干涉图中因为共享物理GPR
InterferenceGraph BuildInterferenceGraph(
MachineFunction& function,
const std::vector<LiveInterval>& intervals,
const std::unordered_map<MachineBasicBlock*, std::set<int>>& liveIn,
const std::unordered_map<MachineInstr*, int>& instrToIdx,
bool buildGPR) { // true=构建GPR图(Int32+Int64), false=构建FPR图
InterferenceGraph ig;
ig.is_gpr = buildGPR;
ig.k = buildGPR ? kNumGPR : kNumFPR32;
// 收集vreg
std::set<int> vregs;
for (const auto& iv : intervals) {
if (buildGPR) {
// GPR图包含Int32和Int64
if (iv.reg_class == VRegClass::kInt32 || iv.reg_class == VRegClass::kInt64) {
vregs.insert(iv.vreg);
}
} else {
// FPR图只包含Float32
if (iv.reg_class == VRegClass::kFloat32) {
vregs.insert(iv.vreg);
}
}
}
if (vregs.empty()) return ig;
// 初始化节点
auto useCounts = ComputeUseCounts(function);
for (int v : vregs) {
IGNode node;
node.vreg = v;
// 找到vreg的reg_class
for (const auto& iv : intervals) {
if (iv.vreg == v) {
node.reg_class = iv.reg_class;
break;
}
}
node.degree = 0;
node.removed = false;
node.is_spill_candidate = false;
node.spill_cost = useCounts.count(v) ? useCounts[v] : 1.0;
ig.nodes[v] = std::move(node);
ig.remaining.insert(v);
}
// 构建干涉边:反向遍历指令,正确模拟活跃集合
for (const auto& bb : function.GetBasicBlocks()) {
auto& insts = bb->GetInstructions();
if (insts.empty()) continue;
// 反向遍历从liveOut开始
std::set<int> live = liveIn.at(bb.get());
// 注意:我们需要正向检查活跃集合,但用正确的数据流
// 更简单的方法:直接用活跃区间重叠来构建边
}
// 用活跃区间重叠构建干涉边(更可靠)
for (auto it1 = vregs.begin(); it1 != vregs.end(); ++it1) {
auto it2 = it1;
++it2;
for (; it2 != vregs.end(); ++it2) {
int a = *it1, b = *it2;
// 找到a和b的活跃区间
const LiveInterval* ivA = nullptr;
const LiveInterval* ivB = nullptr;
for (const auto& iv : intervals) {
if (iv.vreg == a) ivA = &iv;
if (iv.vreg == b) ivB = &iv;
}
if (!ivA || !ivB) continue;
// 检查区间是否重叠(包含端点)
// 两个区间[s1,e1]和[s2,e2]重叠当且仅当:
// max(s1,s2) <= min(e1,e2)
int maxStart = std::max(ivA->start, ivB->start);
int minEnd = std::min(ivA->end, ivB->end);
if (maxStart <= minEnd) {
// 活跃区间重叠,添加干涉边
ig.nodes[a].neighbors.insert(b);
ig.nodes[b].neighbors.insert(a);
}
}
}
// 计算度数和spill cost
for (auto& [vreg, node] : ig.nodes) {
node.degree = static_cast<int>(node.neighbors.size());
if (node.degree > 0) {
node.spill_cost = node.spill_cost / node.degree;
}
}
return ig;
}
// ========== Simplify阶段 ==========
std::vector<StackEntry> Simplify(InterferenceGraph& ig) {
std::vector<StackEntry> stack;
std::queue<int> worklist;
for (int v : ig.remaining) {
if (ig.nodes[v].degree < ig.k) {
worklist.push(v);
}
}
while (!ig.remaining.empty()) {
if (!worklist.empty()) {
int v = worklist.front();
worklist.pop();
if (!ig.remaining.count(v) || ig.nodes[v].removed) continue;
if (ig.nodes[v].degree >= ig.k) continue;
stack.push_back({v, false});
ig.nodes[v].removed = true;
ig.remaining.erase(v);
for (int u : ig.nodes[v].neighbors) {
if (ig.remaining.count(u) && !ig.nodes[u].removed) {
ig.nodes[u].degree--;
if (ig.nodes[u].degree < ig.k) {
worklist.push(u);
}
}
}
} else {
double bestCost = 1e300;
int bestVreg = -1;
for (int v : ig.remaining) {
if (ig.nodes[v].removed) continue;
if (ig.nodes[v].spill_cost < bestCost) {
bestCost = ig.nodes[v].spill_cost;
bestVreg = v;
}
}
if (bestVreg < 0) break;
stack.push_back({bestVreg, true});
ig.nodes[bestVreg].removed = true;
ig.remaining.erase(bestVreg);
for (int u : ig.nodes[bestVreg].neighbors) {
if (ig.remaining.count(u) && !ig.nodes[u].removed) {
ig.nodes[u].degree--;
if (ig.nodes[u].degree < ig.k) {
worklist.push(u);
}
}
}
}
}
return stack;
}
// ========== Select阶段 ==========
// 关键修复颜色用GPR索引0-9表示分配时根据vreg类型选择Xn或Wn
std::pair<std::unordered_map<int, PhysReg>, std::set<int>> SelectColors(
const InterferenceGraph& origIg,
const std::vector<StackEntry>& stack,
const std::vector<LiveInterval>& intervals,
MachineFunction& function) {
std::unordered_map<int, PhysReg> coloring;
std::set<int> actualSpills;
// 跟踪已分配的GPR索引0-9
std::unordered_map<int, int> colorToVReg;
// 逆序弹出栈
for (int i = static_cast<int>(stack.size()) - 1; i >= 0; --i) {
const auto& entry = stack[i];
int vreg = entry.vreg;
const auto& node = origIg.nodes.at(vreg);
// 收集已着色邻居使用的颜色GPR索引
std::set<int> usedColors;
for (int neighbor : node.neighbors) {
auto it = coloring.find(neighbor);
if (it != coloring.end()) {
int colorIdx = -1;
if (origIg.is_gpr) {
colorIdx = GetGPRIndex(it->second);
} else {
colorIdx = GetFPRIndex(it->second);
}
if (colorIdx >= 0) {
usedColors.insert(colorIdx);
}
}
}
// 找第一个可用颜色
int chosenColor = -1;
for (int c = 0; c < origIg.k; ++c) {
if (!usedColors.count(c)) {
chosenColor = c;
break;
}
}
if (chosenColor >= 0) {
// 根据vreg类型选择物理寄存器
PhysReg physReg;
if (origIg.is_gpr) {
if (node.reg_class == VRegClass::kInt64) {
physReg = kGPRPool[chosenColor]; // Xn
} else {
physReg = ToWReg(kGPRPool[chosenColor]); // Wn
}
} else {
physReg = kFPR32Pool[chosenColor]; // Sn
}
coloring[vreg] = physReg;
if (IsCalleeSaved(physReg)) {
function.MarkCalleeSaved(physReg);
}
} else {
actualSpills.insert(vreg);
}
}
return {coloring, actualSpills};
}
// ========== 重写指令 ==========
void RewriteInstructions(
MachineFunction& function,
const std::unordered_map<int, PhysReg>& vregToPhys,
const std::unordered_map<int, int>& spillSlots,
const std::vector<LiveInterval>& intervals) {
auto getSpillRC = [&](int vreg) -> VRegClass {
for (auto& iv : intervals) {
if (iv.vreg == vreg) return iv.reg_class;
}
return VRegClass::kInt32;
};
for (auto& bb : function.GetBasicBlocks()) {
std::vector<MachineInstr> newInsts;
auto& insts = bb->GetInstructions();
for (auto& inst : insts) {
auto& ops = inst.GetOperands();
std::vector<int>& defs = inst.GetDefs();
std::vector<int>& uses = inst.GetUses();
// 收集需要reload的spilled use
std::vector<int> spilledUses;
{
std::set<int> seen;
for (int vreg : uses) {
if (spillSlots.count(vreg) && seen.insert(vreg).second) {
spilledUses.push_back(vreg);
}
}
}
// 插入reload
for (size_t si = 0; si < spilledUses.size(); ++si) {
int vreg = spilledUses[si];
int slot = spillSlots.at(vreg);
PhysReg loadReg;
auto it = vregToPhys.find(vreg);
if (it != vregToPhys.end()) {
loadReg = it->second;
} else {
loadReg = GetSpillScratch(getSpillRC(vreg), static_cast<int>(si));
}
newInsts.emplace_back(Opcode::LoadStack,
std::vector<Operand>{Operand::Reg(loadReg), Operand::FrameIndex(slot)});
}
// 替换VReg为PhysReg
int spillUseIdx = 0;
for (auto& op : ops) {
if (op.GetKind() == Operand::Kind::VReg) {
int vreg = op.GetVReg();
auto it = vregToPhys.find(vreg);
if (it != vregToPhys.end()) {
op = Operand::Reg(it->second);
} else {
int idx = 0;
if (spillSlots.count(vreg)) {
for (size_t si = 0; si < spilledUses.size(); ++si) {
if (spilledUses[si] == vreg) { idx = static_cast<int>(si); break; }
}
}
op = Operand::Reg(GetSpillScratch(getSpillRC(vreg), idx));
spillUseIdx++;
}
}
}
newInsts.push_back(inst);
// 插入def后的store
std::vector<int> spilledDefs;
{
std::set<int> seen;
for (int vreg : defs) {
if (spillSlots.count(vreg) && seen.insert(vreg).second) {
spilledDefs.push_back(vreg);
}
}
}
for (size_t si = 0; si < spilledDefs.size(); ++si) {
int vreg = spilledDefs[si];
int slot = spillSlots.at(vreg);
PhysReg storeReg;
auto it = vregToPhys.find(vreg);
if (it != vregToPhys.end()) {
storeReg = it->second;
} else {
storeReg = GetSpillScratch(getSpillRC(vreg), static_cast<int>(si));
}
newInsts.emplace_back(Opcode::StoreStack,
std::vector<Operand>{Operand::Reg(storeReg), Operand::FrameIndex(slot)});
}
}
insts = std::move(newInsts);
}
}
// ========== 图着色寄存器分配主函数 ==========
void RunGraphColoringRegAlloc(MachineFunction& function) {
// 1. 数据流分析和活跃区间
std::unordered_map<MachineBasicBlock*, std::set<int>> liveIn, liveOut;
std::unordered_map<MachineInstr*, int> instrToIdx;
std::vector<MachineInstr*> idxToInstr;
auto intervals = ComputeLiveIntervals(function, liveIn, liveOut, instrToIdx, idxToInstr);
if (intervals.empty()) return;
// 2. 分配结果
std::unordered_map<int, PhysReg> vregToPhys;
std::unordered_map<int, int> spillSlots;
// 3. 构建GPR干涉图Int32 + Int64共享物理寄存器
{
InterferenceGraph ig = BuildInterferenceGraph(
function, intervals, liveIn, instrToIdx, true);
if (!ig.nodes.empty()) {
auto stack = Simplify(ig);
auto [coloring, spills] = SelectColors(ig, stack, intervals, function);
for (auto& [vreg, phys] : coloring) {
vregToPhys[vreg] = phys;
}
for (int vreg : spills) {
if (spillSlots.count(vreg)) continue;
VRegClass rc = VRegClass::kInt32;
for (const auto& iv : intervals) {
if (iv.vreg == vreg) { rc = iv.reg_class; break; }
}
int slotSize = (rc == VRegClass::kInt64) ? 8 : 4;
int slot = function.CreateSpillSlot(slotSize);
spillSlots[vreg] = slot;
}
}
}
// 4. 构建FPR干涉图Float32
{
InterferenceGraph ig = BuildInterferenceGraph(
function, intervals, liveIn, instrToIdx, false);
if (!ig.nodes.empty()) {
auto stack = Simplify(ig);
auto [coloring, spills] = SelectColors(ig, stack, intervals, function);
for (auto& [vreg, phys] : coloring) {
vregToPhys[vreg] = phys;
}
for (int vreg : spills) {
if (spillSlots.count(vreg)) continue;
int slot = function.CreateSpillSlot(4);
spillSlots[vreg] = slot;
}
}
}
// 5. 重写指令
RewriteInstructions(function, vregToPhys, spillSlots, intervals);
// 6. 清除def/use标记
for (auto& bb : function.GetBasicBlocks()) {
for (auto& inst : bb->GetInstructions()) {
inst.GetDefs().clear();
inst.GetUses().clear();
}
}
}
} // namespace
// ========== 模块入口 ==========
void RunRegAlloc(MachineModule& module) {
for (auto& func : module.GetFunctions()) {
RunGraphColoringRegAlloc(*func);
}
}
} // namespace mir