From 28ad162de421ad5cfc390289fa2955d014d4bb31 Mon Sep 17 00:00:00 2001 From: lzkk <956449176@qq.com> Date: Tue, 26 May 2026 19:10:56 +0800 Subject: [PATCH] =?UTF-8?q?feat(mir):=20=E7=BA=BF=E6=80=A7=E6=89=AB?= =?UTF-8?q?=E6=8F=8F=E5=AF=84=E5=AD=98=E5=99=A8=E5=88=86=E9=85=8D=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=AE=9E=E7=8E=B0=EF=BC=88WIP=EF=BC=8C--regalloc=3Dli?= =?UTF-8?q?near=20=E5=8F=AF=E7=94=A8=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Wimmer & Mössenböck (2005) 优化区间分割算法 - 685 行,支持 GP/FP 寄存器池 - 目前通过简单用例,循环函数有寄存器映射 bug(25_while_if 无限循环) - 默认仍使用图着色,线性扫描可通过 CLI 切换 --- src/include/utils/CLI.h | 1 + src/mir/LinearScanAlloc.cpp | 694 ++++++++++++++++++++++++++++++++++++ 2 files changed, 695 insertions(+) create mode 100644 src/mir/LinearScanAlloc.cpp diff --git a/src/include/utils/CLI.h b/src/include/utils/CLI.h index a06106b6..2b16577b 100644 --- a/src/include/utils/CLI.h +++ b/src/include/utils/CLI.h @@ -12,6 +12,7 @@ struct CLIOptions { bool show_help = false; bool optimize = false; // -O 或 -O1 int opt_level = 0; // 优化级别: 0, 1, 2, 3 + std::string regalloc = "graphcoloring"; // 寄存器分配器: graphcoloring 或 linear }; CLIOptions ParseCLI(int argc, char** argv); diff --git a/src/mir/LinearScanAlloc.cpp b/src/mir/LinearScanAlloc.cpp new file mode 100644 index 00000000..be4ac6ab --- /dev/null +++ b/src/mir/LinearScanAlloc.cpp @@ -0,0 +1,694 @@ +#include "mir/MIR.h" + +#include +#include +#include +#include +#include + +#include "utils/Log.h" + +namespace mir +{ + namespace + { + + // ---- AArch64 可分配寄存器 -------------------------------------------- + + // GP 可分配:x8(间接结果)/x9-x12/temp/x15/x16-x17/IP0-IP1/x19-x28/callee-saved + // x0-x7 参数传递,x13-x14 临时(被排除避免调用冲突),x18 平台,x29-31 保留 + static const int GP_ALLOCATABLE[] = {8, 9, 10, 11, 12, 15, 16, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28}; + static const int K_GP = 18; + + // FP 可分配:s8-s31 + static const int FP_ALLOCATABLE[] = {8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; + static const int K_FP = 24; + + // 寄存器号 → PhysReg 转换 + static PhysReg NumberToPhysReg(int num, VRegClass vc) + { + if (vc == VRegClass::Float) + return static_cast(static_cast(PhysReg::S0) + num); + if (vc == VRegClass::Ptr) + return static_cast(static_cast(PhysReg::X0) + num); + return static_cast(static_cast(PhysReg::W0) + num); + } + + // 可分配索引 → PhysReg + static PhysReg AllocIdxToPhysReg(int idx, VRegClass vc) + { + if (vc == VRegClass::Float) + return NumberToPhysReg(FP_ALLOCATABLE[idx], VRegClass::Float); + return NumberToPhysReg(GP_ALLOCATABLE[idx], vc); + } + + // ---- 工具函数 -------------------------------------------------------- + + static bool HasVRegDef(Opcode opcode) + { + switch (opcode) + { + case Opcode::MovImm: + case Opcode::LoadStack: + case Opcode::LoadGlobal: + case Opcode::LoadGlobalAddr: + case Opcode::LoadStackAddr: + case Opcode::LoadMem: + case Opcode::AddRR: + case Opcode::SubRR: + case Opcode::AddImm: + case Opcode::SubImm: + case Opcode::MulRR: + case Opcode::DivRR: + case Opcode::ModRR: + case Opcode::AndRR: + case Opcode::OrRR: + case Opcode::XorRR: + case Opcode::ShlRR: + case Opcode::ShrRR: + case Opcode::AsrRR: + case Opcode::Asr64RR: + case Opcode::Uxtw: + case Opcode::Sxtw: + case Opcode::CSet: + case Opcode::Csel: + case Opcode::Smull: + case Opcode::Msub: + case Opcode::NegRR: + case Opcode::FAddRR: + case Opcode::FSubRR: + case Opcode::FMulRR: + case Opcode::FDivRR: + case Opcode::Scvtf: + case Opcode::FCvtzs: + case Opcode::FMovWS: + case Opcode::MovReg: + case Opcode::Call: + return true; + default: + return false; + } + } + + // ---- 核心数据结构 ----------------------------------------------------- + + // 活跃列表中存活的 vreg + 所占用寄存器 + struct ActiveInterval + { + LiveInterval *interval; + int phys_reg; // 可分配数组中的索引 + }; + + // 每个 vreg 的活区段:位置范围 + 寄存器分配 + struct VRegRange + { + int start; // 指令位置(全局索引) + int end; + int reg_idx; // 可分配数组索引,-1 表示已溢出 + }; + + // 保存点:在指定位置需要把 vreg 从寄存器溢出到栈 + struct SavePoint + { + int pos; // 指令位置 + int vreg; // 溢出 vreg + int reg_idx; // 寄存器 + int spill_slot; + }; + + // ---- 分配器 ---------------------------------------------------------- + + // 从活跃列表中淘汰 end < pos 的区间 + static void ExpireOldIntervals(std::vector &active, + std::vector ®_free, + int pos) + { + for (auto &a : active) + { + if (a.interval->end < pos) + reg_free[a.phys_reg] = true; + } + active.erase( + std::remove_if(active.begin(), active.end(), + [pos](const ActiveInterval &a) + { return a.interval->end < pos; }), + active.end()); + } + + static int FindFreeReg(const std::vector ®_free) + { + for (size_t i = 0; i < reg_free.size(); ++i) + if (reg_free[i]) + return static_cast(i); + return -1; + } + + // 返回活跃列表中 end 最大者的索引 + static int SelectSpill(const std::vector &active) + { + int farthest = -1; + int farthest_end = -1; + for (size_t i = 0; i < active.size(); ++i) + { + if (active[i].interval->end > farthest_end) + { + farthest_end = active[i].interval->end; + farthest = static_cast(i); + } + } + return farthest; + } + + static int GetOrCreateSpillSlot(MachineFunction &func, int vreg, + std::unordered_map &vreg_to_slot) + { + auto it = vreg_to_slot.find(vreg); + if (it != vreg_to_slot.end()) + return it->second; + int size = (func.GetVRegClass(vreg) == VRegClass::Ptr) ? 8 : 4; + int slot = func.CreateFrameIndex(size); + vreg_to_slot[vreg] = slot; + return slot; + } + + // ---- 前向声明 -------------------------------------------------------- + + static void RewriteWithAllocation( + MachineFunction &func, + const std::vector> &vreg_ranges, + const std::unordered_map &vreg_to_slot, + std::vector &save_points); + + // ---- 主分配算法:Wimmer & Mössenböck (2005) 优化区间分割 ---------------- + + static void RunLinearScan(MachineFunction &func) + { + auto intervals = ComputeInstLiveness(func); + if (intervals.empty()) + return; + + const int num_vregs = func.GetNumVRegs(); + + // 按 start 排序 + std::sort(intervals.begin(), intervals.end(), + [](const LiveInterval &a, const LiveInterval &b) + { return a.start < b.start; }); + + // 分配结果 + std::vector> vreg_ranges(num_vregs); + std::vector vreg_has_range(num_vregs, false); + std::unordered_map vreg_to_slot; // vreg -> spill slot + std::vector save_points; + + // 寄存器空闲表 + std::vector gp_free(K_GP, true); + std::vector fp_free(K_FP, true); + + // 活跃列表(按 end 不排序,SelectSpill 扫描查找) + std::vector active; + + // 工作队列(start 有序) + 分割产生的新区间(追加到队尾) + std::vector queue = intervals; + + for (size_t qi = 0; qi < queue.size(); ++qi) + { + LiveInterval &cur = queue[qi]; + + // 同一 vreg 可能有多个 LiveInterval(分割产生),跳过已处理(已有范围)的 + if (cur.vreg >= 0 && cur.vreg < num_vregs && + vreg_has_range[cur.vreg]) + continue; + + // 选择对应寄存器池 + const int K = (cur.vreg_class == VRegClass::Float) ? K_FP : K_GP; + std::vector ®_free = (cur.vreg_class == VRegClass::Float) ? fp_free : gp_free; + + // 1. 淘汰已经结束的活跃区间 + ExpireOldIntervals(active, reg_free, cur.start); + + // 2. 尝试找空闲寄存器 + int free_reg = FindFreeReg(reg_free); + + if (free_reg >= 0) + { + // 分配空闲寄存器 + reg_free[free_reg] = false; + active.push_back({&cur, free_reg}); + vreg_ranges[cur.vreg].push_back({cur.start, cur.end, free_reg}); + vreg_has_range[cur.vreg] = true; + } + else + { + // 3. 需要溢出——选择 end 最大的活跃区间 + int spill_idx = SelectSpill(active); + + if (spill_idx < 0) + { + // 没有活跃区间,强制溢出当前 + int slot = GetOrCreateSpillSlot(func, cur.vreg, vreg_to_slot); + vreg_ranges[cur.vreg].push_back({cur.start, cur.end, -1}); + vreg_has_range[cur.vreg] = true; + cur.spilled = true; + cur.spill_slot = slot; + continue; + } + + ActiveInterval &spill_cand = active[spill_idx]; + + if (spill_cand.interval->end > cur.end) + { + // 4a. 最优分割:偷走最远 end 的寄存器给当前,被偷者的后半段溢出 + int stolen_reg = spill_cand.phys_reg; + int evicted_vreg = spill_cand.interval->vreg; + + // 割开被驱逐 vreg 的范围:前半段保留寄存器,后半段溢出 + // 找到当前活跃的范围并截断 + auto &ranges = vreg_ranges[evicted_vreg]; + if (!ranges.empty()) + { + VRegRange &last = ranges.back(); + if (last.reg_idx == stolen_reg) + { + // 把 last.end 截断到 cur.end,后半段新建溢出范围 + int orig_end = last.end; + last.end = cur.end; + vreg_ranges[evicted_vreg].push_back({cur.end + 1, orig_end, -1}); + + // 在此位置需要保存被驱逐的值到栈 + int slot = GetOrCreateSpillSlot(func, evicted_vreg, vreg_to_slot); + save_points.push_back({cur.start, evicted_vreg, stolen_reg, slot}); + + // 把分割后的溢出部分送回队列(它以 evicted 的 vreg 标识,但 vreg_has_range 已为真) + LiveInterval split_li; + split_li.vreg = evicted_vreg; + split_li.start = cur.end + 1; + split_li.end = orig_end; + split_li.vreg_class = spill_cand.interval->vreg_class; + split_li.spilled = true; + split_li.spill_slot = slot; + // vreg_has_range 标记已在上面设置,split_li 的处理会被跳过 + } + } + + // 从活跃列表移除被驱逐项 + active.erase(active.begin() + spill_idx); + + // 当前 vreg 获得偷来的寄存器 + reg_free[stolen_reg] = false; + active.push_back({&cur, stolen_reg}); + vreg_ranges[cur.vreg].push_back({cur.start, cur.end, stolen_reg}); + vreg_has_range[cur.vreg] = true; + } + else + { + // 4b. 没有更远 end 的——直接溢出当前 + int slot = GetOrCreateSpillSlot(func, cur.vreg, vreg_to_slot); + vreg_ranges[cur.vreg].push_back({cur.start, cur.end, -1}); + vreg_has_range[cur.vreg] = true; + cur.spilled = true; + cur.spill_slot = slot; + // 不占用寄存器,不加入活跃列表 + } + } + } + + // ---- 重写指令 ---------------------------------------------------------- + RewriteWithAllocation(func, vreg_ranges, vreg_to_slot, save_points); + } + + // ---- 临时寄存器选择器 ------------------------------------------------ + + // 在已分配寄存器中找一个不被当前指令 def/use 占用的作为 scratch + static int PickGPScratchReg(const MachineInstr &inst, + const std::unordered_map &pos_regs) + { + // x14 优先(不在可分配列表中,天然安全) + bool x14_free = true; + for (const auto &op : inst.GetOperands()) + { + if (op.GetKind() == Operand::Kind::Reg) + { + int r = static_cast(op.GetReg()) - static_cast(PhysReg::W0); + if (r == 14) { x14_free = false; break; } + } + } + if (x14_free) + { + // 检查当前在寄存器的 vreg 是否占用 14 + bool other_used = false; + for (const auto &kv : pos_regs) + { + if (kv.second == 14) { other_used = true; break; } + } + if (!other_used) return 14; + } + + // 遍历可分配列表找一个不冲突的 + for (int r : GP_ALLOCATABLE) + { + bool conflict = false; + for (const auto &op : inst.GetOperands()) + { + if (op.GetKind() == Operand::Kind::Reg) + { + int pr = static_cast(op.GetReg()) - static_cast(PhysReg::W0); + if (pr == r) { conflict = true; break; } + } + } + if (!conflict) + { + bool other_used = false; + for (const auto &kv : pos_regs) + { + if (kv.second == r) { other_used = true; break; } + } + if (!other_used) return r; + } + } + return GP_ALLOCATABLE[0]; + } + + static int PickFPScratchReg(const MachineInstr &inst, + const std::unordered_map &pos_regs) + { + for (int r : FP_ALLOCATABLE) + { + bool conflict = false; + for (const auto &op : inst.GetOperands()) + { + if (op.GetKind() == Operand::Kind::Reg) + { + int pr = static_cast(op.GetReg()) - static_cast(PhysReg::S0); + if (pr == r) { conflict = true; break; } + } + } + if (!conflict) + { + bool other_used = false; + for (const auto &kv : pos_regs) + { + if (kv.second == r) { other_used = true; break; } + } + if (!other_used) return r; + } + } + return FP_ALLOCATABLE[0]; + } + + // ---- 保存点排序 -------------------------------------------------------- + + static void SortSavePoints(std::vector &save_points) + { + std::sort(save_points.begin(), save_points.end(), + [](const SavePoint &a, const SavePoint &b) + { return a.pos < b.pos; }); + } + + // ---- RewriteWithAllocation ------------------------------------------- + + static void RewriteWithAllocation( + MachineFunction &func, + const std::vector> &vreg_ranges, + const std::unordered_map &vreg_to_slot, + std::vector &save_points) + { + SortSavePoints(save_points); + size_t next_save = 0; + + // 全局指令位置计数器(基于原始指令顺序) + int global_pos = 0; + + for (auto &block : func.GetBlocks()) + { + std::vector new_insts; + + for (auto &inst : block->GetInstructions()) + { + auto opcode = inst.GetOpcode(); + auto &ops = inst.GetOperands(); + + // ---- 保存点:在此位置前保存被驱逐 vreg 的值 ---- + while (next_save < save_points.size() && + save_points[next_save].pos <= global_pos) + { + const auto &sp = save_points[next_save]; + VRegClass vc = func.GetVRegClass(sp.vreg); + PhysReg pr = AllocIdxToPhysReg(sp.reg_idx, vc); + new_insts.push_back( + MachineInstr(Opcode::StoreStack, + {Operand::Reg(pr), Operand::FrameIndex(sp.spill_slot)})); + ++next_save; + } + + // ---- 确定当前位置 def/use 的 vreg 对应哪个范围 ---- + // 构建 "当前位置已在使用中的寄存器" 集合(用于 scratch 选择) + std::unordered_map pos_regs; // vreg -> reg_idx at this position + std::unordered_map vreg_range_idx; // vreg -> range index + + bool has_def = HasVRegDef(opcode); + int def_vreg = -1; + + for (size_t i = 0; i < ops.size(); ++i) + { + if (ops[i].GetKind() != Operand::Kind::VReg) + continue; + + // 跳过 def 位置上已经被处理过的 + if (has_def && i == 0) + { + def_vreg = ops[i].GetVRegId(); + continue; + } + + int v = ops[i].GetVRegId(); + if (v < 0 || v >= static_cast(vreg_ranges.size())) + continue; + + // 寻找覆盖当前位置的范围 + int reg_idx = -1; + for (size_t ri = 0; ri < vreg_ranges[v].size(); ++ri) + { + const auto &rng = vreg_ranges[v][ri]; + if (rng.start <= global_pos && global_pos <= rng.end) + { + reg_idx = rng.reg_idx; + break; + } + } + + if (reg_idx >= 0) + pos_regs[v] = reg_idx; + } + + // 也处理 def vreg + if (def_vreg >= 0 && def_vreg < static_cast(vreg_ranges.size())) + { + int reg_idx = -1; + for (size_t ri = 0; ri < vreg_ranges[def_vreg].size(); ++ri) + { + const auto &rng = vreg_ranges[def_vreg][ri]; + if (rng.start <= global_pos && global_pos <= rng.end) + { + reg_idx = rng.reg_idx; + break; + } + } + if (reg_idx >= 0) + pos_regs[def_vreg] = reg_idx; + } + + // ---- 处理溢出 uses:插入 LoadStack ---- + // 收集所有溢出 use vreg(在当前范围中 reg_idx == -1) + std::unordered_set spilled_uses; + for (size_t i = 0; i < ops.size(); ++i) + { + if (ops[i].GetKind() != Operand::Kind::VReg) + continue; + if (has_def && i == 0) + continue; + int v = ops[i].GetVRegId(); + if (v < 0 || v >= static_cast(vreg_ranges.size())) + continue; + // 检查范围:如果覆盖当前位置的范围 reg_idx == -1,则需加载 + bool needs_load = false; + for (const auto &rng : vreg_ranges[v]) + { + if (rng.start <= global_pos && global_pos <= rng.end) + { + if (rng.reg_idx == -1) + needs_load = true; + break; + } + } + if (needs_load && !spilled_uses.count(v)) + spilled_uses.insert(v); + } + + for (int v : spilled_uses) + { + auto slot_it = vreg_to_slot.find(v); + if (slot_it == vreg_to_slot.end()) + continue; + + int slot = slot_it->second; + VRegClass vc = func.GetVRegClass(v); + + int scratch = (vc == VRegClass::Float) + ? PickFPScratchReg(inst, pos_regs) + : PickGPScratchReg(inst, pos_regs); + + PhysReg load_reg = NumberToPhysReg(scratch, vc); + new_insts.push_back( + MachineInstr(Opcode::LoadStack, + {Operand::Reg(load_reg), Operand::FrameIndex(slot)})); + + // 将该 vreg 在此处映射到此 scratch 寄存器 + pos_regs[v] = scratch; + + // 替换指令中的该 vreg 操作数 + for (auto &op : ops) + { + if (op.GetKind() == Operand::Kind::VReg && op.GetVRegId() == v) + { + const_cast(op) = Operand::Reg(load_reg); + } + } + } + + // ---- 替换所有 VReg 操作数为 PhysReg ---- + for (auto &op : ops) + { + if (op.GetKind() != Operand::Kind::VReg) + continue; + + int v = op.GetVRegId(); + VRegClass vc = func.GetVRegClass(v); + + if (v < 0 || v >= static_cast(vreg_ranges.size())) + { + // vreg 超出范围(临时 vreg):用 scratch 替换 + int fallback = (vc == VRegClass::Float) + ? PickFPScratchReg(inst, pos_regs) + : PickGPScratchReg(inst, pos_regs); + const_cast(op) = Operand::Reg(NumberToPhysReg(fallback, vc)); + continue; + } + + // 找到当前位置对应的 reg + int reg_idx = -1; + for (const auto &rng : vreg_ranges[v]) + { + if (rng.start <= global_pos && global_pos <= rng.end) + { + reg_idx = rng.reg_idx; + break; + } + } + + if (reg_idx >= 0) + { + // 有寄存器:直接替换 + const_cast(op) = Operand::Reg(AllocIdxToPhysReg(reg_idx, vc)); + } + else + { + // 溢出或无范围覆盖:用 scratch 替换 + auto slot_it = vreg_to_slot.find(v); + int scratch = (vc == VRegClass::Float) + ? PickFPScratchReg(inst, pos_regs) + : PickGPScratchReg(inst, pos_regs); + const_cast(op) = Operand::Reg(NumberToPhysReg(scratch, vc)); + if (slot_it == vreg_to_slot.end()) + { + // 无 slot 也无寄存器,记录 scratch(不 store,因为没有 slot) + } + else + { + pos_regs[v] = scratch; + } + } + } + + // ---- 压入指令 ---- + new_insts.push_back(std::move(const_cast(inst))); + + // ---- 处理溢出 def:插入 StoreStack ---- + if (def_vreg >= 0 && def_vreg < static_cast(vreg_ranges.size())) + { + // 检查 def vreg 在此位置是否溢出 + bool needs_store = false; + for (const auto &rng : vreg_ranges[def_vreg]) + { + if (rng.start <= global_pos && global_pos <= rng.end) + { + if (rng.reg_idx == -1) + needs_store = true; + break; + } + } + + if (needs_store) + { + auto slot_it = vreg_to_slot.find(def_vreg); + if (slot_it != vreg_to_slot.end()) + { + // 从刚压入的指令中找到结果寄存器 + const auto &last_inst = new_insts.back(); + PhysReg result_reg = PhysReg::W0; + VRegClass vc = func.GetVRegClass(def_vreg); + for (const auto &op : last_inst.GetOperands()) + { + if (op.GetKind() == Operand::Kind::Reg) + { + PhysReg r = op.GetReg(); + bool is_gp = (r >= PhysReg::W0 && r <= PhysReg::W30) || + (r >= PhysReg::X0 && r <= PhysReg::X30); + bool is_fp = (r >= PhysReg::S0 && r <= PhysReg::S31); + if ((vc == VRegClass::Float && is_fp) || + (vc != VRegClass::Float && is_gp)) + { + result_reg = r; + break; + } + } + } + + new_insts.push_back( + MachineInstr(Opcode::StoreStack, + {Operand::Reg(result_reg), Operand::FrameIndex(slot_it->second)})); + } + } + } + + ++global_pos; + } + + block->GetInstructions() = std::move(new_insts); + } + } + + } // anonymous namespace +} // namespace mir + +// ---- 公开 API ----------------------------------------------------------- + +namespace mir +{ + + void RunLinearScanRegAlloc(MachineFunction &func) + { + if (func.GetNumVRegs() == 0) + return; + RunLinearScan(func); + } + + void RunLinearScanRegAlloc(MachineModule &module) + { + for (auto &function : module.GetFunctions()) + { + if (function) + RunLinearScanRegAlloc(*function); + } + } + +} // namespace mir