From 9e8984d740c4819d017b261356fd6ee220b457dd Mon Sep 17 00:00:00 2001 From: zjx Date: Tue, 26 May 2026 14:26:49 +0800 Subject: [PATCH] =?UTF-8?q?lab5=E5=AF=84=E5=AD=98=E5=99=A8=E5=88=86?= =?UTF-8?q?=E9=85=8D=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .claude/plans/temporal-questing-tower.md | 646 +++++++++++++++ .claude/settings.local.json | 14 + include/mir/MIR.h | 52 +- scripts/run_all_tests.sh | 4 +- src/mir/AsmPrinter.cpp | 6 + src/mir/FrameLowering.cpp | 50 +- src/mir/Lowering.cpp | 43 +- src/mir/MIRBasicBlock.cpp | 13 + src/mir/MIRFunction.cpp | 13 + src/mir/MIRInstr.cpp | 20 + src/mir/RegAlloc.cpp | 980 +++++++++++++++++++++-- src/mir/passes/Peephole.cpp | 29 +- 12 files changed, 1793 insertions(+), 77 deletions(-) create mode 100644 .claude/plans/temporal-questing-tower.md create mode 100644 .claude/settings.local.json diff --git a/.claude/plans/temporal-questing-tower.md b/.claude/plans/temporal-questing-tower.md new file mode 100644 index 0000000..516a031 --- /dev/null +++ b/.claude/plans/temporal-questing-tower.md @@ -0,0 +1,646 @@ +# Lab5: 图着色寄存器分配 - 详细实施方案 + +## 总体策略选择 + +经过对代码库的深入分析,选择 **方案B:后置提升(Post-Lowering Promotion)** 作为主要策略: + +**核心思路**:保留现有 Lowering.cpp 基本不动(它已经正确工作),在其后添加一个 **栈槽提升(Stack Slot Promotion)** pass,将"栈槽中转"模式转换为"虚拟寄存器"模式,然后对虚拟寄存器做图着色分配。 + +**选择理由**: +1. Lowering.cpp 有 1117 行,每个指令模式都使用 `StoreStack/LoadStack + 固定寄存器`,全部重写风险极高 +2. 当前模式中,每个 IR Value 对应唯一一个栈槽(ValueSlotMap),这天然等价于虚拟寄存器 +3. 后置提升只需识别 `LoadStack vreg->physreg` / `StoreStack physreg->vreg`,无需修改 Lowering 逻辑 +4. 即使提升失败(如数组地址计算),退化到原始行为即可,保证正确性 + +--- + +## 架构概览 + +``` +LowerToMIR (不变) + | +RunPeephole (不变,先做局部优化) + | +RunStackPromotion (新增:栈槽->虚拟寄存器提升) + | +RunRegAlloc (重写:图着色寄存器分配) + | +RunFrameLowering (修改:处理溢出槽 + callee-saved) + | +PrintAsm (小改:支持新的 callee-saved 保存/恢复) +``` + +--- + +## Step 1: 扩展 MIR 基础设施 - 虚拟寄存器支持 + +### 1.1 修改 `include/mir/MIR.h` + +```cpp +// 新增:虚拟寄存器 ID 类型 +using VRegId = int; // 正数,从 1 开始 + +// 新增:寄存器类(用于区分 GPR 和 FPR) +enum class RegClass { GPR, FPR }; + +// 修改 Operand 类: +class Operand { + public: + enum class Kind { Reg, VReg, Imm, FrameIndex, Symbol }; + // ^^^^^ 新增 + + static Operand Reg(PhysReg reg); + static Operand VReg(VRegId id, RegClass rc); // 新增 + static Operand Imm(int value); + static Operand FrameIndex(int index); + static Operand Symbol(std::string name); + + Kind GetKind() const { return kind_; } + PhysReg GetReg() const { return reg_; } + VRegId GetVReg() const { return vreg_id_; } // 新增 + RegClass GetRegClass() const { return reg_class_; } // 新增 + int GetImm() const { return imm_; } + int GetFrameIndex() const { return imm_; } + const std::string& GetSymbol() const { return symbol_; } + + bool IsReg() const { return kind_ == Kind::Reg; } + bool IsVReg() const { return kind_ == Kind::VReg; } // 新增 + + private: + Operand(Kind kind, PhysReg reg, int imm, std::string symbol = ""); + + Kind kind_; + PhysReg reg_ = PhysReg::W0; + VRegId vreg_id_ = 0; // 新增 + RegClass reg_class_ = RegClass::GPR; // 新增 + int imm_ = 0; + std::string symbol_; +}; +``` + +### 1.2 修改 `MachineFunction`:添加虚拟寄存器管理 + +```cpp +class MachineFunction { + public: + // ... 现有接口不变 ... + + // 新增:虚拟寄存器分配 + VRegId CreateVReg(RegClass rc); + RegClass GetVRegClass(VRegId id) const; + int GetNumVRegs() const { return next_vreg_id_ - 1; } + + // 新增:callee-saved 管理 + void SetCalleeSavedRegs(const std::vector& regs); + const std::vector& GetCalleeSavedRegs() const; + + private: + // ... 现有字段 ... + int next_vreg_id_ = 1; + std::vector vreg_classes_; // index = vreg_id - 1 + std::vector callee_saved_regs_; +}; +``` + +### 1.3 修改 `MachineBasicBlock`:添加 CFG 信息 + +```cpp +class MachineBasicBlock { + public: + // ... 现有接口不变 ... + + // 新增:CFG 前驱/后继 + void AddSuccessor(MachineBasicBlock* succ); + void AddPredecessor(MachineBasicBlock* pred); + const std::vector& GetSuccessors() const; + const std::vector& GetPredecessors() const; + void ClearCFG(); // 用于重建 CFG + + private: + std::string name_; + std::vector instructions_; + std::vector succs_; // 新增 + std::vector preds_; // 新增 +}; +``` + +### 1.4 修改 `src/mir/MIRInstr.cpp` + +```cpp +Operand Operand::VReg(VRegId id, RegClass rc) { + Operand op(Kind::VReg, PhysReg::W0, 0); + op.vreg_id_ = id; + op.reg_class_ = rc; + return op; +} +``` + +--- + +## Step 2: 构建 CFG(新增 pass) + +### 2.1 新增文件 `src/mir/BuildCFG.cpp` + +在 MIR 生成之后,通过分析每个基本块末尾的跳转指令构建 CFG: + +```cpp +namespace mir { + +void BuildCFG(MachineFunction& function) { + // 先清除旧的 CFG 信息(支持重建) + for (auto& bb_ptr : function.GetBlocks()) { + bb_ptr->ClearCFG(); + } + + auto& blocks = function.GetBlocks(); + for (size_t idx = 0; idx < blocks.size(); ++idx) { + auto& bb = *blocks[idx]; + auto& insts = bb.GetInstructions(); + bool has_unconditional_branch = false; + + for (const auto& inst : insts) { + Opcode op = inst.GetOpcode(); + if (op == Opcode::B || op == Opcode::Bcond || op == Opcode::FBcond || + op == Opcode::Cbnz || op == Opcode::Cbz) { + const auto& target_name = inst.GetOperands()[0].GetSymbol(); + MachineBasicBlock* target = function.FindBlock(target_name); + if (target) { + bb.AddSuccessor(target); + target->AddPredecessor(&bb); + } + if (op == Opcode::B) has_unconditional_branch = true; + } + } + + // Fall-through:如果块没有以无条件跳转/Ret 结束 + if (!has_unconditional_branch && !insts.empty()) { + Opcode last_op = insts.back().GetOpcode(); + if (last_op != Opcode::Ret && last_op != Opcode::B) { + if (idx + 1 < blocks.size()) { + auto* next_bb = blocks[idx + 1].get(); + bb.AddSuccessor(next_bb); + next_bb->AddPredecessor(&bb); + } + } + } + } +} + +} // namespace mir +``` + +### 2.2 在 `MIR.h` 中声明 + +```cpp +void BuildCFG(MachineFunction& function); +``` + +--- + +## Step 3: 栈槽提升 Pass(核心创新) + +### 3.1 新增文件 `src/mir/StackPromotion.cpp` + +**核心算法**:扫描所有指令,识别"可提升栈槽"——即仅通过 `LoadStack`/`StoreStack` 访问的 4 字节标量槽(非数组、非指针、非通过 `LoadStackOffset`/`StoreStackOffset`/`LoadStackAddr` 访问的)。 + +```cpp +#include "mir/MIR.h" +#include +#include +#include + +namespace mir { +namespace { + +struct SlotUsageInfo { + bool is_promotable = true; + RegClass reg_class = RegClass::GPR; + int load_count = 0; + int store_count = 0; +}; + +std::unordered_map AnalyzeSlotUsage( + const MachineFunction& function) { + std::unordered_map info; + + for (const auto& bb_ptr : function.GetBlocks()) { + for (const auto& inst : bb_ptr->GetInstructions()) { + const auto& ops = inst.GetOperands(); + + // LoadStackOffset/StoreStackOffset/LoadStackAddr 用于数组访问,不可提升 + if (inst.GetOpcode() == Opcode::LoadStackOffset || + inst.GetOpcode() == Opcode::StoreStackOffset || + inst.GetOpcode() == Opcode::LoadStackAddr) { + if (ops.size() >= 2 && + ops[1].GetKind() == Operand::Kind::FrameIndex) { + info[ops[1].GetFrameIndex()].is_promotable = false; + } + continue; + } + + // LoadStack: ops[0]=dst_reg, ops[1]=frame_index + if (inst.GetOpcode() == Opcode::LoadStack) { + if (ops.size() >= 2 && + ops[1].GetKind() == Operand::Kind::FrameIndex) { + int slot = ops[1].GetFrameIndex(); + auto& si = info[slot]; + si.load_count++; + PhysReg dst = ops[0].GetReg(); + if (dst >= PhysReg::S0 && dst <= PhysReg::S10) { + si.reg_class = RegClass::FPR; + } else if (dst >= PhysReg::X0 && dst <= PhysReg::X11) { + // 64位加载 = 指针,不提升 + si.is_promotable = false; + } + } + continue; + } + + // StoreStack: ops[0]=src_reg, ops[1]=frame_index + if (inst.GetOpcode() == Opcode::StoreStack) { + if (ops.size() >= 2 && + ops[1].GetKind() == Operand::Kind::FrameIndex) { + int slot = ops[1].GetFrameIndex(); + auto& si = info[slot]; + si.store_count++; + PhysReg src = ops[0].GetReg(); + if (src >= PhysReg::S0 && src <= PhysReg::S10) { + si.reg_class = RegClass::FPR; + } else if (src >= PhysReg::X0 && src <= PhysReg::X11) { + si.is_promotable = false; + } + } + continue; + } + } + } + + // 排除大小不为 4 的槽(数组、指针等) + for (const auto& slot : function.GetFrameSlots()) { + if (slot.size != 4) { + info[slot.index].is_promotable = false; + } + } + + return info; +} + +} // namespace + +void RunStackPromotion(MachineFunction& function) { + auto slot_info = AnalyzeSlotUsage(function); + + // 为每个可提升的栈槽分配一个虚拟寄存器 + std::unordered_map slot_vreg_map; + for (auto& [slot_idx, si] : slot_info) { + if (!si.is_promotable) continue; + VRegId vreg = function.CreateVReg(si.reg_class); + slot_vreg_map[slot_idx] = vreg; + } + + if (slot_vreg_map.empty()) return; + + // 重写指令:LoadStack/StoreStack -> MovReg/FMovReg with VReg + for (auto& bb_ptr : function.GetBlocks()) { + auto& insts = bb_ptr->GetInstructions(); + std::vector new_insts; + new_insts.reserve(insts.size()); + + for (auto& inst : insts) { + const auto& ops = inst.GetOperands(); + + // LoadStack reg, [slot] -> MovReg reg, %vreg + if (inst.GetOpcode() == Opcode::LoadStack && + ops.size() >= 2 && + ops[1].GetKind() == Operand::Kind::FrameIndex) { + int slot = ops[1].GetFrameIndex(); + auto it = slot_vreg_map.find(slot); + if (it != slot_vreg_map.end()) { + RegClass rc = slot_info[slot].reg_class; + Opcode mov_op = (rc == RegClass::FPR) + ? Opcode::FMovReg : Opcode::MovReg; + new_insts.emplace_back(mov_op, std::vector{ + ops[0], // dst physreg (暂保留,后续被 regalloc 处理) + Operand::VReg(it->second, rc) + }); + continue; + } + } + + // StoreStack reg, [slot] -> MovReg %vreg, reg + if (inst.GetOpcode() == Opcode::StoreStack && + ops.size() >= 2 && + ops[1].GetKind() == Operand::Kind::FrameIndex) { + int slot = ops[1].GetFrameIndex(); + auto it = slot_vreg_map.find(slot); + if (it != slot_vreg_map.end()) { + RegClass rc = slot_info[slot].reg_class; + Opcode mov_op = (rc == RegClass::FPR) + ? Opcode::FMovReg : Opcode::MovReg; + new_insts.emplace_back(mov_op, std::vector{ + Operand::VReg(it->second, rc), + ops[0] // src physreg + }); + continue; + } + } + + // 其他指令保持不变 + new_insts.push_back(std::move(inst)); + } + + insts = std::move(new_insts); + } +} + +} // namespace mir +``` + +### 3.2 提升前后的 MIR 对比(示例) + +**提升前(当前输出)**: +```asm +; a = b + c +LoadStack w8, [slot_b] ; 从栈加载 b +LoadStack w9, [slot_c] ; 从栈加载 c +AddRR w8, w8, w9 ; w8 = b + c +StoreStack w8, [slot_a] ; 结果存回栈 +``` + +**提升后**: +```asm +; a = b + c +MovReg w8, %vreg2 ; w8 <- vreg_b +MovReg w9, %vreg3 ; w9 <- vreg_c +AddRR w8, w8, w9 ; w8 = b + c +MovReg %vreg1, w8 ; vreg_a <- w8 +``` + +### 3.3 为什么这样设计 + +提升后的形态保留了 Lowering 生成的"固定寄存器做中间计算"的模式,只是把 Load/Store 换成了与 vreg 的 copy。好处: +1. **正确性有保障**:即使 regalloc 全部 spill,等价于恢复原始行为 +2. **mov coalescing 自然发生**:如果 vreg1 被分配到 w8,则最后的 `MovReg w8, w8` 变为 no-op +3. **渐进式优化**:后续可以做更激进的提升(把 AddRR 的操作数也换成 vreg) + +--- + +## Step 4: 活跃性分析(Liveness Analysis) + +### 4.1 新增文件 `src/mir/Liveness.cpp` + +```cpp +#include "mir/MIR.h" +#include +#include +#include + +namespace mir { +namespace { + +// 判断指令的第一个操作数是否为 def(写入) +bool FirstOpIsDef(Opcode op) { + switch (op) { + case Opcode::MovImm: case Opcode::MovReg: + case Opcode::FMovImm: case Opcode::FMovReg: + case Opcode::LoadStack: case Opcode::LoadStackOffset: + case Opcode::LoadStackAddr: case Opcode::LoadIndirect: + case Opcode::LoadGlobal: case Opcode::LoadGlobalAddr: + case Opcode::AddRI: case Opcode::SubRI: + case Opcode::AddRR: case Opcode::SubRR: + case Opcode::MulRR: case Opcode::DivRR: case Opcode::ModRR: + case Opcode::LsrRI: case Opcode::LslRI: case Opcode::LslRR: + case Opcode::FAddRR: case Opcode::FSubRR: + case Opcode::FMulRR: case Opcode::FDivRR: case Opcode::FSqrtRR: + case Opcode::SIToFP: case Opcode::FPToSI: + case Opcode::CmpRR: case Opcode::FCmpRR: + return true; + default: + return false; + } +} + +struct BlockDefUse { + std::unordered_set def; // 块中定义的 vreg + std::unordered_set use; // 块中 upward-exposed use 的 vreg +}; + +BlockDefUse ComputeBlockDefUse(const MachineBasicBlock& bb) { + BlockDefUse info; + for (const auto& inst : bb.GetInstructions()) { + const auto& ops = inst.GetOperands(); + bool first_is_def = FirstOpIsDef(inst.GetOpcode()); + + // 先收集 use(在 def 之前的引用) + for (size_t i = 0; i < ops.size(); ++i) { + if (!ops[i].IsVReg()) continue; + if (i == 0 && first_is_def) continue; // 这是 def,不是 use + VRegId v = ops[i].GetVReg(); + if (info.def.find(v) == info.def.end()) { + info.use.insert(v); // 在本块定义前被使用 + } + } + + // 再收集 def + if (first_is_def && !ops.empty() && ops[0].IsVReg()) { + info.def.insert(ops[0].GetVReg()); + } + } + return info; +} + +} // namespace + +LivenessInfo ComputeLiveness(MachineFunction& function) { + // 1. 计算每个块的 def/use + std::unordered_map block_info; + for (auto& bb_ptr : function.GetBlocks()) { + block_info[bb_ptr.get()] = ComputeBlockDefUse(*bb_ptr); + } + + // 2. 初始化 + LivenessInfo result; + for (auto& bb_ptr : function.GetBlocks()) { + result.live_in[bb_ptr.get()] = {}; + result.live_out[bb_ptr.get()] = {}; + } + + // 3. 迭代数据流直到不动点 + // live_out[B] = U live_in[S], for S in succ(B) + // live_in[B] = use[B] U (live_out[B] - def[B]) + bool changed = true; + while (changed) { + changed = false; + auto& blocks = function.GetBlocks(); + for (int i = (int)blocks.size() - 1; i >= 0; --i) { + auto* bb = blocks[i].get(); + auto& bi = block_info[bb]; + + // live_out = union of successors' live_in + std::unordered_set new_out; + for (auto* succ : bb->GetSuccessors()) { + for (VRegId v : result.live_in[succ]) { + new_out.insert(v); + } + } + + // live_in = use U (live_out - def) + std::unordered_set new_in = bi.use; + for (VRegId v : new_out) { + if (bi.def.find(v) == bi.def.end()) { + new_in.insert(v); + } + } + + if (new_in != result.live_in[bb] || new_out != result.live_out[bb]) { + changed = true; + result.live_in[bb] = std::move(new_in); + result.live_out[bb] = std::move(new_out); + } + } + } + + return result; +} + +} // namespace mir +``` + +### 4.2 在 `MIR.h` 中声明数据结构和接口 + +```cpp +// 活跃性分析结果 +struct LivenessInfo { + std::unordered_map> live_in; + std::unordered_map> live_out; +}; + +LivenessInfo ComputeLiveness(MachineFunction& function); +``` + +--- + +## Step 5: 干涉图构建 + +### 5.1 数据结构(在 `RegAlloc.cpp` 中) + +```cpp +namespace { + +struct InterferenceGraph { + int num_vregs; + std::vector> adj; // 邻接表 + std::vector degree; // 度数 + + explicit InterferenceGraph(int n) + : num_vregs(n), adj(n + 1), degree(n + 1, 0) {} + + void AddEdge(VRegId u, VRegId v) { + if (u == v) return; + if (adj[u].insert(v).second) degree[u]++; + if (adj[v].insert(u).second) degree[v]++; + } + + void RemoveNode(VRegId v) { + for (VRegId neighbor : adj[v]) { + adj[neighbor].erase(v); + degree[neighbor]--; + } + adj[v].clear(); + degree[v] = 0; + } +}; + +} // namespace +``` + +### 5.2 构建算法 + +基于活跃性分析结果,逆序遍历每个基本块的指令,维护当前活跃集合: + +```cpp +namespace { + +InterferenceGraph BuildInterferenceGraph( + MachineFunction& function, + const LivenessInfo& liveness, + std::vector& crosses_call) { // 输出:vreg 是否跨越 call + int n = function.GetNumVRegs(); + InterferenceGraph graph(n); + crosses_call.assign(n + 1, false); + + for (auto& bb_ptr : function.GetBlocks()) { + auto* bb = bb_ptr.get(); + std::unordered_set live = liveness.live_out.at(bb); + + // 逆序遍历指令 + auto& insts = bb->GetInstructions(); + for (int i = (int)insts.size() - 1; i >= 0; --i) { + const auto& inst = insts[i]; + const auto& ops = inst.GetOperands(); + + // 处理 Call 指令:所有当前活跃的 vreg 都 crosses_call + if (inst.GetOpcode() == Opcode::Bl) { + for (VRegId v : live) { + crosses_call[v] = true; + } + } + + // 收集 def 和 use + bool first_is_def = FirstOpIsDef(inst.GetOpcode()); + std::vector defs, uses; + + for (size_t j = 0; j < ops.size(); ++j) { + if (!ops[j].IsVReg()) continue; + if (j == 0 && first_is_def) { + defs.push_back(ops[j].GetVReg()); + } else { + uses.push_back(ops[j].GetVReg()); + } + } + + // 对每个 def:与所有当前 live 的 vreg 建立干涉边 + // 例外:mov d, s 指令中,d 不与 s 干涉(允许合并) + bool is_move = (inst.GetOpcode() == Opcode::MovReg || + inst.GetOpcode() == Opcode::FMovReg); + std::unordered_set use_set(uses.begin(), uses.end()); + + for (VRegId d : defs) { + for (VRegId l : live) { + if (l == d) continue; + if (is_move && use_set.count(l)) continue; // mov coalescing 友好 + graph.AddEdge(d, l); + } + } + + // 更新 live 集合:移除 def,加入 use + for (VRegId d : defs) live.erase(d); + for (VRegId u : uses) live.insert(u); + } + } + + return graph; +} + +} // namespace +``` + +### 5.3 物理寄存器干涉的处理 + +提升后的 MIR 中仍然存在物理寄存器(如 `AddRR w8, w8, w9`),这些物理寄存器的生命期极短(通常在一条指令内定义并被下一条 `MovReg %vreg, w8` 消费)。 + +**关键洞察**:由于物理寄存器不参与图着色(它们已经是物理的),我们不需要在干涉图中建模它们。但在 Select 阶段,需要避免将 vreg 分配到与其"相邻"物理寄存器使用点冲突的寄存器上。 + +**简化处理**:当前设计中,vreg 的生命期与物理寄存器的使用点基本不重叠(vreg 在 `MovReg %vreg, physreg` 处被定义,在 `MovReg physreg, %vreg` 处被使用)。因此物理寄存器约束可以通过 mov coalescing 自然解决,无需显式建模。 + +--- + +## Step 6: 图着色寄存器分配 (核心算法) + +### 6.1 重写 `src/mir/RegAlloc.cpp` + +整个文件结构如下: + +```cpp +#include "mir/MIR.h" +#include \ No newline at end of file diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..843a5de --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,14 @@ +{ + "permissions": { + "allow": [ + "Bash(mkdir -p \"C:\\\\Users\\\\郑同学\\\\.claude\\\\plans\")", + "Bash(mkdir -p \"/c/Users/郑同学/.claude/plans\")", + "Bash(mkdir -p \"/mnt/c/Users/郑同学/.claude/plans\")", + "Bash(echo \"test\")", + "Bash(rm \"/mnt/c/Users/郑同学/.claude/plans/test.md\")", + "Bash(cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER=/c/mingw/mingw64/bin/g++.exe -DCMAKE_C_COMPILER=/c/mingw/mingw64/bin/gcc.exe)", + "Bash(cmake -S . -B build -G \"MinGW Makefiles\" -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER=\"C:/mingw/MinGW/bin/g++.exe\" -DCMAKE_C_COMPILER=\"C:/mingw/MinGW/bin/gcc.exe\")", + "Bash(wsl *)" + ] + } +} diff --git a/include/mir/MIR.h b/include/mir/MIR.h index 39c15e2..bdd211d 100644 --- a/include/mir/MIR.h +++ b/include/mir/MIR.h @@ -50,6 +50,7 @@ enum class Opcode { AddRI, SubRI, AddRR, + AddRR_UXTW, // add xN, xM, wK, uxtw(零扩展W寄存器后加到X寄存器) SubRR, MulRR, DivRR, @@ -77,27 +78,49 @@ enum class Opcode { Ret, }; +// 虚拟寄存器类别 +enum class VRegClass { + GPR, // 通用寄存器 (w0-w11) + GPR64, // 64位通用寄存器 (x0-x11) + FPR, // 浮点寄存器 (s0-s10) +}; + class Operand { public: - enum class Kind { Reg, Imm, FrameIndex, Symbol }; + enum class Kind { Reg, VReg, Imm, FrameIndex, Symbol }; static Operand Reg(PhysReg reg); + static Operand VReg(int vreg_id, VRegClass rc = VRegClass::GPR); static Operand Imm(int value); static Operand FrameIndex(int index); static Operand Symbol(std::string name); Kind GetKind() const { return kind_; } + bool IsReg() const { return kind_ == Kind::Reg; } + bool IsVReg() const { return kind_ == Kind::VReg; } + bool IsImm() const { return kind_ == Kind::Imm; } + bool IsFrameIndex() const { return kind_ == Kind::FrameIndex; } + bool IsSymbol() const { return kind_ == Kind::Symbol; } + PhysReg GetReg() const { return reg_; } + int GetVRegId() const { return vreg_id_; } + VRegClass GetVRegClass() const { return vreg_class_; } int GetImm() const { return imm_; } int GetFrameIndex() const { return imm_; } const std::string& GetSymbol() const { return symbol_; } + // 用于寄存器分配后替换 VReg → Reg + void AssignPhysReg(PhysReg reg); + private: Operand(Kind kind, PhysReg reg, int imm, std::string symbol = ""); + Operand(int vreg_id, VRegClass rc); Kind kind_; - PhysReg reg_; - int imm_; + PhysReg reg_ = PhysReg::W0; + int vreg_id_ = -1; + VRegClass vreg_class_ = VRegClass::GPR; + int imm_ = 0; std::string symbol_; }; @@ -107,6 +130,9 @@ class MachineInstr { Opcode GetOpcode() const { return opcode_; } const std::vector& GetOperands() const { return operands_; } + std::vector& GetOperands() { return operands_; } + + void SetOperand(size_t i, Operand op); private: Opcode opcode_; @@ -130,9 +156,18 @@ class MachineBasicBlock { MachineInstr& Append(Opcode opcode, std::initializer_list operands = {}); + // CFG 支持 + void AddSuccessor(MachineBasicBlock* succ); + void AddPredecessor(MachineBasicBlock* pred); + const std::vector& GetSuccessors() const { return successors_; } + const std::vector& GetPredecessors() const { return predecessors_; } + void ClearCFG() { successors_.clear(); predecessors_.clear(); } + private: std::string name_; std::vector instructions_; + std::vector successors_; + std::vector predecessors_; }; class MachineFunction { @@ -153,15 +188,26 @@ class MachineFunction { FrameSlot& GetFrameSlot(int index); const FrameSlot& GetFrameSlot(int index) const; const std::vector& GetFrameSlots() const { return frame_slots_; } + std::vector& GetMutableFrameSlots() { return frame_slots_; } int GetFrameSize() const { return frame_size_; } void SetFrameSize(int size) { frame_size_ = size; } + // 虚拟寄存器管理 + int CreateVReg(VRegClass rc = VRegClass::GPR); + int GetNumVRegs() const { return next_vreg_id_; } + + // Callee-saved 寄存器管理 + void AddUsedCalleeSaved(PhysReg reg); + const std::vector& GetUsedCalleeSaved() const { return used_callee_saved_; } + private: std::string name_; std::vector> blocks_; std::vector frame_slots_; int frame_size_ = 0; + int next_vreg_id_ = 0; + std::vector used_callee_saved_; }; class MachineModule { diff --git a/scripts/run_all_tests.sh b/scripts/run_all_tests.sh index ccda08f..25a52d4 100755 --- a/scripts/run_all_tests.sh +++ b/scripts/run_all_tests.sh @@ -75,7 +75,7 @@ run_ir_test() { # performance 用例给更长的超时时间 local run_timeout=30 if [[ "$sy" == *"performance"* ]]; then - run_timeout=300 + run_timeout=1000 fi if [[ -f "$stdin_file" ]]; then timeout $run_timeout "$exe" < "$stdin_file" > "$stdout_file" 2>/dev/null @@ -154,7 +154,7 @@ run_asm_test() { # performance 用例给更长的超时时间 local run_timeout=30 if [[ "$sy" == *"performance"* ]]; then - run_timeout=300 + run_timeout=1000 fi if [[ -f "$stdin_file" ]]; then timeout $run_timeout qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file" 2>/dev/null diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index 424d85f..a3f2ed9 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -281,6 +281,12 @@ void PrintAsm(const MachineModule& module, std::ostream& os) { << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; + case Opcode::AddRR_UXTW: + // add xN, xM, wK, uxtw — 零扩展W寄存器后加到X寄存器 + os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << ", uxtw\n"; + break; case Opcode::SubRR: os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " diff --git a/src/mir/FrameLowering.cpp b/src/mir/FrameLowering.cpp index 242b5a9..2a901a0 100644 --- a/src/mir/FrameLowering.cpp +++ b/src/mir/FrameLowering.cpp @@ -1,5 +1,6 @@ #include "mir/MIR.h" +#include #include #include @@ -12,20 +13,45 @@ int AlignTo(int value, int align) { return ((value + align - 1) / align) * align; } +// 获取 W 寄存器对应的 X 寄存器 +PhysReg WRegToXReg(PhysReg w) { + int idx = static_cast(w) - static_cast(PhysReg::W0); + if (idx >= 0 && idx <= 11) { + return static_cast(static_cast(PhysReg::X0) + idx); + } + return w; +} + } // namespace void RunFrameLowering(MachineFunction& function) { + // 计算栈槽偏移 int cursor = 0; for (const auto& slot : function.GetFrameSlots()) { cursor += slot.size; + function.GetFrameSlot(slot.index).offset = -cursor; } - cursor = 0; - for (const auto& slot : function.GetFrameSlots()) { - cursor += slot.size; - function.GetFrameSlot(slot.index).offset = -cursor; + // callee-saved 寄存器:为每个分配栈槽并记住索引 + const auto& callee_saved = function.GetUsedCalleeSaved(); + std::vector> callee_save_slots; // (x_reg, slot_index) + for (size_t i = 0; i < callee_saved.size(); ++i) { + PhysReg save_reg = callee_saved[i]; + PhysReg x_reg = save_reg; + if (save_reg >= PhysReg::W0 && save_reg <= PhysReg::W11) { + x_reg = WRegToXReg(save_reg); + } + // 浮点 callee-saved 直接用 s 寄存器保存(4字节) + bool is_float = (save_reg >= PhysReg::S0 && save_reg <= PhysReg::S10); + int slot_size = is_float ? 4 : 8; + int slot = function.CreateFrameIndex(slot_size); + function.GetFrameSlot(slot).offset = -(cursor + static_cast(i + 1) * 8); + callee_save_slots.emplace_back(is_float ? save_reg : x_reg, slot); } - function.SetFrameSize(AlignTo(cursor, 16)); + + int callee_save_size = static_cast(callee_saved.size()) * 8; + int total_frame = AlignTo(cursor + callee_save_size, 16); + function.SetFrameSize(total_frame); // 在每个基本块的开头和结尾插入 prologue/epilogue for (const auto& bb_ptr : function.GetBlocks()) { @@ -36,10 +62,24 @@ void RunFrameLowering(MachineFunction& function) { // 只在入口块插入 prologue if (bb.GetName() == "entry") { lowered.emplace_back(Opcode::Prologue); + + // 保存 callee-saved 寄存器 + for (const auto& [reg, slot] : callee_save_slots) { + lowered.emplace_back(Opcode::StoreStack, + std::vector{Operand::Reg(reg), + Operand::FrameIndex(slot)}); + } } for (const auto& inst : insts) { if (inst.GetOpcode() == Opcode::Ret) { + // 恢复 callee-saved 寄存器 + for (const auto& [reg, slot] : callee_save_slots) { + lowered.emplace_back( + Opcode::LoadStack, + std::vector{Operand::Reg(reg), + Operand::FrameIndex(slot)}); + } lowered.emplace_back(Opcode::Epilogue); } lowered.push_back(inst); diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index 573eaf2..f200702 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -82,8 +82,9 @@ void EmitAddOffset(PhysReg reg, int byte_offset, MachineBasicBlock& block) { {Operand::Reg(reg), Operand::Reg(reg), Operand::Imm(byte_offset)}); return; } + // 使用 X10 统一做 64 位加法,避免 W10/X10 别名问题 block.Append(Opcode::MovImm, - {Operand::Reg(PhysReg::W10), Operand::Imm(byte_offset)}); + {Operand::Reg(PhysReg::X10), Operand::Imm(byte_offset)}); block.Append(Opcode::AddRR, {Operand::Reg(reg), Operand::Reg(reg), Operand::Reg(PhysReg::X10)}); } @@ -194,15 +195,15 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, geps.emplace(&inst, GepInfo{-1, -1 - index_slot, gv->GetName()}); if (ptr_slot >= 0) { - // 计算地址:x9 = &global_array + (index * 4) + // 计算地址:x9 = &global_array + (w10 * 4) block.Append(Opcode::LoadStack, {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); EmitLslBy2(PhysReg::W10, block); block.Append(Opcode::LoadGlobalAddr, {Operand::Reg(PhysReg::X9), Operand::Symbol(gv->GetName())}); - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), + block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::X10)}); + Operand::Reg(PhysReg::W10)}); block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)}); } @@ -249,10 +250,10 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, block.Append(Opcode::LoadStack, {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); EmitLslBy2(PhysReg::W10, block); - // x9 = x9 + w10 - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), + // x9 = x9 + uxtw(w10) + block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::X10)}); + Operand::Reg(PhysReg::W10)}); block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)}); } @@ -286,15 +287,15 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, geps.emplace(&inst, GepInfo{base_it->second, -1 - index_slot, ""}); if (ptr_slot >= 0) { - // 计算地址:x9 = x29 + base_offset + (index * 4) + // 计算地址:x9 = x29 + base_offset + (w10 * 4) block.Append(Opcode::LoadStack, {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); EmitLslBy2(PhysReg::W10, block); block.Append(Opcode::LoadStackAddr, {Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)}); - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), + block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::X10)}); + Operand::Reg(PhysReg::W10)}); block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)}); } @@ -330,7 +331,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, } else { // 变量索引:global_array[var_idx] int index_slot = -1 - gep_info.byte_offset; - // 1. 加载 index + // 1. 加载 index(4字节 W 寄存器) block.Append(Opcode::LoadStack, {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); // 2. index * 4 @@ -338,10 +339,10 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, // 3. 获取全局数组基址 block.Append(Opcode::LoadGlobalAddr, {Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)}); - // 4. x9 + offset - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), + // 4. x9 + w10, uxtw(零扩展 W 寄存器后加到 X 寄存器) + block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::X10)}); + Operand::Reg(PhysReg::W10)}); // 5. 存储 block.Append(Opcode::StoreIndirect, {Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)}); @@ -361,9 +362,9 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, block.Append(Opcode::LoadStackAddr, {Operand::Reg(PhysReg::X9), Operand::FrameIndex(gep_info.base_slot)}); - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), + block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::X10)}); + Operand::Reg(PhysReg::W10)}); block.Append(Opcode::StoreIndirect, {Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)}); } @@ -431,9 +432,9 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, EmitLslBy2(PhysReg::W10, block); block.Append(Opcode::LoadGlobalAddr, {Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)}); - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), + block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::X10)}); + Operand::Reg(PhysReg::W10)}); block.Append(Opcode::LoadIndirect, {Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)}); } @@ -452,9 +453,9 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, block.Append(Opcode::LoadStackAddr, {Operand::Reg(PhysReg::X9), Operand::FrameIndex(gep_info.base_slot)}); - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), + block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::X10)}); + Operand::Reg(PhysReg::W10)}); block.Append(Opcode::LoadIndirect, {Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)}); } @@ -799,7 +800,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X9)}); - block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(2)}); + block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X9), Operand::Imm(2)}); block.Append(Opcode::DivRR, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X9)}); diff --git a/src/mir/MIRBasicBlock.cpp b/src/mir/MIRBasicBlock.cpp index d42b4b3..bda52c2 100644 --- a/src/mir/MIRBasicBlock.cpp +++ b/src/mir/MIRBasicBlock.cpp @@ -1,5 +1,6 @@ #include "mir/MIR.h" +#include #include namespace mir { @@ -13,4 +14,16 @@ MachineInstr& MachineBasicBlock::Append(Opcode opcode, return instructions_.back(); } +void MachineBasicBlock::AddSuccessor(MachineBasicBlock* succ) { + if (std::find(successors_.begin(), successors_.end(), succ) == successors_.end()) { + successors_.push_back(succ); + } +} + +void MachineBasicBlock::AddPredecessor(MachineBasicBlock* pred) { + if (std::find(predecessors_.begin(), predecessors_.end(), pred) == predecessors_.end()) { + predecessors_.push_back(pred); + } +} + } // namespace mir diff --git a/src/mir/MIRFunction.cpp b/src/mir/MIRFunction.cpp index c4f6f34..4ea0b16 100644 --- a/src/mir/MIRFunction.cpp +++ b/src/mir/MIRFunction.cpp @@ -1,5 +1,6 @@ #include "mir/MIR.h" +#include #include #include @@ -47,6 +48,18 @@ const FrameSlot& MachineFunction::GetFrameSlot(int index) const { return frame_slots_[index]; } +int MachineFunction::CreateVReg(VRegClass rc) { + (void)rc; // 类别信息在 Operand 中记录 + return next_vreg_id_++; +} + +void MachineFunction::AddUsedCalleeSaved(PhysReg reg) { + if (std::find(used_callee_saved_.begin(), used_callee_saved_.end(), reg) == + used_callee_saved_.end()) { + used_callee_saved_.push_back(reg); + } +} + MachineFunction* MachineModule::CreateFunction(std::string name) { functions_.push_back(std::make_unique(std::move(name))); return functions_.back().get(); diff --git a/src/mir/MIRInstr.cpp b/src/mir/MIRInstr.cpp index 4047b4a..22176a3 100644 --- a/src/mir/MIRInstr.cpp +++ b/src/mir/MIRInstr.cpp @@ -1,5 +1,6 @@ #include "mir/MIR.h" +#include #include namespace mir { @@ -7,8 +8,15 @@ namespace mir { Operand::Operand(Kind kind, PhysReg reg, int imm, std::string symbol) : kind_(kind), reg_(reg), imm_(imm), symbol_(std::move(symbol)) {} +Operand::Operand(int vreg_id, VRegClass rc) + : kind_(Kind::VReg), vreg_id_(vreg_id), vreg_class_(rc) {} + Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); } +Operand Operand::VReg(int vreg_id, VRegClass rc) { + return Operand(vreg_id, rc); +} + Operand Operand::Imm(int value) { return Operand(Kind::Imm, PhysReg::W0, value); } @@ -21,7 +29,19 @@ Operand Operand::Symbol(std::string name) { return Operand(Kind::Symbol, PhysReg::W0, 0, std::move(name)); } +void Operand::AssignPhysReg(PhysReg reg) { + kind_ = Kind::Reg; + reg_ = reg; + vreg_id_ = -1; +} + MachineInstr::MachineInstr(Opcode opcode, std::vector operands) : opcode_(opcode), operands_(std::move(operands)) {} +void MachineInstr::SetOperand(size_t i, Operand op) { + if (i < operands_.size()) { + operands_[i] = std::move(op); + } +} + } // namespace mir diff --git a/src/mir/RegAlloc.cpp b/src/mir/RegAlloc.cpp index 4335ea9..c432d74 100644 --- a/src/mir/RegAlloc.cpp +++ b/src/mir/RegAlloc.cpp @@ -1,64 +1,954 @@ #include "mir/MIR.h" +#include +#include +#include #include +#include +#include +#include #include "utils/Log.h" namespace mir { namespace { -bool IsAllowedReg(PhysReg reg) { - switch (reg) { - case PhysReg::W0: - case PhysReg::W1: - case PhysReg::W2: - case PhysReg::W3: - case PhysReg::W4: - case PhysReg::W5: - case PhysReg::W6: - case PhysReg::W7: - case PhysReg::W8: - case PhysReg::W9: - case PhysReg::W10: - case PhysReg::X0: - case PhysReg::X1: - case PhysReg::X2: - case PhysReg::X3: - case PhysReg::X4: - case PhysReg::X5: - case PhysReg::X6: - case PhysReg::X7: - case PhysReg::X8: - case PhysReg::X9: - case PhysReg::X10: - case PhysReg::X29: - case PhysReg::X30: - case PhysReg::SP: - case PhysReg::S0: - case PhysReg::S1: - case PhysReg::S2: - case PhysReg::S3: - case PhysReg::S4: - case PhysReg::S5: - case PhysReg::S6: - case PhysReg::S7: - case PhysReg::S8: - case PhysReg::S9: - case PhysReg::S10: - return true; +// ============================================================================ +// 物理寄存器定义 +// ============================================================================ + +// 可分配的 GPR (32-bit):w0-w11(共12个) +static const std::vector kAllocGPR = { + PhysReg::W0, PhysReg::W1, PhysReg::W2, PhysReg::W3, + PhysReg::W4, PhysReg::W5, PhysReg::W6, PhysReg::W7, + PhysReg::W8, PhysReg::W9, PhysReg::W10, PhysReg::W11, +}; + +// 可分配的 GPR64 (64-bit):x0-x11(共12个) +static const std::vector kAllocGPR64 = { + PhysReg::X0, PhysReg::X1, PhysReg::X2, PhysReg::X3, + PhysReg::X4, PhysReg::X5, PhysReg::X6, PhysReg::X7, + PhysReg::X8, PhysReg::X9, PhysReg::X10, PhysReg::X11, +}; + +// 可分配的 FPR:s0-s10(共11个) +static const std::vector kAllocFPR = { + PhysReg::S0, PhysReg::S1, PhysReg::S2, PhysReg::S3, + PhysReg::S4, PhysReg::S5, PhysReg::S6, PhysReg::S7, + PhysReg::S8, PhysReg::S9, PhysReg::S10, +}; + +// Callee-saved 寄存器 +static const std::set kCalleeSavedGPR = { + PhysReg::W8, PhysReg::W9, PhysReg::W10, PhysReg::W11, +}; +static const std::set kCalleeSavedGPR64 = { + PhysReg::X8, PhysReg::X9, PhysReg::X10, PhysReg::X11, +}; +static const std::set kCalleeSavedFPR = { + PhysReg::S8, PhysReg::S9, PhysReg::S10, +}; + +// Caller-saved 寄存器(被函数调用破坏) +static const std::set kCallerSavedGPR = { + PhysReg::W0, PhysReg::W1, PhysReg::W2, PhysReg::W3, + PhysReg::W4, PhysReg::W5, PhysReg::W6, PhysReg::W7, +}; +static const std::set kCallerSavedGPR64 = { + PhysReg::X0, PhysReg::X1, PhysReg::X2, PhysReg::X3, + PhysReg::X4, PhysReg::X5, PhysReg::X6, PhysReg::X7, +}; +static const std::set kCallerSavedFPR = { + PhysReg::S0, PhysReg::S1, PhysReg::S2, PhysReg::S3, + PhysReg::S4, PhysReg::S5, PhysReg::S6, PhysReg::S7, +}; + +bool IsFloatReg(PhysReg r) { + return r >= PhysReg::S0 && r <= PhysReg::S10; +} + +// W/X 寄存器别名:w0-w11 和 x0-x11 是同一物理寄存器的 32/64 位视图 +PhysReg WToX(PhysReg w) { + int idx = static_cast(w) - static_cast(PhysReg::W0); + if (idx >= 0 && idx <= 11) + return static_cast(static_cast(PhysReg::X0) + idx); + return w; +} + +PhysReg XToW(PhysReg x) { + int idx = static_cast(x) - static_cast(PhysReg::X0); + if (idx >= 0 && idx <= 11) + return static_cast(static_cast(PhysReg::W0) + idx); + return x; +} + +// 将物理寄存器及其别名都加入集合 +void InsertWithAlias(std::set& s, PhysReg r) { + s.insert(r); + if (r >= PhysReg::W0 && r <= PhysReg::W11) { + s.insert(WToX(r)); + } else if (r >= PhysReg::X0 && r <= PhysReg::X11) { + s.insert(XToW(r)); + } +} + +bool IsGPR64(PhysReg r) { + return (r >= PhysReg::X0 && r <= PhysReg::X11) || + r == PhysReg::X29 || r == PhysReg::X30 || r == PhysReg::SP; +} + +// 确定物理寄存器对应的 VRegClass +VRegClass ClassForPhysReg(PhysReg r) { + if (IsFloatReg(r)) return VRegClass::FPR; + if (IsGPR64(r)) return VRegClass::GPR64; + return VRegClass::GPR; +} + +// 获取虚拟寄存器类别的可分配物理寄存器列表 +const std::vector& GetAllocRegs(VRegClass rc) { + switch (rc) { + case VRegClass::GPR: return kAllocGPR; + case VRegClass::GPR64: return kAllocGPR64; + case VRegClass::FPR: return kAllocFPR; + } + return kAllocGPR; +} + +// 判断是否应该提升的寄存器 +// 只提升临时计算寄存器 (w8-w11, x8-x11, s8-s10) +// 不提升 ABI 寄存器 (w0-w7, x0-x7, s0-s7) 和特殊寄存器 (x29, x30, sp) +bool ShouldPromote(PhysReg r) { + if (r == PhysReg::X29 || r == PhysReg::X30 || r == PhysReg::SP) return false; + // 不提升 ABI caller-saved 寄存器(w0-w7, x0-x7, s0-s7) + if (r >= PhysReg::W0 && r <= PhysReg::W7) return false; + if (r >= PhysReg::X0 && r <= PhysReg::X7) return false; + if (r >= PhysReg::S0 && r <= PhysReg::S7) return false; + return true; +} + +// ============================================================================ +// 1. 构建 CFG +// ============================================================================ + +void BuildCFG(MachineFunction& func) { + for (auto& bb_ptr : func.GetBlocks()) { + bb_ptr->ClearCFG(); + } + + for (auto& bb_ptr : func.GetBlocks()) { + auto& bb = *bb_ptr; + for (const auto& inst : bb.GetInstructions()) { + auto op = inst.GetOpcode(); + if (op == Opcode::B) { + auto* target = func.FindBlock(inst.GetOperands()[0].GetSymbol()); + if (target) { + bb.AddSuccessor(target); + target->AddPredecessor(&bb); + } + } else if (op == Opcode::Bcond || op == Opcode::FBcond) { + auto* target = func.FindBlock(inst.GetOperands()[0].GetSymbol()); + if (target) { + bb.AddSuccessor(target); + target->AddPredecessor(&bb); + } + } else if (op == Opcode::Cbnz || op == Opcode::Cbz) { + auto* target = func.FindBlock(inst.GetOperands()[1].GetSymbol()); + if (target) { + bb.AddSuccessor(target); + target->AddPredecessor(&bb); + } + } + } + // fall-through:如果最后一条指令不是无条件跳转/ret,则下一个块是后继 + if (!bb.GetInstructions().empty()) { + auto last_op = bb.GetInstructions().back().GetOpcode(); + if (last_op != Opcode::B && last_op != Opcode::Ret) { + const auto& blocks = func.GetBlocks(); + for (size_t i = 0; i + 1 < blocks.size(); ++i) { + if (blocks[i].get() == &bb) { + auto* next = blocks[i + 1].get(); + bb.AddSuccessor(next); + next->AddPredecessor(&bb); + break; + } + } + } + } + } +} + +// ============================================================================ +// 2. VRegInfo 结构 +// ============================================================================ + +struct VRegInfo { + int id = 0; + VRegClass rc = VRegClass::GPR; + int spill_slot = -1; + bool is_spilled = false; +}; + +// ============================================================================ +// 3. 将 MIR 中的物理寄存器提升为虚拟寄存器 +// ============================================================================ +// +// 策略:只提升临时计算寄存器 (w8-w11, x8-x11, s8-s10)。 +// ABI 寄存器 (w0-w7, x0-x7, s0-s7) 保持物理寄存器不变,因为它们用于: +// - 函数参数传递、返回值、调用约定 +// 跨块的值通过栈槽传递,所以无需跨块跟踪 vreg。 +// 图着色分配器仍然可以将 vreg 分配到 w0-w7 等寄存器。 + +void PromoteToVRegs(MachineFunction& func, + std::vector& vreg_infos) { + for (auto& bb_ptr : func.GetBlocks()) { + auto& insts = bb_ptr->GetInstructions(); + // 块内物理寄存器 → vreg 的映射 + std::unordered_map phys_to_vreg; + + auto EnsureSize = [&](int vreg_id) { + if (static_cast(vreg_infos.size()) <= vreg_id) + vreg_infos.resize(vreg_id + 1); + }; + + auto GetOrCreateVReg = [&](PhysReg reg) -> int { + int key = static_cast(reg); + auto it = phys_to_vreg.find(key); + if (it != phys_to_vreg.end()) return it->second; + VRegClass rc = ClassForPhysReg(reg); + int vreg_id = func.CreateVReg(rc); + EnsureSize(vreg_id); + vreg_infos[vreg_id] = {vreg_id, rc, -1, false}; + phys_to_vreg[key] = vreg_id; + return vreg_id; + }; + + auto CreateNewVReg = [&](PhysReg reg) -> int { + VRegClass rc = ClassForPhysReg(reg); + int vreg_id = func.CreateVReg(rc); + EnsureSize(vreg_id); + vreg_infos[vreg_id] = {vreg_id, rc, -1, false}; + phys_to_vreg[static_cast(reg)] = vreg_id; + return vreg_id; + }; + + // 辅助:提升 use 位置的操作数 + auto PromoteUse = [&](Operand& op) { + if (op.IsReg() && ShouldPromote(op.GetReg())) { + PhysReg orig = op.GetReg(); + int vreg = GetOrCreateVReg(orig); + op = Operand::VReg(vreg, ClassForPhysReg(orig)); + } + }; + + // 辅助:提升 def 位置的操作数(为 def 创建新 vreg) + auto PromoteDef = [&](Operand& op) { + if (op.IsReg() && ShouldPromote(op.GetReg())) { + PhysReg orig = op.GetReg(); + int vreg = CreateNewVReg(orig); + op = Operand::VReg(vreg, ClassForPhysReg(orig)); + } + }; + + for (size_t i = 0; i < insts.size(); ++i) { + auto& inst = insts[i]; + auto opcode = inst.GetOpcode(); + + // 跳过控制流/伪指令 + if (opcode == Opcode::Prologue || opcode == Opcode::Epilogue || + opcode == Opcode::B || opcode == Opcode::Bcond || + opcode == Opcode::FBcond || opcode == Opcode::Ret) + continue; + + // Bl:清除所有 caller-saved 映射(它们被调用破坏了) + if (opcode == Opcode::Bl) { + // callee-saved (w8-w11等) 的映射也要清除,因为调用可能破坏它们 + // (虽然按约定不应该,但这里保守处理) + phys_to_vreg.clear(); + continue; + } + + auto& ops = inst.GetOperands(); + + // Cbnz/Cbz: ops[0] 是 use + if (opcode == Opcode::Cbnz || opcode == Opcode::Cbz) { + if (!ops.empty()) PromoteUse(ops[0]); + continue; + } + + // CmpOnlyRR / FCmpOnlyRR: ops[0], ops[1] 都是 use(不写寄存器,只设条件码) + if (opcode == Opcode::CmpOnlyRR || opcode == Opcode::FCmpOnlyRR) { + for (size_t j = 0; j < std::min(ops.size(), size_t(2)); ++j) + PromoteUse(ops[j]); + continue; + } + + // ---- 先处理 use 操作数 ---- + switch (opcode) { + case Opcode::MovReg: + case Opcode::FMovReg: + if (ops.size() > 1) PromoteUse(ops[1]); + break; + + case Opcode::StoreStack: + case Opcode::StoreStackOffset: + case Opcode::StoreGlobal: + if (!ops.empty()) PromoteUse(ops[0]); + break; + + case Opcode::StoreIndirect: + if (!ops.empty()) PromoteUse(ops[0]); + if (ops.size() > 1) PromoteUse(ops[1]); + break; + + case Opcode::LoadIndirect: + if (ops.size() > 1) PromoteUse(ops[1]); + break; + + case Opcode::AddRI: case Opcode::SubRI: + case Opcode::LsrRI: case Opcode::LslRI: + case Opcode::FSqrtRR: case Opcode::SIToFP: case Opcode::FPToSI: + if (ops.size() > 1) PromoteUse(ops[1]); + break; + + case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR: + case Opcode::MulRR: case Opcode::DivRR: + case Opcode::ModRR: case Opcode::LslRR: + case Opcode::FAddRR: case Opcode::FSubRR: + case Opcode::FMulRR: case Opcode::FDivRR: + if (ops.size() > 1) PromoteUse(ops[1]); + if (ops.size() > 2) PromoteUse(ops[2]); + break; + + case Opcode::CmpRR: case Opcode::FCmpRR: + if (ops.size() > 1) PromoteUse(ops[1]); + if (ops.size() > 2) PromoteUse(ops[2]); + break; + + default: + break; + } + + // ---- 然后处理 def 操作数 ---- + switch (opcode) { + case Opcode::MovImm: case Opcode::FMovImm: + case Opcode::MovReg: case Opcode::FMovReg: + case Opcode::LoadStack: case Opcode::LoadStackOffset: + case Opcode::LoadStackAddr: case Opcode::LoadIndirect: + case Opcode::LoadGlobal: case Opcode::LoadGlobalAddr: + case Opcode::AddRI: case Opcode::SubRI: + case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR: + case Opcode::MulRR: case Opcode::DivRR: + case Opcode::ModRR: case Opcode::LsrRI: + case Opcode::LslRI: case Opcode::LslRR: + case Opcode::FAddRR: case Opcode::FSubRR: + case Opcode::FMulRR: case Opcode::FDivRR: + case Opcode::FSqrtRR: case Opcode::SIToFP: case Opcode::FPToSI: + case Opcode::CmpRR: case Opcode::FCmpRR: + if (!ops.empty()) PromoteDef(ops[0]); + break; + default: + break; + } + } + } +} + +// ============================================================================ +// 4. VReg 活跃性分析 +// ============================================================================ + +struct VRegLiveInfo { + std::set live_in; + std::set live_out; +}; + +using VRegLiveMap = std::unordered_map; + +void GetVRegDefsUses(const MachineInstr& inst, + std::set& defs, std::set& uses) { + const auto& ops = inst.GetOperands(); + auto opcode = inst.GetOpcode(); + + // Use 操作数 + switch (opcode) { + case Opcode::MovReg: case Opcode::FMovReg: + if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId()); + break; + case Opcode::StoreStack: case Opcode::StoreStackOffset: + case Opcode::StoreGlobal: + if (!ops.empty() && ops[0].IsVReg()) uses.insert(ops[0].GetVRegId()); + break; + case Opcode::StoreIndirect: + if (!ops.empty() && ops[0].IsVReg()) uses.insert(ops[0].GetVRegId()); + if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId()); + break; + case Opcode::LoadIndirect: + if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId()); + break; + case Opcode::AddRI: case Opcode::SubRI: + case Opcode::LsrRI: case Opcode::LslRI: + case Opcode::FSqrtRR: case Opcode::SIToFP: case Opcode::FPToSI: + if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId()); + break; + case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR: + case Opcode::MulRR: case Opcode::DivRR: + case Opcode::ModRR: case Opcode::LslRR: + case Opcode::FAddRR: case Opcode::FSubRR: + case Opcode::FMulRR: case Opcode::FDivRR: + if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId()); + if (ops.size() > 2 && ops[2].IsVReg()) uses.insert(ops[2].GetVRegId()); + break; + case Opcode::CmpRR: case Opcode::FCmpRR: + if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId()); + if (ops.size() > 2 && ops[2].IsVReg()) uses.insert(ops[2].GetVRegId()); + break; + case Opcode::CmpOnlyRR: case Opcode::FCmpOnlyRR: + if (!ops.empty() && ops[0].IsVReg()) uses.insert(ops[0].GetVRegId()); + if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId()); + break; + case Opcode::Cbnz: case Opcode::Cbz: + if (!ops.empty() && ops[0].IsVReg()) uses.insert(ops[0].GetVRegId()); + break; + default: + break; + } + + // Def 操作数 + switch (opcode) { + case Opcode::MovImm: case Opcode::FMovImm: + case Opcode::MovReg: case Opcode::FMovReg: + case Opcode::LoadStack: case Opcode::LoadStackOffset: + case Opcode::LoadStackAddr: case Opcode::LoadIndirect: + case Opcode::LoadGlobal: case Opcode::LoadGlobalAddr: + case Opcode::AddRI: case Opcode::SubRI: + case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR: + case Opcode::MulRR: case Opcode::DivRR: + case Opcode::ModRR: case Opcode::LsrRI: + case Opcode::LslRI: case Opcode::LslRR: + case Opcode::FAddRR: case Opcode::FSubRR: + case Opcode::FMulRR: case Opcode::FDivRR: + case Opcode::FSqrtRR: case Opcode::SIToFP: case Opcode::FPToSI: + case Opcode::CmpRR: case Opcode::FCmpRR: + if (!ops.empty() && ops[0].IsVReg()) defs.insert(ops[0].GetVRegId()); + break; + default: + break; + } +} + +VRegLiveMap ComputeVRegLiveness(MachineFunction& func) { + VRegLiveMap live_map; + for (auto& bb_ptr : func.GetBlocks()) { + live_map[bb_ptr.get()] = VRegLiveInfo{}; + } + + bool changed = true; + while (changed) { + changed = false; + const auto& blocks = func.GetBlocks(); + for (int bi = static_cast(blocks.size()) - 1; bi >= 0; --bi) { + auto* bb = blocks[bi].get(); + auto& info = live_map[bb]; + + // live_out = union of successors' live_in + std::set new_live_out; + for (auto* succ : bb->GetSuccessors()) { + const auto& succ_in = live_map[succ].live_in; + new_live_out.insert(succ_in.begin(), succ_in.end()); + } + + // 从块末尾向前扫描计算 live_in + std::set live = new_live_out; + const auto& insts = bb->GetInstructions(); + for (int i = static_cast(insts.size()) - 1; i >= 0; --i) { + std::set defs, uses_set; + GetVRegDefsUses(insts[i], defs, uses_set); + for (auto v : defs) live.erase(v); + for (auto v : uses_set) live.insert(v); + } + + if (live != info.live_in || new_live_out != info.live_out) { + info.live_in = std::move(live); + info.live_out = std::move(new_live_out); + changed = true; + } + } + } + + return live_map; +} + +// ============================================================================ +// 5. 干涉图构建 +// ============================================================================ + +// 获取指令中定义和使用的物理寄存器(非 vreg 的) +void GetPhysRegDefsUses(const MachineInstr& inst, + std::set& phys_defs, + std::set& phys_uses) { + const auto& ops = inst.GetOperands(); + auto opcode = inst.GetOpcode(); + + // 对于 Bl 指令,所有 caller-saved 寄存器被隐式定义(破坏) + if (opcode == Opcode::Bl) { + for (auto r : kCallerSavedGPR) phys_defs.insert(r); + for (auto r : kCallerSavedGPR64) phys_defs.insert(r); + for (auto r : kCallerSavedFPR) phys_defs.insert(r); + return; + } + + // Use 位置的物理寄存器 + switch (opcode) { + case Opcode::MovReg: case Opcode::FMovReg: + if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg())) + phys_uses.insert(ops[1].GetReg()); + break; + case Opcode::StoreStack: case Opcode::StoreStackOffset: + case Opcode::StoreGlobal: + if (!ops.empty() && ops[0].IsReg() && !ShouldPromote(ops[0].GetReg())) + phys_uses.insert(ops[0].GetReg()); + break; + case Opcode::StoreIndirect: + if (!ops.empty() && ops[0].IsReg() && !ShouldPromote(ops[0].GetReg())) + phys_uses.insert(ops[0].GetReg()); + if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg())) + phys_uses.insert(ops[1].GetReg()); + break; + case Opcode::AddRI: case Opcode::SubRI: + case Opcode::LsrRI: case Opcode::LslRI: + case Opcode::FSqrtRR: case Opcode::SIToFP: case Opcode::FPToSI: + if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg())) + phys_uses.insert(ops[1].GetReg()); + break; + case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR: + case Opcode::MulRR: case Opcode::DivRR: + case Opcode::ModRR: case Opcode::LslRR: + case Opcode::FAddRR: case Opcode::FSubRR: + case Opcode::FMulRR: case Opcode::FDivRR: + if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg())) + phys_uses.insert(ops[1].GetReg()); + if (ops.size() > 2 && ops[2].IsReg() && !ShouldPromote(ops[2].GetReg())) + phys_uses.insert(ops[2].GetReg()); + break; + case Opcode::CmpRR: case Opcode::FCmpRR: + case Opcode::CmpOnlyRR: case Opcode::FCmpOnlyRR: + if (!ops.empty() && ops[0].IsReg() && !ShouldPromote(ops[0].GetReg())) + phys_uses.insert(ops[0].GetReg()); + if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg())) + phys_uses.insert(ops[1].GetReg()); + break; + case Opcode::Cbnz: case Opcode::Cbz: + if (!ops.empty() && ops[0].IsReg() && !ShouldPromote(ops[0].GetReg())) + phys_uses.insert(ops[0].GetReg()); + break; + case Opcode::LoadIndirect: + if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg())) + phys_uses.insert(ops[1].GetReg()); + break; + default: + break; + } + + // Def 位置的物理寄存器 + switch (opcode) { + case Opcode::MovImm: case Opcode::FMovImm: + case Opcode::MovReg: case Opcode::FMovReg: + case Opcode::LoadStack: case Opcode::LoadStackOffset: + case Opcode::LoadStackAddr: case Opcode::LoadIndirect: + case Opcode::LoadGlobal: case Opcode::LoadGlobalAddr: + case Opcode::AddRI: case Opcode::SubRI: + case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR: + case Opcode::MulRR: case Opcode::DivRR: + case Opcode::ModRR: case Opcode::LsrRI: + case Opcode::LslRI: case Opcode::LslRR: + case Opcode::FAddRR: case Opcode::FSubRR: + case Opcode::FMulRR: case Opcode::FDivRR: + case Opcode::FSqrtRR: case Opcode::SIToFP: case Opcode::FPToSI: + case Opcode::CmpRR: case Opcode::FCmpRR: + if (!ops.empty() && ops[0].IsReg() && !ShouldPromote(ops[0].GetReg())) + phys_defs.insert(ops[0].GetReg()); + break; + default: + break; + } +} + +struct InterferenceGraph { + int num_vregs = 0; + std::vector> adj; + std::vector degree; + // 每个 vreg 不能使用的物理寄存器集合(与物理寄存器的干涉) + std::vector> phys_interference; + + void Init(int n) { + num_vregs = n; + adj.resize(n); + degree.resize(n, 0); + phys_interference.resize(n); + } + + void AddEdge(int u, int v) { + if (u == v) return; + if (u < 0 || u >= num_vregs || v < 0 || v >= num_vregs) return; + if (adj[u].count(v)) return; + adj[u].insert(v); + adj[v].insert(u); + degree[u]++; + degree[v]++; + } + + void AddPhysInterference(int vreg, PhysReg phys) { + if (vreg < 0 || vreg >= num_vregs) return; + // 同时插入物理寄存器及其 W/X 别名 + InsertWithAlias(phys_interference[vreg], phys); + } +}; + +InterferenceGraph BuildInterferenceGraph( + MachineFunction& func, + const VRegLiveMap& live_map, + int num_vregs) { + InterferenceGraph ig; + ig.Init(num_vregs); + + for (auto& bb_ptr : func.GetBlocks()) { + auto* bb = bb_ptr.get(); + const auto& insts = bb->GetInstructions(); + auto it = live_map.find(bb); + if (it == live_map.end()) continue; + + std::set live = it->second.live_out; + // 追踪当前活跃的物理寄存器 + std::set phys_live; + + // 从块末尾向前扫描 + for (int i = static_cast(insts.size()) - 1; i >= 0; --i) { + std::set defs, uses_set; + GetVRegDefsUses(insts[i], defs, uses_set); + + std::set phys_defs, phys_uses; + GetPhysRegDefsUses(insts[i], phys_defs, phys_uses); + + // 对每个 vreg def,与所有当前 live 的 vreg 添加干涉边 + for (int d : defs) { + for (int l : live) { + ig.AddEdge(d, l); + } + // vreg def 与当前活跃的物理寄存器干涉 + for (auto pr : phys_live) { + ig.AddPhysInterference(d, pr); + } + } + + // 物理寄存器 def 与当前活跃的 vreg 干涉 + for (auto pr : phys_defs) { + for (int l : live) { + ig.AddPhysInterference(l, pr); + } + } + + // 更新 vreg 活跃集合 + for (int d : defs) live.erase(d); + for (int u : uses_set) live.insert(u); + + // 更新物理寄存器活跃集合 + for (auto pr : phys_defs) phys_live.erase(pr); + for (auto pr : phys_uses) phys_live.insert(pr); + } + } + + return ig; +} + +// ============================================================================ +// 6. 图着色(Simplify → Select → Spill) +// ============================================================================ + +struct ColoringResult { + std::unordered_map assignment; + std::vector spilled; +}; + +bool IsCalleeSaved(PhysReg r, VRegClass rc) { + switch (rc) { + case VRegClass::GPR: return kCalleeSavedGPR.count(r) > 0; + case VRegClass::GPR64: return kCalleeSavedGPR64.count(r) > 0; + case VRegClass::FPR: return kCalleeSavedFPR.count(r) > 0; } return false; } +ColoringResult ColorGraph( + InterferenceGraph& ig, + const std::vector& vreg_infos, + MachineFunction& func) { + int n = ig.num_vregs; + ColoringResult result; + if (n == 0) return result; + + // Simplify: 迭代移除度 < K 的节点入栈 + std::vector removed(n, false); + std::stack select_stack; + std::vector cur_degree(ig.degree); + + auto K = [&](int v) -> int { + return static_cast(GetAllocRegs(vreg_infos[v].rc).size()); + }; + + int remaining = n; + while (remaining > 0) { + bool found = false; + for (int v = 0; v < n; ++v) { + if (removed[v]) continue; + if (cur_degree[v] < K(v)) { + select_stack.push(v); + removed[v] = true; + remaining--; + for (int nb : ig.adj[v]) { + if (!removed[nb]) cur_degree[nb]--; + } + found = true; + break; + } + } + + if (!found) { + // Potential spill:选度数最大的节点 + int best = -1; + int best_degree = -1; + for (int v = 0; v < n; ++v) { + if (removed[v]) continue; + if (cur_degree[v] > best_degree) { + best_degree = cur_degree[v]; + best = v; + } + } + if (best >= 0) { + select_stack.push(best); + removed[best] = true; + remaining--; + for (int nb : ig.adj[best]) { + if (!removed[nb]) cur_degree[nb]--; + } + } + } + } + + // Select: 从栈中弹出,尝试着色 + std::vector color(n, -1); + + while (!select_stack.empty()) { + int v = select_stack.top(); + select_stack.pop(); + + const auto& alloc_regs = GetAllocRegs(vreg_infos[v].rc); + + // 收集邻居已使用的颜色 + 物理寄存器干涉 + std::set used_colors; + for (int nb : ig.adj[v]) { + if (color[nb] >= 0) { + PhysReg nb_color = static_cast(color[nb]); + // 插入颜色及其 W/X 别名,防止跨类别冲突 + InsertWithAlias(used_colors, nb_color); + } + } + // 加入与此 vreg 干涉的物理寄存器(已含别名) + for (auto pr : ig.phys_interference[v]) { + used_colors.insert(pr); + } + + // 优先使用 callee-saved(减少保存开销;对仅在块内使用的 vreg 更优) + PhysReg chosen = PhysReg::W0; + bool colored = false; + + for (auto r : alloc_regs) { + if (used_colors.count(r)) continue; + if (IsCalleeSaved(r, vreg_infos[v].rc)) { + chosen = r; + colored = true; + break; + } + } + if (!colored) { + for (auto r : alloc_regs) { + if (!used_colors.count(r)) { + chosen = r; + colored = true; + break; + } + } + } + + if (colored) { + color[v] = static_cast(chosen); + result.assignment[v] = chosen; + if (IsCalleeSaved(chosen, vreg_infos[v].rc)) { + func.AddUsedCalleeSaved(chosen); + } + } else { + result.spilled.push_back(v); + } + } + + return result; +} + +// ============================================================================ +// 7. Spill 处理 +// ============================================================================ + +void RewriteSpills(MachineFunction& func, + const std::vector& spilled, + std::vector& vreg_infos) { + for (int v : spilled) { + int slot_size = (vreg_infos[v].rc == VRegClass::GPR64) ? 8 : 4; + int slot = func.CreateFrameIndex(slot_size); + vreg_infos[v].spill_slot = slot; + vreg_infos[v].is_spilled = true; + } + + std::set spilled_set(spilled.begin(), spilled.end()); + + for (auto& bb_ptr : func.GetBlocks()) { + auto& insts = bb_ptr->GetInstructions(); + std::vector new_insts; + new_insts.reserve(insts.size() * 2); + + for (auto& inst : insts) { + auto& ops = inst.GetOperands(); + + std::set inst_defs, inst_uses; + GetVRegDefsUses(inst, inst_defs, inst_uses); + + // 收集此指令中 spilled vreg 的 use 和 def 位置 + std::vector> use_positions; + std::vector> def_positions; + + for (size_t j = 0; j < ops.size(); ++j) { + if (!ops[j].IsVReg()) continue; + int vid = ops[j].GetVRegId(); + if (!spilled_set.count(vid)) continue; + if (inst_uses.count(vid)) use_positions.emplace_back(j, vid); + if (inst_defs.count(vid)) def_positions.emplace_back(j, vid); + } + + auto EnsureSize = [&](int id) { + if (static_cast(vreg_infos.size()) <= id) + vreg_infos.resize(id + 1); + }; + + // 为 spilled use 插入 LoadStack + for (auto& [op_idx, vid] : use_positions) { + VRegClass rc = vreg_infos[vid].rc; + int new_vreg = func.CreateVReg(rc); + EnsureSize(new_vreg); + vreg_infos[new_vreg] = {new_vreg, rc, -1, false}; + new_insts.emplace_back(Opcode::LoadStack, + std::vector{ + Operand::VReg(new_vreg, rc), + Operand::FrameIndex(vreg_infos[vid].spill_slot)}); + ops[op_idx] = Operand::VReg(new_vreg, rc); + } + + new_insts.push_back(std::move(inst)); + + // 为 spilled def 插入 StoreStack + for (auto& [op_idx, vid] : def_positions) { + VRegClass rc = vreg_infos[vid].rc; + int new_vreg = func.CreateVReg(rc); + EnsureSize(new_vreg); + vreg_infos[new_vreg] = {new_vreg, rc, -1, false}; + // 替换 emitted 指令中的 def 操作数 + new_insts.back().GetOperands()[op_idx] = Operand::VReg(new_vreg, rc); + new_insts.emplace_back(Opcode::StoreStack, + std::vector{ + Operand::VReg(new_vreg, rc), + Operand::FrameIndex(vreg_infos[vid].spill_slot)}); + } + } + + insts = std::move(new_insts); + } +} + +// ============================================================================ +// 8. 应用着色结果 +// ============================================================================ + +void ApplyColoring(MachineFunction& func, + const std::unordered_map& assignment) { + for (auto& bb_ptr : func.GetBlocks()) { + for (auto& inst : bb_ptr->GetInstructions()) { + for (auto& op : inst.GetOperands()) { + if (op.IsVReg()) { + auto it = assignment.find(op.GetVRegId()); + if (it != assignment.end()) { + op.AssignPhysReg(it->second); + } + } + } + } + } +} + } // namespace -void RunRegAlloc(MachineFunction& function) { - for (const auto& bb_ptr : function.GetBlocks()) { +// ============================================================================ +// 9. 主入口 +// ============================================================================ + +void RunRegAlloc(MachineFunction& func) { + // 步骤 1: 构建 CFG + BuildCFG(func); + + // 步骤 2: 将物理寄存器提升为虚拟寄存器 + std::vector vreg_infos; + PromoteToVRegs(func, vreg_infos); + + int num_vregs = func.GetNumVRegs(); + if (num_vregs == 0) return; + + vreg_infos.resize(num_vregs); + + // 图着色 + spill 迭代 + const int kMaxIterations = 10; + for (int iter = 0; iter < kMaxIterations; ++iter) { + // 重建 CFG(spill 后指令变了) + if (iter > 0) BuildCFG(func); + + VRegLiveMap live_map = ComputeVRegLiveness(func); + + num_vregs = func.GetNumVRegs(); + vreg_infos.resize(num_vregs); + InterferenceGraph ig = BuildInterferenceGraph(func, live_map, num_vregs); + + ColoringResult result = ColorGraph(ig, vreg_infos, func); + + if (result.spilled.empty()) { + ApplyColoring(func, result.assignment); + return; + } + + RewriteSpills(func, result.spilled, vreg_infos); + } + + // 最后一次尝试 + BuildCFG(func); + VRegLiveMap live_map = ComputeVRegLiveness(func); + num_vregs = func.GetNumVRegs(); + vreg_infos.resize(num_vregs); + InterferenceGraph ig = BuildInterferenceGraph(func, live_map, num_vregs); + ColoringResult result = ColorGraph(ig, vreg_infos, func); + ApplyColoring(func, result.assignment); + + // 最终检查:不应有残留 VReg + for (auto& bb_ptr : func.GetBlocks()) { for (const auto& inst : bb_ptr->GetInstructions()) { - for (const auto& operand : inst.GetOperands()) { - if (operand.GetKind() == Operand::Kind::Reg && - !IsAllowedReg(operand.GetReg())) { - throw std::runtime_error(FormatError("mir", "寄存器分配失败")); + for (const auto& op : inst.GetOperands()) { + if (op.IsVReg()) { + throw std::runtime_error( + FormatError("mir", "寄存器分配失败:存在未着色的虚拟寄存器 v" + + std::to_string(op.GetVRegId()))); } } } diff --git a/src/mir/passes/Peephole.cpp b/src/mir/passes/Peephole.cpp index a6f1b85..087cd7c 100644 --- a/src/mir/passes/Peephole.cpp +++ b/src/mir/passes/Peephole.cpp @@ -72,6 +72,7 @@ std::optional GetWrittenReg(const MachineInstr& inst) { case Opcode::AddRI: case Opcode::SubRI: case Opcode::AddRR: + case Opcode::AddRR_UXTW: case Opcode::SubRR: case Opcode::MulRR: case Opcode::DivRR: @@ -125,6 +126,23 @@ void RecordStore(std::unordered_map& slot_to_reg, slot_to_reg[ops[1].GetFrameIndex()] = ops[0].GetReg(); } +bool IsWReg(PhysReg reg) { + return reg >= PhysReg::W0 && reg <= PhysReg::W11; +} + +bool IsXReg(PhysReg reg) { + return (reg >= PhysReg::X0 && reg <= PhysReg::X11) || + reg == PhysReg::X29 || reg == PhysReg::X30; +} + +// 检查两个寄存器宽度是否兼容(同为 W,同为 X,或同为 S) +bool SameRegWidth(PhysReg a, PhysReg b) { + if (IsWReg(a) && IsWReg(b)) return true; + if (IsXReg(a) && IsXReg(b)) return true; + if (IsFloatReg(a) && IsFloatReg(b)) return true; + return false; +} + bool TryForwardLoad(std::vector& out, std::unordered_map& slot_to_reg, const MachineInstr& load) { @@ -142,6 +160,12 @@ bool TryForwardLoad(std::vector& out, } const PhysReg src = it->second; + + // 宽度不匹配时不能转发(如 W8 → X8 会生成非法的 mov x8, w8) + if (!SameRegWidth(src, dst)) { + return false; + } + if (RegAlias(src, dst)) { slot_to_reg[slot] = dst; return true; @@ -242,7 +266,10 @@ void RunPeephole(MachineFunction& function) { IsLoadStack(insts[i + 1]) && IsSameFrameIndex(cur, insts[i + 1])) { optimized.push_back(cur); RecordStore(slot_to_reg, cur); - TryForwardLoad(optimized, slot_to_reg, insts[i + 1]); + if (!TryForwardLoad(optimized, slot_to_reg, insts[i + 1])) { + // 转发失败(如宽度不匹配),保留原始 load + optimized.push_back(insts[i + 1]); + } ++i; continue; }