feat(mir): 线性扫描寄存器分配初始实现(WIP,--regalloc=linear 可用)

- Wimmer & Mössenböck (2005) 优化区间分割算法
- 685 行,支持 GP/FP 寄存器池
- 目前通过简单用例,循环函数有寄存器映射 bug(25_while_if 无限循环)
- 默认仍使用图着色,线性扫描可通过 CLI 切换
lzk
lzkk 4 days ago
parent a9ebfdc0e0
commit 28ad162de4

@ -12,6 +12,7 @@ struct CLIOptions {
bool show_help = false;
bool optimize = false; // -O 或 -O1
int opt_level = 0; // 优化级别: 0, 1, 2, 3
std::string regalloc = "graphcoloring"; // 寄存器分配器: graphcoloring 或 linear
};
CLIOptions ParseCLI(int argc, char** argv);

@ -0,0 +1,694 @@
#include "mir/MIR.h"
#include <algorithm>
#include <limits>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "utils/Log.h"
namespace mir
{
namespace
{
// ---- AArch64 可分配寄存器 --------------------------------------------
// GP 可分配x8(间接结果)/x9-x12/temp/x15/x16-x17/IP0-IP1/x19-x28/callee-saved
// x0-x7 参数传递x13-x14 临时(被排除避免调用冲突)x18 平台x29-31 保留
static const int GP_ALLOCATABLE[] = {8, 9, 10, 11, 12, 15, 16, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28};
static const int K_GP = 18;
// FP 可分配s8-s31
static const int FP_ALLOCATABLE[] = {8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
static const int K_FP = 24;
// 寄存器号 → PhysReg 转换
static PhysReg NumberToPhysReg(int num, VRegClass vc)
{
if (vc == VRegClass::Float)
return static_cast<PhysReg>(static_cast<int>(PhysReg::S0) + num);
if (vc == VRegClass::Ptr)
return static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + num);
return static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + num);
}
// 可分配索引 → PhysReg
static PhysReg AllocIdxToPhysReg(int idx, VRegClass vc)
{
if (vc == VRegClass::Float)
return NumberToPhysReg(FP_ALLOCATABLE[idx], VRegClass::Float);
return NumberToPhysReg(GP_ALLOCATABLE[idx], vc);
}
// ---- 工具函数 --------------------------------------------------------
static bool HasVRegDef(Opcode opcode)
{
switch (opcode)
{
case Opcode::MovImm:
case Opcode::LoadStack:
case Opcode::LoadGlobal:
case Opcode::LoadGlobalAddr:
case Opcode::LoadStackAddr:
case Opcode::LoadMem:
case Opcode::AddRR:
case Opcode::SubRR:
case Opcode::AddImm:
case Opcode::SubImm:
case Opcode::MulRR:
case Opcode::DivRR:
case Opcode::ModRR:
case Opcode::AndRR:
case Opcode::OrRR:
case Opcode::XorRR:
case Opcode::ShlRR:
case Opcode::ShrRR:
case Opcode::AsrRR:
case Opcode::Asr64RR:
case Opcode::Uxtw:
case Opcode::Sxtw:
case Opcode::CSet:
case Opcode::Csel:
case Opcode::Smull:
case Opcode::Msub:
case Opcode::NegRR:
case Opcode::FAddRR:
case Opcode::FSubRR:
case Opcode::FMulRR:
case Opcode::FDivRR:
case Opcode::Scvtf:
case Opcode::FCvtzs:
case Opcode::FMovWS:
case Opcode::MovReg:
case Opcode::Call:
return true;
default:
return false;
}
}
// ---- 核心数据结构 -----------------------------------------------------
// 活跃列表中存活的 vreg + 所占用寄存器
struct ActiveInterval
{
LiveInterval *interval;
int phys_reg; // 可分配数组中的索引
};
// 每个 vreg 的活区段:位置范围 + 寄存器分配
struct VRegRange
{
int start; // 指令位置(全局索引)
int end;
int reg_idx; // 可分配数组索引,-1 表示已溢出
};
// 保存点:在指定位置需要把 vreg 从寄存器溢出到栈
struct SavePoint
{
int pos; // 指令位置
int vreg; // 溢出 vreg
int reg_idx; // 寄存器
int spill_slot;
};
// ---- 分配器 ----------------------------------------------------------
// 从活跃列表中淘汰 end < pos 的区间
static void ExpireOldIntervals(std::vector<ActiveInterval> &active,
std::vector<bool> &reg_free,
int pos)
{
for (auto &a : active)
{
if (a.interval->end < pos)
reg_free[a.phys_reg] = true;
}
active.erase(
std::remove_if(active.begin(), active.end(),
[pos](const ActiveInterval &a)
{ return a.interval->end < pos; }),
active.end());
}
static int FindFreeReg(const std::vector<bool> &reg_free)
{
for (size_t i = 0; i < reg_free.size(); ++i)
if (reg_free[i])
return static_cast<int>(i);
return -1;
}
// 返回活跃列表中 end 最大者的索引
static int SelectSpill(const std::vector<ActiveInterval> &active)
{
int farthest = -1;
int farthest_end = -1;
for (size_t i = 0; i < active.size(); ++i)
{
if (active[i].interval->end > farthest_end)
{
farthest_end = active[i].interval->end;
farthest = static_cast<int>(i);
}
}
return farthest;
}
static int GetOrCreateSpillSlot(MachineFunction &func, int vreg,
std::unordered_map<int, int> &vreg_to_slot)
{
auto it = vreg_to_slot.find(vreg);
if (it != vreg_to_slot.end())
return it->second;
int size = (func.GetVRegClass(vreg) == VRegClass::Ptr) ? 8 : 4;
int slot = func.CreateFrameIndex(size);
vreg_to_slot[vreg] = slot;
return slot;
}
// ---- 前向声明 --------------------------------------------------------
static void RewriteWithAllocation(
MachineFunction &func,
const std::vector<std::vector<VRegRange>> &vreg_ranges,
const std::unordered_map<int, int> &vreg_to_slot,
std::vector<SavePoint> &save_points);
// ---- 主分配算法Wimmer & Mössenböck (2005) 优化区间分割 ----------------
static void RunLinearScan(MachineFunction &func)
{
auto intervals = ComputeInstLiveness(func);
if (intervals.empty())
return;
const int num_vregs = func.GetNumVRegs();
// 按 start 排序
std::sort(intervals.begin(), intervals.end(),
[](const LiveInterval &a, const LiveInterval &b)
{ return a.start < b.start; });
// 分配结果
std::vector<std::vector<VRegRange>> vreg_ranges(num_vregs);
std::vector<bool> vreg_has_range(num_vregs, false);
std::unordered_map<int, int> vreg_to_slot; // vreg -> spill slot
std::vector<SavePoint> save_points;
// 寄存器空闲表
std::vector<bool> gp_free(K_GP, true);
std::vector<bool> fp_free(K_FP, true);
// 活跃列表(按 end 不排序SelectSpill 扫描查找)
std::vector<ActiveInterval> active;
// 工作队列start 有序) + 分割产生的新区间(追加到队尾)
std::vector<LiveInterval> queue = intervals;
for (size_t qi = 0; qi < queue.size(); ++qi)
{
LiveInterval &cur = queue[qi];
// 同一 vreg 可能有多个 LiveInterval分割产生跳过已处理(已有范围)的
if (cur.vreg >= 0 && cur.vreg < num_vregs &&
vreg_has_range[cur.vreg])
continue;
// 选择对应寄存器池
const int K = (cur.vreg_class == VRegClass::Float) ? K_FP : K_GP;
std::vector<bool> &reg_free = (cur.vreg_class == VRegClass::Float) ? fp_free : gp_free;
// 1. 淘汰已经结束的活跃区间
ExpireOldIntervals(active, reg_free, cur.start);
// 2. 尝试找空闲寄存器
int free_reg = FindFreeReg(reg_free);
if (free_reg >= 0)
{
// 分配空闲寄存器
reg_free[free_reg] = false;
active.push_back({&cur, free_reg});
vreg_ranges[cur.vreg].push_back({cur.start, cur.end, free_reg});
vreg_has_range[cur.vreg] = true;
}
else
{
// 3. 需要溢出——选择 end 最大的活跃区间
int spill_idx = SelectSpill(active);
if (spill_idx < 0)
{
// 没有活跃区间,强制溢出当前
int slot = GetOrCreateSpillSlot(func, cur.vreg, vreg_to_slot);
vreg_ranges[cur.vreg].push_back({cur.start, cur.end, -1});
vreg_has_range[cur.vreg] = true;
cur.spilled = true;
cur.spill_slot = slot;
continue;
}
ActiveInterval &spill_cand = active[spill_idx];
if (spill_cand.interval->end > cur.end)
{
// 4a. 最优分割:偷走最远 end 的寄存器给当前,被偷者的后半段溢出
int stolen_reg = spill_cand.phys_reg;
int evicted_vreg = spill_cand.interval->vreg;
// 割开被驱逐 vreg 的范围:前半段保留寄存器,后半段溢出
// 找到当前活跃的范围并截断
auto &ranges = vreg_ranges[evicted_vreg];
if (!ranges.empty())
{
VRegRange &last = ranges.back();
if (last.reg_idx == stolen_reg)
{
// 把 last.end 截断到 cur.end后半段新建溢出范围
int orig_end = last.end;
last.end = cur.end;
vreg_ranges[evicted_vreg].push_back({cur.end + 1, orig_end, -1});
// 在此位置需要保存被驱逐的值到栈
int slot = GetOrCreateSpillSlot(func, evicted_vreg, vreg_to_slot);
save_points.push_back({cur.start, evicted_vreg, stolen_reg, slot});
// 把分割后的溢出部分送回队列(它以 evicted 的 vreg 标识,但 vreg_has_range 已为真)
LiveInterval split_li;
split_li.vreg = evicted_vreg;
split_li.start = cur.end + 1;
split_li.end = orig_end;
split_li.vreg_class = spill_cand.interval->vreg_class;
split_li.spilled = true;
split_li.spill_slot = slot;
// vreg_has_range 标记已在上面设置split_li 的处理会被跳过
}
}
// 从活跃列表移除被驱逐项
active.erase(active.begin() + spill_idx);
// 当前 vreg 获得偷来的寄存器
reg_free[stolen_reg] = false;
active.push_back({&cur, stolen_reg});
vreg_ranges[cur.vreg].push_back({cur.start, cur.end, stolen_reg});
vreg_has_range[cur.vreg] = true;
}
else
{
// 4b. 没有更远 end 的——直接溢出当前
int slot = GetOrCreateSpillSlot(func, cur.vreg, vreg_to_slot);
vreg_ranges[cur.vreg].push_back({cur.start, cur.end, -1});
vreg_has_range[cur.vreg] = true;
cur.spilled = true;
cur.spill_slot = slot;
// 不占用寄存器,不加入活跃列表
}
}
}
// ---- 重写指令 ----------------------------------------------------------
RewriteWithAllocation(func, vreg_ranges, vreg_to_slot, save_points);
}
// ---- 临时寄存器选择器 ------------------------------------------------
// 在已分配寄存器中找一个不被当前指令 def/use 占用的作为 scratch
static int PickGPScratchReg(const MachineInstr &inst,
const std::unordered_map<int, int> &pos_regs)
{
// x14 优先(不在可分配列表中,天然安全)
bool x14_free = true;
for (const auto &op : inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::Reg)
{
int r = static_cast<int>(op.GetReg()) - static_cast<int>(PhysReg::W0);
if (r == 14) { x14_free = false; break; }
}
}
if (x14_free)
{
// 检查当前在寄存器的 vreg 是否占用 14
bool other_used = false;
for (const auto &kv : pos_regs)
{
if (kv.second == 14) { other_used = true; break; }
}
if (!other_used) return 14;
}
// 遍历可分配列表找一个不冲突的
for (int r : GP_ALLOCATABLE)
{
bool conflict = false;
for (const auto &op : inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::Reg)
{
int pr = static_cast<int>(op.GetReg()) - static_cast<int>(PhysReg::W0);
if (pr == r) { conflict = true; break; }
}
}
if (!conflict)
{
bool other_used = false;
for (const auto &kv : pos_regs)
{
if (kv.second == r) { other_used = true; break; }
}
if (!other_used) return r;
}
}
return GP_ALLOCATABLE[0];
}
static int PickFPScratchReg(const MachineInstr &inst,
const std::unordered_map<int, int> &pos_regs)
{
for (int r : FP_ALLOCATABLE)
{
bool conflict = false;
for (const auto &op : inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::Reg)
{
int pr = static_cast<int>(op.GetReg()) - static_cast<int>(PhysReg::S0);
if (pr == r) { conflict = true; break; }
}
}
if (!conflict)
{
bool other_used = false;
for (const auto &kv : pos_regs)
{
if (kv.second == r) { other_used = true; break; }
}
if (!other_used) return r;
}
}
return FP_ALLOCATABLE[0];
}
// ---- 保存点排序 --------------------------------------------------------
static void SortSavePoints(std::vector<SavePoint> &save_points)
{
std::sort(save_points.begin(), save_points.end(),
[](const SavePoint &a, const SavePoint &b)
{ return a.pos < b.pos; });
}
// ---- RewriteWithAllocation -------------------------------------------
static void RewriteWithAllocation(
MachineFunction &func,
const std::vector<std::vector<VRegRange>> &vreg_ranges,
const std::unordered_map<int, int> &vreg_to_slot,
std::vector<SavePoint> &save_points)
{
SortSavePoints(save_points);
size_t next_save = 0;
// 全局指令位置计数器(基于原始指令顺序)
int global_pos = 0;
for (auto &block : func.GetBlocks())
{
std::vector<MachineInstr> new_insts;
for (auto &inst : block->GetInstructions())
{
auto opcode = inst.GetOpcode();
auto &ops = inst.GetOperands();
// ---- 保存点:在此位置前保存被驱逐 vreg 的值 ----
while (next_save < save_points.size() &&
save_points[next_save].pos <= global_pos)
{
const auto &sp = save_points[next_save];
VRegClass vc = func.GetVRegClass(sp.vreg);
PhysReg pr = AllocIdxToPhysReg(sp.reg_idx, vc);
new_insts.push_back(
MachineInstr(Opcode::StoreStack,
{Operand::Reg(pr), Operand::FrameIndex(sp.spill_slot)}));
++next_save;
}
// ---- 确定当前位置 def/use 的 vreg 对应哪个范围 ----
// 构建 "当前位置已在使用中的寄存器" 集合(用于 scratch 选择)
std::unordered_map<int, int> pos_regs; // vreg -> reg_idx at this position
std::unordered_map<int, int> vreg_range_idx; // vreg -> range index
bool has_def = HasVRegDef(opcode);
int def_vreg = -1;
for (size_t i = 0; i < ops.size(); ++i)
{
if (ops[i].GetKind() != Operand::Kind::VReg)
continue;
// 跳过 def 位置上已经被处理过的
if (has_def && i == 0)
{
def_vreg = ops[i].GetVRegId();
continue;
}
int v = ops[i].GetVRegId();
if (v < 0 || v >= static_cast<int>(vreg_ranges.size()))
continue;
// 寻找覆盖当前位置的范围
int reg_idx = -1;
for (size_t ri = 0; ri < vreg_ranges[v].size(); ++ri)
{
const auto &rng = vreg_ranges[v][ri];
if (rng.start <= global_pos && global_pos <= rng.end)
{
reg_idx = rng.reg_idx;
break;
}
}
if (reg_idx >= 0)
pos_regs[v] = reg_idx;
}
// 也处理 def vreg
if (def_vreg >= 0 && def_vreg < static_cast<int>(vreg_ranges.size()))
{
int reg_idx = -1;
for (size_t ri = 0; ri < vreg_ranges[def_vreg].size(); ++ri)
{
const auto &rng = vreg_ranges[def_vreg][ri];
if (rng.start <= global_pos && global_pos <= rng.end)
{
reg_idx = rng.reg_idx;
break;
}
}
if (reg_idx >= 0)
pos_regs[def_vreg] = reg_idx;
}
// ---- 处理溢出 uses插入 LoadStack ----
// 收集所有溢出 use vreg在当前范围中 reg_idx == -1
std::unordered_set<int> spilled_uses;
for (size_t i = 0; i < ops.size(); ++i)
{
if (ops[i].GetKind() != Operand::Kind::VReg)
continue;
if (has_def && i == 0)
continue;
int v = ops[i].GetVRegId();
if (v < 0 || v >= static_cast<int>(vreg_ranges.size()))
continue;
// 检查范围:如果覆盖当前位置的范围 reg_idx == -1则需加载
bool needs_load = false;
for (const auto &rng : vreg_ranges[v])
{
if (rng.start <= global_pos && global_pos <= rng.end)
{
if (rng.reg_idx == -1)
needs_load = true;
break;
}
}
if (needs_load && !spilled_uses.count(v))
spilled_uses.insert(v);
}
for (int v : spilled_uses)
{
auto slot_it = vreg_to_slot.find(v);
if (slot_it == vreg_to_slot.end())
continue;
int slot = slot_it->second;
VRegClass vc = func.GetVRegClass(v);
int scratch = (vc == VRegClass::Float)
? PickFPScratchReg(inst, pos_regs)
: PickGPScratchReg(inst, pos_regs);
PhysReg load_reg = NumberToPhysReg(scratch, vc);
new_insts.push_back(
MachineInstr(Opcode::LoadStack,
{Operand::Reg(load_reg), Operand::FrameIndex(slot)}));
// 将该 vreg 在此处映射到此 scratch 寄存器
pos_regs[v] = scratch;
// 替换指令中的该 vreg 操作数
for (auto &op : ops)
{
if (op.GetKind() == Operand::Kind::VReg && op.GetVRegId() == v)
{
const_cast<Operand &>(op) = Operand::Reg(load_reg);
}
}
}
// ---- 替换所有 VReg 操作数为 PhysReg ----
for (auto &op : ops)
{
if (op.GetKind() != Operand::Kind::VReg)
continue;
int v = op.GetVRegId();
VRegClass vc = func.GetVRegClass(v);
if (v < 0 || v >= static_cast<int>(vreg_ranges.size()))
{
// vreg 超出范围(临时 vreg用 scratch 替换
int fallback = (vc == VRegClass::Float)
? PickFPScratchReg(inst, pos_regs)
: PickGPScratchReg(inst, pos_regs);
const_cast<Operand &>(op) = Operand::Reg(NumberToPhysReg(fallback, vc));
continue;
}
// 找到当前位置对应的 reg
int reg_idx = -1;
for (const auto &rng : vreg_ranges[v])
{
if (rng.start <= global_pos && global_pos <= rng.end)
{
reg_idx = rng.reg_idx;
break;
}
}
if (reg_idx >= 0)
{
// 有寄存器:直接替换
const_cast<Operand &>(op) = Operand::Reg(AllocIdxToPhysReg(reg_idx, vc));
}
else
{
// 溢出或无范围覆盖:用 scratch 替换
auto slot_it = vreg_to_slot.find(v);
int scratch = (vc == VRegClass::Float)
? PickFPScratchReg(inst, pos_regs)
: PickGPScratchReg(inst, pos_regs);
const_cast<Operand &>(op) = Operand::Reg(NumberToPhysReg(scratch, vc));
if (slot_it == vreg_to_slot.end())
{
// 无 slot 也无寄存器,记录 scratch不 store因为没有 slot
}
else
{
pos_regs[v] = scratch;
}
}
}
// ---- 压入指令 ----
new_insts.push_back(std::move(const_cast<MachineInstr &>(inst)));
// ---- 处理溢出 def插入 StoreStack ----
if (def_vreg >= 0 && def_vreg < static_cast<int>(vreg_ranges.size()))
{
// 检查 def vreg 在此位置是否溢出
bool needs_store = false;
for (const auto &rng : vreg_ranges[def_vreg])
{
if (rng.start <= global_pos && global_pos <= rng.end)
{
if (rng.reg_idx == -1)
needs_store = true;
break;
}
}
if (needs_store)
{
auto slot_it = vreg_to_slot.find(def_vreg);
if (slot_it != vreg_to_slot.end())
{
// 从刚压入的指令中找到结果寄存器
const auto &last_inst = new_insts.back();
PhysReg result_reg = PhysReg::W0;
VRegClass vc = func.GetVRegClass(def_vreg);
for (const auto &op : last_inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::Reg)
{
PhysReg r = op.GetReg();
bool is_gp = (r >= PhysReg::W0 && r <= PhysReg::W30) ||
(r >= PhysReg::X0 && r <= PhysReg::X30);
bool is_fp = (r >= PhysReg::S0 && r <= PhysReg::S31);
if ((vc == VRegClass::Float && is_fp) ||
(vc != VRegClass::Float && is_gp))
{
result_reg = r;
break;
}
}
}
new_insts.push_back(
MachineInstr(Opcode::StoreStack,
{Operand::Reg(result_reg), Operand::FrameIndex(slot_it->second)}));
}
}
}
++global_pos;
}
block->GetInstructions() = std::move(new_insts);
}
}
} // anonymous namespace
} // namespace mir
// ---- 公开 API -----------------------------------------------------------
namespace mir
{
void RunLinearScanRegAlloc(MachineFunction &func)
{
if (func.GetNumVRegs() == 0)
return;
RunLinearScan(func);
}
void RunLinearScanRegAlloc(MachineModule &module)
{
for (auto &function : module.GetFunctions())
{
if (function)
RunLinearScanRegAlloc(*function);
}
}
} // namespace mir
Loading…
Cancel
Save