线性扫描改图着色

ftt 6 days ago
parent e73f7cc871
commit d20639d4ba

@ -6,6 +6,8 @@
#include <stdexcept>
#include <unordered_map>
#include <vector>
#include <queue>
#include <cmath>
#include "utils/Log.h"
@ -26,36 +28,30 @@ struct LiveInterval {
: vreg(v), start(s), end(e), reg_class(rc) {}
};
// ========== 可分配物理寄存器池 ==========
// 仅使用 callee-saved 寄存器W19-W28 (10个), X19-X28 (10个), S8-S13 (6个)
// 原因caller-saved 寄存器 (W0-W18, X0-X18, S0-S7) 不能跨函数调用存活,
// 而寄存器分配器未实现调用点 spill。使用 caller-saved 寄存器会导致跨调用值被
// 被调用者破坏。
// W14, W15 / X14, X15 / S14, S15 保留为 spill scratch。
const PhysReg kGPR32Pool[] = {
PhysReg::W19, PhysReg::W20, PhysReg::W21, PhysReg::W22,
PhysReg::W23, PhysReg::W24, PhysReg::W25, PhysReg::W26,
PhysReg::W27, PhysReg::W28,
};
constexpr int kNumGPR32 = sizeof(kGPR32Pool) / sizeof(kGPR32Pool[0]);
// 整数 64位仅 callee-savedX19-X28
const PhysReg kGPR64Pool[] = {
// ========== 物理寄存器池 ==========
// GPR: X19-X28 / W19-W28 (10个物理寄存器Xn和Wn是同一寄存器的不同视图)
// 注意Int32和Int64共享这10个物理GPR
const PhysReg kGPRPool[] = {
PhysReg::X19, PhysReg::X20, PhysReg::X21, PhysReg::X22,
PhysReg::X23, PhysReg::X24, PhysReg::X25, PhysReg::X26,
PhysReg::X27, PhysReg::X28,
};
constexpr int kNumGPR64 = sizeof(kGPR64Pool) / sizeof(kGPR64Pool[0]);
constexpr int kNumGPR = sizeof(kGPRPool) / sizeof(kGPRPool[0]);
// 获取对应的W寄存器
PhysReg ToWReg(PhysReg xreg) {
int idx = static_cast<int>(xreg) - static_cast<int>(PhysReg::X0);
return static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + idx);
}
// 浮点 32位: S8-S13 (callee-saved). S14-S15 保留为 spill scratch
// S16+ 是 caller-saved不能放入通用池外部调用会隐式 clobber
// 浮点寄存器池
const PhysReg kFPR32Pool[] = {
PhysReg::S8, PhysReg::S9, PhysReg::S10, PhysReg::S11,
PhysReg::S12, PhysReg::S13,
};
constexpr int kNumFPR32 = sizeof(kFPR32Pool) / sizeof(kFPR32Pool[0]);
// Spill scratch registers (每个类型 2 个,避免多 spilled-vreg 冲突)
// Spill scratch registers
const PhysReg kSpillScratchInt32[] = { PhysReg::W15, PhysReg::W14 };
const PhysReg kSpillScratchInt64[] = { PhysReg::X15, PhysReg::X14 };
const PhysReg kSpillScratchFloat[] = { PhysReg::S15, PhysReg::S14 };
@ -69,7 +65,6 @@ PhysReg GetSpillScratch(VRegClass rc, int idx) {
return kSpillScratchInt32[0];
}
// 判断是否为 callee-saved 寄存器
bool IsCalleeSaved(PhysReg reg) {
switch (reg) {
case PhysReg::W19: case PhysReg::W20: case PhysReg::W21: case PhysReg::W22:
@ -89,18 +84,23 @@ bool IsCalleeSaved(PhysReg reg) {
}
}
// 获取寄存器编号(用于 Wn/Xn 互斥检查)
int GetRegIndex(PhysReg reg) {
if (reg >= PhysReg::W0 && reg <= PhysReg::W30)
return static_cast<int>(reg) - static_cast<int>(PhysReg::W0);
if (reg >= PhysReg::X0 && reg <= PhysReg::X30)
return static_cast<int>(reg) - static_cast<int>(PhysReg::X0);
if (reg >= PhysReg::S0 && reg <= PhysReg::S31)
return 32 + static_cast<int>(reg) - static_cast<int>(PhysReg::S0);
// 获取GPR的统一编号0-9对应X19/W19到X28/W28
int GetGPRIndex(PhysReg reg) {
if (reg >= PhysReg::W19 && reg <= PhysReg::W28)
return static_cast<int>(reg) - static_cast<int>(PhysReg::W19);
if (reg >= PhysReg::X19 && reg <= PhysReg::X28)
return static_cast<int>(reg) - static_cast<int>(PhysReg::X19);
return -1;
}
// 获取FPR编号
int GetFPRIndex(PhysReg reg) {
if (reg >= PhysReg::S8 && reg <= PhysReg::S13)
return static_cast<int>(reg) - static_cast<int>(PhysReg::S8);
return -1;
}
// ========== 推断 vreg 类型(优先使用 Lowering 存储的类型) ==========
// ========== 推断 vreg 类型 ==========
VRegClass InferVRegClass(int vreg, MachineFunction& function) {
if (function.HasVRegType(vreg)) {
switch (function.GetVRegType(vreg)) {
@ -109,7 +109,7 @@ VRegClass InferVRegClass(int vreg, MachineFunction& function) {
case MachineFunction::VRegType::kInt32: return VRegClass::kInt32;
}
}
return VRegClass::kInt32; // 默认(不应到达,因为 Lowering 覆盖所有 vreg
return VRegClass::kInt32;
}
// ========== 指令编号 ==========
@ -128,17 +128,19 @@ void NumberInstructions(MachineFunction& function,
}
}
// ========== 构建活跃区间 ==========
std::vector<LiveInterval> ComputeLiveIntervals(MachineFunction& function) {
// ========== 计算活跃区间和数据流信息 ==========
std::vector<LiveInterval> ComputeLiveIntervals(
MachineFunction& function,
std::unordered_map<MachineBasicBlock*, std::set<int>>& liveIn,
std::unordered_map<MachineBasicBlock*, std::set<int>>& liveOut,
std::unordered_map<MachineInstr*, int>& instrToIdx,
std::vector<MachineInstr*>& idxToInstr) {
const auto& blocks = function.GetBasicBlocks();
if (blocks.empty()) return {};
// 编号
std::unordered_map<MachineInstr*, int> instrToIdx;
std::vector<MachineInstr*> idxToInstr;
std::map<int, MachineBasicBlock*> blockBoundary;
NumberInstructions(function, instrToIdx, idxToInstr, blockBoundary);
int total = static_cast<int>(idxToInstr.size());
// 收集所有 vreg
std::set<int> allVRegs;
@ -147,10 +149,9 @@ std::vector<LiveInterval> ComputeLiveIntervals(MachineFunction& function) {
for (int u : inst->GetUses()) allVRegs.insert(u);
}
// 每个 vreg 的活跃位置集合
// 每个 vreg 的活跃位置
std::unordered_map<int, std::set<int>> vregPositions;
// 基本块的 use/def 集合
struct BlockInfo {
std::set<int> use;
std::set<int> def;
@ -161,7 +162,6 @@ std::vector<LiveInterval> ComputeLiveIntervals(MachineFunction& function) {
for (const auto& bb : blocks) {
auto& info = blockInfo[bb.get()];
// 找到块的首尾指令序号
auto& insts = bb->GetInstructions();
if (!insts.empty()) {
info.startIdx = instrToIdx[&insts.front()];
@ -186,8 +186,7 @@ std::vector<LiveInterval> ComputeLiveIntervals(MachineFunction& function) {
}
}
// 数据流分析: liveIn/liveOut
std::unordered_map<MachineBasicBlock*, std::set<int>> liveIn, liveOut;
// 数据流分析
bool changed = true;
while (changed) {
changed = false;
@ -195,7 +194,6 @@ std::vector<LiveInterval> ComputeLiveIntervals(MachineFunction& function) {
MachineBasicBlock* bb = it->get();
auto& info = blockInfo[bb];
// liveOut = union of successors' liveIn
std::set<int> newLiveOut;
for (auto* succ : bb->GetSuccessors()) {
for (int v : liveIn[succ]) newLiveOut.insert(v);
@ -205,7 +203,6 @@ std::vector<LiveInterval> ComputeLiveIntervals(MachineFunction& function) {
changed = true;
}
// liveIn = use (liveOut - def)
std::set<int> newLiveIn = info.use;
for (int v : liveOut[bb]) {
if (info.def.count(v) == 0) newLiveIn.insert(v);
@ -218,11 +215,6 @@ std::vector<LiveInterval> ComputeLiveIntervals(MachineFunction& function) {
}
// 生成 LiveInterval
// 同时使用 liveIn 和 liveOut 扩展 end。仅从 liveOut 扩展不够:
// 考虑值在循环中 liveIn 于块 B在 B 中被使用)但其线性最后使用
// 出现在较早块的情况。若回边之后的块中定义了新 vreg 并分配到
// 同一物理寄存器,该寄存器会在循环入口被覆盖(如 graphColoring
// 中 &i 地址被 sxtw 覆盖)。将 end 扩展到 B.endIdx 可防止此问题。
std::vector<LiveInterval> intervals;
for (int vreg : allVRegs) {
auto it = vregPositions.find(vreg);
@ -245,141 +237,295 @@ std::vector<LiveInterval> ComputeLiveIntervals(MachineFunction& function) {
intervals.emplace_back(vreg, start, end, rc);
}
std::sort(intervals.begin(), intervals.end(),
[](const LiveInterval& a, const LiveInterval& b) {
return a.start < b.start;
});
return intervals;
}
// ========== 寄存器池选择 ==========
const PhysReg* GetRegPool(VRegClass rc, int& count) {
switch (rc) {
case VRegClass::kInt32: count = kNumGPR32; return kGPR32Pool;
case VRegClass::kInt64: count = kNumGPR64; return kGPR64Pool;
case VRegClass::kFloat32: count = kNumFPR32; return kFPR32Pool;
// ========== 图着色核心数据结构 ==========
struct IGNode {
int vreg;
VRegClass reg_class;
std::set<int> neighbors;
int degree;
bool removed;
bool is_spill_candidate;
double spill_cost;
};
struct StackEntry {
int vreg;
bool is_spill_candidate;
};
// 干涉图:按"寄存器类别组"构建
// GPR组: Int32 + Int64 (共享物理寄存器)
// FPR组: Float32
struct InterferenceGraph {
std::unordered_map<int, IGNode> nodes;
std::set<int> remaining;
int k; // 可用颜色数
bool is_gpr; // true=GPR组(Int32+Int64), false=FPR组
};
// ========== 计算使用频率 ==========
std::unordered_map<int, int> ComputeUseCounts(MachineFunction& function) {
std::unordered_map<int, int> useCounts;
for (auto& bb : function.GetBasicBlocks()) {
for (auto& inst : bb->GetInstructions()) {
for (int u : inst.GetUses()) useCounts[u]++;
for (int d : inst.GetDefs()) useCounts[d]++;
}
}
count = 0;
return nullptr;
return useCounts;
}
// ========== 活跃区间比较(按 end 排序,用于 active 集合) ==========
struct ByEnd {
bool operator()(const LiveInterval* a, const LiveInterval* b) const {
if (a->end != b->end) return a->end < b->end;
return a->vreg < b->vreg; // 打破平局:相同 end 按 vreg 区分
// ========== 构建干涉图 ==========
// 关键修复Int32和Int64在同一个干涉图中因为共享物理GPR
InterferenceGraph BuildInterferenceGraph(
MachineFunction& function,
const std::vector<LiveInterval>& intervals,
const std::unordered_map<MachineBasicBlock*, std::set<int>>& liveIn,
const std::unordered_map<MachineInstr*, int>& instrToIdx,
bool buildGPR) { // true=构建GPR图(Int32+Int64), false=构建FPR图
InterferenceGraph ig;
ig.is_gpr = buildGPR;
ig.k = buildGPR ? kNumGPR : kNumFPR32;
// 收集vreg
std::set<int> vregs;
for (const auto& iv : intervals) {
if (buildGPR) {
// GPR图包含Int32和Int64
if (iv.reg_class == VRegClass::kInt32 || iv.reg_class == VRegClass::kInt64) {
vregs.insert(iv.vreg);
}
} else {
// FPR图只包含Float32
if (iv.reg_class == VRegClass::kFloat32) {
vregs.insert(iv.vreg);
}
}
}
if (vregs.empty()) return ig;
// 初始化节点
auto useCounts = ComputeUseCounts(function);
for (int v : vregs) {
IGNode node;
node.vreg = v;
// 找到vreg的reg_class
for (const auto& iv : intervals) {
if (iv.vreg == v) {
node.reg_class = iv.reg_class;
break;
}
}
node.degree = 0;
node.removed = false;
node.is_spill_candidate = false;
node.spill_cost = useCounts.count(v) ? useCounts[v] : 1.0;
ig.nodes[v] = std::move(node);
ig.remaining.insert(v);
}
};
// ========== 线性扫描寄存器分配(单函数) ==========
void RunRegAllocFunc(MachineFunction& function) {
auto intervals = ComputeLiveIntervals(function);
if (intervals.empty()) return;
// vreg → 分配的物理寄存器
std::unordered_map<int, PhysReg> vregToPhys;
// 构建干涉边:反向遍历指令,正确模拟活跃集合
for (const auto& bb : function.GetBasicBlocks()) {
auto& insts = bb->GetInstructions();
if (insts.empty()) continue;
// 活跃区间集合(按 end 排序)
std::set<const LiveInterval*, ByEnd> active;
// 反向遍历从liveOut开始
std::set<int> live = liveIn.at(bb.get());
// Spill 槽vreg → FrameIndex
std::unordered_map<int, int> spillSlots;
// 注意:我们需要正向检查活跃集合,但用正确的数据流
// 更简单的方法:直接用活跃区间重叠来构建边
}
// 寄存器占用跟踪
std::set<int> occupiedRegIndices;
// 每个寄存器池的空闲/占用状态
auto allocReg = [&](const LiveInterval& interval) -> PhysReg {
int poolSize = 0;
const PhysReg* pool = GetRegPool(interval.reg_class, poolSize);
for (int i = 0; i < poolSize; ++i) {
int idx = GetRegIndex(pool[i]);
if (occupiedRegIndices.count(idx) == 0) {
occupiedRegIndices.insert(idx);
if (IsCalleeSaved(pool[i])) {
function.MarkCalleeSaved(pool[i]);
}
return pool[i];
// 用活跃区间重叠构建干涉边(更可靠)
for (auto it1 = vregs.begin(); it1 != vregs.end(); ++it1) {
auto it2 = it1;
++it2;
for (; it2 != vregs.end(); ++it2) {
int a = *it1, b = *it2;
// 找到a和b的活跃区间
const LiveInterval* ivA = nullptr;
const LiveInterval* ivB = nullptr;
for (const auto& iv : intervals) {
if (iv.vreg == a) ivA = &iv;
if (iv.vreg == b) ivB = &iv;
}
if (!ivA || !ivB) continue;
// 检查区间是否重叠(包含端点)
// 两个区间[s1,e1]和[s2,e2]重叠当且仅当:
// max(s1,s2) <= min(e1,e2)
int maxStart = std::max(ivA->start, ivB->start);
int minEnd = std::min(ivA->end, ivB->end);
if (maxStart <= minEnd) {
// 活跃区间重叠,添加干涉边
ig.nodes[a].neighbors.insert(b);
ig.nodes[b].neighbors.insert(a);
}
}
// 无法分配
return PhysReg::W0; // will trigger spill logic
};
}
auto freeReg = [&](PhysReg reg) {
occupiedRegIndices.erase(GetRegIndex(reg));
};
// 计算度数和spill cost
for (auto& [vreg, node] : ig.nodes) {
node.degree = static_cast<int>(node.neighbors.size());
if (node.degree > 0) {
node.spill_cost = node.spill_cost / node.degree;
}
}
return ig;
}
// ========== Simplify阶段 ==========
std::vector<StackEntry> Simplify(InterferenceGraph& ig) {
std::vector<StackEntry> stack;
std::queue<int> worklist;
auto isFreeReg = [&](VRegClass rc) -> bool {
int poolSize = 0;
const PhysReg* pool = GetRegPool(rc, poolSize);
for (int i = 0; i < poolSize; ++i) {
if (occupiedRegIndices.count(GetRegIndex(pool[i])) == 0)
return true;
for (int v : ig.remaining) {
if (ig.nodes[v].degree < ig.k) {
worklist.push(v);
}
return false;
};
}
// 线性扫描
for (auto& interval : intervals) {
// 1. Expire old intervals
std::vector<const LiveInterval*> toRemove;
for (auto* act : active) {
if (act->end < interval.start) {
toRemove.push_back(act);
auto it = vregToPhys.find(act->vreg);
if (it != vregToPhys.end()) {
freeReg(it->second);
while (!ig.remaining.empty()) {
if (!worklist.empty()) {
int v = worklist.front();
worklist.pop();
if (!ig.remaining.count(v) || ig.nodes[v].removed) continue;
if (ig.nodes[v].degree >= ig.k) continue;
stack.push_back({v, false});
ig.nodes[v].removed = true;
ig.remaining.erase(v);
for (int u : ig.nodes[v].neighbors) {
if (ig.remaining.count(u) && !ig.nodes[u].removed) {
ig.nodes[u].degree--;
if (ig.nodes[u].degree < ig.k) {
worklist.push(u);
}
}
}
} else {
double bestCost = 1e300;
int bestVreg = -1;
for (int v : ig.remaining) {
if (ig.nodes[v].removed) continue;
if (ig.nodes[v].spill_cost < bestCost) {
bestCost = ig.nodes[v].spill_cost;
bestVreg = v;
}
}
if (bestVreg < 0) break;
stack.push_back({bestVreg, true});
ig.nodes[bestVreg].removed = true;
ig.remaining.erase(bestVreg);
for (int u : ig.nodes[bestVreg].neighbors) {
if (ig.remaining.count(u) && !ig.nodes[u].removed) {
ig.nodes[u].degree--;
if (ig.nodes[u].degree < ig.k) {
worklist.push(u);
}
}
}
}
for (auto* act : toRemove) {
active.erase(act);
}
return stack;
}
// ========== Select阶段 ==========
// 关键修复颜色用GPR索引0-9表示分配时根据vreg类型选择Xn或Wn
std::pair<std::unordered_map<int, PhysReg>, std::set<int>> SelectColors(
const InterferenceGraph& origIg,
const std::vector<StackEntry>& stack,
const std::vector<LiveInterval>& intervals,
MachineFunction& function) {
std::unordered_map<int, PhysReg> coloring;
std::set<int> actualSpills;
// 跟踪已分配的GPR索引0-9
std::unordered_map<int, int> colorToVReg;
// 逆序弹出栈
for (int i = static_cast<int>(stack.size()) - 1; i >= 0; --i) {
const auto& entry = stack[i];
int vreg = entry.vreg;
const auto& node = origIg.nodes.at(vreg);
// 收集已着色邻居使用的颜色GPR索引
std::set<int> usedColors;
for (int neighbor : node.neighbors) {
auto it = coloring.find(neighbor);
if (it != coloring.end()) {
int colorIdx = -1;
if (origIg.is_gpr) {
colorIdx = GetGPRIndex(it->second);
} else {
colorIdx = GetFPRIndex(it->second);
}
if (colorIdx >= 0) {
usedColors.insert(colorIdx);
}
}
}
// 2. 尝试分配
if (isFreeReg(interval.reg_class)) {
PhysReg reg = allocReg(interval);
vregToPhys[interval.vreg] = reg;
active.insert(&interval);
} else {
// 3. Spill: 选择 active 中最晚结束的区间
if (active.empty()) {
// 所有寄存器都被占用Wn/Xn 别名冲突等边缘情况)
// 直接 spill 当前 interval
int slotSize = (interval.reg_class == VRegClass::kInt64) ? 8 : 4;
int slot = function.CreateSpillSlot(slotSize);
spillSlots[interval.vreg] = slot;
continue;
// 找第一个可用颜色
int chosenColor = -1;
for (int c = 0; c < origIg.k; ++c) {
if (!usedColors.count(c)) {
chosenColor = c;
break;
}
const LiveInterval* spillCand = *active.rbegin(); // 最晚结束
if (spillCand->end > interval.end) {
// Spill spillCand
PhysReg reg = vregToPhys[spillCand->vreg];
vregToPhys.erase(spillCand->vreg);
freeReg(reg);
active.erase(spillCand);
// 为其分配 spill slot
int slotSize = (spillCand->reg_class == VRegClass::kInt64) ? 8 : 4;
int slot = function.CreateSpillSlot(slotSize);
spillSlots[spillCand->vreg] = slot;
}
// 分配当前 interval
PhysReg newReg = allocReg(interval);
vregToPhys[interval.vreg] = newReg;
active.insert(&interval);
if (chosenColor >= 0) {
// 根据vreg类型选择物理寄存器
PhysReg physReg;
if (origIg.is_gpr) {
if (node.reg_class == VRegClass::kInt64) {
physReg = kGPRPool[chosenColor]; // Xn
} else {
physReg = ToWReg(kGPRPool[chosenColor]); // Wn
}
} else {
// Spill 当前 interval
int slotSize = (interval.reg_class == VRegClass::kInt64) ? 8 : 4;
int slot = function.CreateSpillSlot(slotSize);
spillSlots[interval.vreg] = slot;
// 不分配物理寄存器
physReg = kFPR32Pool[chosenColor]; // Sn
}
coloring[vreg] = physReg;
if (IsCalleeSaved(physReg)) {
function.MarkCalleeSaved(physReg);
}
} else {
actualSpills.insert(vreg);
}
}
// ========== 重写指令VReg → PhysReg + spill/reload ==========
return {coloring, actualSpills};
}
// ========== 重写指令 ==========
void RewriteInstructions(
MachineFunction& function,
const std::unordered_map<int, PhysReg>& vregToPhys,
const std::unordered_map<int, int>& spillSlots,
const std::vector<LiveInterval>& intervals) {
auto getSpillRC = [&](int vreg) -> VRegClass {
for (auto& iv : intervals) {
if (iv.vreg == vreg) return iv.reg_class;
}
return VRegClass::kInt32;
};
for (auto& bb : function.GetBasicBlocks()) {
std::vector<MachineInstr> newInsts;
auto& insts = bb->GetInstructions();
@ -389,15 +535,7 @@ void RunRegAllocFunc(MachineFunction& function) {
std::vector<int>& defs = inst.GetDefs();
std::vector<int>& uses = inst.GetUses();
// 辅助:获取 spilled vreg 的类型
auto getSpillRC = [&](int vreg) -> VRegClass {
for (auto& iv : intervals) {
if (iv.vreg == vreg) return iv.reg_class;
}
return VRegClass::kInt32;
};
// 收集此指令中需要 reload 的 spilled vreg去重
// 收集需要reload的spilled use
std::vector<int> spilledUses;
{
std::set<int> seen;
@ -408,10 +546,10 @@ void RunRegAllocFunc(MachineFunction& function) {
}
}
// === 插入 use 前的 reload(每个 spilled vreg 用不同 scratch ===
// 插入reload
for (size_t si = 0; si < spilledUses.size(); ++si) {
int vreg = spilledUses[si];
int slot = spillSlots[vreg];
int slot = spillSlots.at(vreg);
PhysReg loadReg;
auto it = vregToPhys.find(vreg);
if (it != vregToPhys.end()) {
@ -423,8 +561,7 @@ void RunRegAllocFunc(MachineFunction& function) {
std::vector<Operand>{Operand::Reg(loadReg), Operand::FrameIndex(slot)});
}
// === 替换 VReg 操作数为 PhysReg ===
// 跟踪每条指令中 spilled vreg 的 scratch 索引
// 替换VReg为PhysReg
int spillUseIdx = 0;
for (auto& op : ops) {
if (op.GetKind() == Operand::Kind::VReg) {
@ -433,18 +570,11 @@ void RunRegAllocFunc(MachineFunction& function) {
if (it != vregToPhys.end()) {
op = Operand::Reg(it->second);
} else {
// spilled 或未分配:用 spill scratch
int idx = 0;
if (spillSlots.count(vreg)) {
for (size_t si = 0; si < spilledUses.size(); ++si) {
if (spilledUses[si] == vreg) { idx = static_cast<int>(si); break; }
}
} else {
// 防御vreg 未在 vregToPhys 或 spillSlots 中,创建临时 spill slot
VRegClass rc = getSpillRC(vreg);
int slotSize = (rc == VRegClass::kInt64) ? 8 : 4;
int slot = function.CreateSpillSlot(slotSize);
spillSlots[vreg] = slot;
}
op = Operand::Reg(GetSpillScratch(getSpillRC(vreg), idx));
spillUseIdx++;
@ -454,8 +584,7 @@ void RunRegAllocFunc(MachineFunction& function) {
newInsts.push_back(inst);
// === 插入 def 后的 store用于 spilled vreg ===
// 收集此指令中 spilled def vreg去重
// 插入def后的store
std::vector<int> spilledDefs;
{
std::set<int> seen;
@ -467,7 +596,7 @@ void RunRegAllocFunc(MachineFunction& function) {
}
for (size_t si = 0; si < spilledDefs.size(); ++si) {
int vreg = spilledDefs[si];
int slot = spillSlots[vreg];
int slot = spillSlots.at(vreg);
PhysReg storeReg;
auto it = vregToPhys.find(vreg);
if (it != vregToPhys.end()) {
@ -481,8 +610,67 @@ void RunRegAllocFunc(MachineFunction& function) {
}
insts = std::move(newInsts);
}
}
// ========== 图着色寄存器分配主函数 ==========
void RunGraphColoringRegAlloc(MachineFunction& function) {
// 1. 数据流分析和活跃区间
std::unordered_map<MachineBasicBlock*, std::set<int>> liveIn, liveOut;
std::unordered_map<MachineInstr*, int> instrToIdx;
std::vector<MachineInstr*> idxToInstr;
auto intervals = ComputeLiveIntervals(function, liveIn, liveOut, instrToIdx, idxToInstr);
if (intervals.empty()) return;
// 2. 分配结果
std::unordered_map<int, PhysReg> vregToPhys;
std::unordered_map<int, int> spillSlots;
// 3. 构建GPR干涉图Int32 + Int64共享物理寄存器
{
InterferenceGraph ig = BuildInterferenceGraph(
function, intervals, liveIn, instrToIdx, true);
if (!ig.nodes.empty()) {
auto stack = Simplify(ig);
auto [coloring, spills] = SelectColors(ig, stack, intervals, function);
for (auto& [vreg, phys] : coloring) {
vregToPhys[vreg] = phys;
}
for (int vreg : spills) {
if (spillSlots.count(vreg)) continue;
VRegClass rc = VRegClass::kInt32;
for (const auto& iv : intervals) {
if (iv.vreg == vreg) { rc = iv.reg_class; break; }
}
int slotSize = (rc == VRegClass::kInt64) ? 8 : 4;
int slot = function.CreateSpillSlot(slotSize);
spillSlots[vreg] = slot;
}
}
}
// 4. 构建FPR干涉图Float32
{
InterferenceGraph ig = BuildInterferenceGraph(
function, intervals, liveIn, instrToIdx, false);
if (!ig.nodes.empty()) {
auto stack = Simplify(ig);
auto [coloring, spills] = SelectColors(ig, stack, intervals, function);
for (auto& [vreg, phys] : coloring) {
vregToPhys[vreg] = phys;
}
for (int vreg : spills) {
if (spillSlots.count(vreg)) continue;
int slot = function.CreateSpillSlot(4);
spillSlots[vreg] = slot;
}
}
}
// 5. 重写指令
RewriteInstructions(function, vregToPhys, spillSlots, intervals);
// 清除所有指令的 def/useRA 完成后不再需要)
// 6. 清除def/use标记
for (auto& bb : function.GetBasicBlocks()) {
for (auto& inst : bb->GetInstructions()) {
inst.GetDefs().clear();
@ -491,12 +679,13 @@ void RunRegAllocFunc(MachineFunction& function) {
}
}
} // namespace
} // namespace
// ========== 模块入口 ==========
void RunRegAlloc(MachineModule& module) {
for (auto& func : module.GetFunctions()) {
RunRegAllocFunc(*func);
RunGraphColoringRegAlloc(*func);
}
}
} // namespace mir
} // namespace mir
Loading…
Cancel
Save