Pomelo
Pomelo 24 hours ago
parent 6e5b6d8235
commit 938304c5bf

@ -0,0 +1,646 @@
# Lab5: 图着色寄存器分配 - 详细实施方案
## 总体策略选择
经过对代码库的深入分析,选择 **方案B后置提升Post-Lowering Promotion** 作为主要策略:
**核心思路**:保留现有 Lowering.cpp 基本不动(它已经正确工作),在其后添加一个 **栈槽提升Stack Slot Promotion** pass将"栈槽中转"模式转换为"虚拟寄存器"模式,然后对虚拟寄存器做图着色分配。
**选择理由**
1. Lowering.cpp 有 1117 行,每个指令模式都使用 `StoreStack/LoadStack + 固定寄存器`,全部重写风险极高
2. 当前模式中,每个 IR Value 对应唯一一个栈槽ValueSlotMap这天然等价于虚拟寄存器
3. 后置提升只需识别 `LoadStack vreg->physreg` / `StoreStack physreg->vreg`,无需修改 Lowering 逻辑
4. 即使提升失败(如数组地址计算),退化到原始行为即可,保证正确性
---
## 架构概览
```
LowerToMIR (不变)
|
RunPeephole (不变,先做局部优化)
|
RunStackPromotion (新增:栈槽->虚拟寄存器提升)
|
RunRegAlloc (重写:图着色寄存器分配)
|
RunFrameLowering (修改:处理溢出槽 + callee-saved)
|
PrintAsm (小改:支持新的 callee-saved 保存/恢复)
```
---
## Step 1: 扩展 MIR 基础设施 - 虚拟寄存器支持
### 1.1 修改 `include/mir/MIR.h`
```cpp
// 新增:虚拟寄存器 ID 类型
using VRegId = int; // 正数,从 1 开始
// 新增:寄存器类(用于区分 GPR 和 FPR
enum class RegClass { GPR, FPR };
// 修改 Operand 类:
class Operand {
public:
enum class Kind { Reg, VReg, Imm, FrameIndex, Symbol };
// ^^^^^ 新增
static Operand Reg(PhysReg reg);
static Operand VReg(VRegId id, RegClass rc); // 新增
static Operand Imm(int value);
static Operand FrameIndex(int index);
static Operand Symbol(std::string name);
Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; }
VRegId GetVReg() const { return vreg_id_; } // 新增
RegClass GetRegClass() const { return reg_class_; } // 新增
int GetImm() const { return imm_; }
int GetFrameIndex() const { return imm_; }
const std::string& GetSymbol() const { return symbol_; }
bool IsReg() const { return kind_ == Kind::Reg; }
bool IsVReg() const { return kind_ == Kind::VReg; } // 新增
private:
Operand(Kind kind, PhysReg reg, int imm, std::string symbol = "");
Kind kind_;
PhysReg reg_ = PhysReg::W0;
VRegId vreg_id_ = 0; // 新增
RegClass reg_class_ = RegClass::GPR; // 新增
int imm_ = 0;
std::string symbol_;
};
```
### 1.2 修改 `MachineFunction`:添加虚拟寄存器管理
```cpp
class MachineFunction {
public:
// ... 现有接口不变 ...
// 新增:虚拟寄存器分配
VRegId CreateVReg(RegClass rc);
RegClass GetVRegClass(VRegId id) const;
int GetNumVRegs() const { return next_vreg_id_ - 1; }
// 新增callee-saved 管理
void SetCalleeSavedRegs(const std::vector<PhysReg>& regs);
const std::vector<PhysReg>& GetCalleeSavedRegs() const;
private:
// ... 现有字段 ...
int next_vreg_id_ = 1;
std::vector<RegClass> vreg_classes_; // index = vreg_id - 1
std::vector<PhysReg> callee_saved_regs_;
};
```
### 1.3 修改 `MachineBasicBlock`:添加 CFG 信息
```cpp
class MachineBasicBlock {
public:
// ... 现有接口不变 ...
// 新增CFG 前驱/后继
void AddSuccessor(MachineBasicBlock* succ);
void AddPredecessor(MachineBasicBlock* pred);
const std::vector<MachineBasicBlock*>& GetSuccessors() const;
const std::vector<MachineBasicBlock*>& GetPredecessors() const;
void ClearCFG(); // 用于重建 CFG
private:
std::string name_;
std::vector<MachineInstr> instructions_;
std::vector<MachineBasicBlock*> succs_; // 新增
std::vector<MachineBasicBlock*> preds_; // 新增
};
```
### 1.4 修改 `src/mir/MIRInstr.cpp`
```cpp
Operand Operand::VReg(VRegId id, RegClass rc) {
Operand op(Kind::VReg, PhysReg::W0, 0);
op.vreg_id_ = id;
op.reg_class_ = rc;
return op;
}
```
---
## Step 2: 构建 CFG新增 pass
### 2.1 新增文件 `src/mir/BuildCFG.cpp`
在 MIR 生成之后,通过分析每个基本块末尾的跳转指令构建 CFG
```cpp
namespace mir {
void BuildCFG(MachineFunction& function) {
// 先清除旧的 CFG 信息(支持重建)
for (auto& bb_ptr : function.GetBlocks()) {
bb_ptr->ClearCFG();
}
auto& blocks = function.GetBlocks();
for (size_t idx = 0; idx < blocks.size(); ++idx) {
auto& bb = *blocks[idx];
auto& insts = bb.GetInstructions();
bool has_unconditional_branch = false;
for (const auto& inst : insts) {
Opcode op = inst.GetOpcode();
if (op == Opcode::B || op == Opcode::Bcond || op == Opcode::FBcond ||
op == Opcode::Cbnz || op == Opcode::Cbz) {
const auto& target_name = inst.GetOperands()[0].GetSymbol();
MachineBasicBlock* target = function.FindBlock(target_name);
if (target) {
bb.AddSuccessor(target);
target->AddPredecessor(&bb);
}
if (op == Opcode::B) has_unconditional_branch = true;
}
}
// Fall-through如果块没有以无条件跳转/Ret 结束
if (!has_unconditional_branch && !insts.empty()) {
Opcode last_op = insts.back().GetOpcode();
if (last_op != Opcode::Ret && last_op != Opcode::B) {
if (idx + 1 < blocks.size()) {
auto* next_bb = blocks[idx + 1].get();
bb.AddSuccessor(next_bb);
next_bb->AddPredecessor(&bb);
}
}
}
}
}
} // namespace mir
```
### 2.2 在 `MIR.h` 中声明
```cpp
void BuildCFG(MachineFunction& function);
```
---
## Step 3: 栈槽提升 Pass核心创新
### 3.1 新增文件 `src/mir/StackPromotion.cpp`
**核心算法**:扫描所有指令,识别"可提升栈槽"——即仅通过 `LoadStack`/`StoreStack` 访问的 4 字节标量槽(非数组、非指针、非通过 `LoadStackOffset`/`StoreStackOffset`/`LoadStackAddr` 访问的)。
```cpp
#include "mir/MIR.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace mir {
namespace {
struct SlotUsageInfo {
bool is_promotable = true;
RegClass reg_class = RegClass::GPR;
int load_count = 0;
int store_count = 0;
};
std::unordered_map<int, SlotUsageInfo> AnalyzeSlotUsage(
const MachineFunction& function) {
std::unordered_map<int, SlotUsageInfo> info;
for (const auto& bb_ptr : function.GetBlocks()) {
for (const auto& inst : bb_ptr->GetInstructions()) {
const auto& ops = inst.GetOperands();
// LoadStackOffset/StoreStackOffset/LoadStackAddr 用于数组访问,不可提升
if (inst.GetOpcode() == Opcode::LoadStackOffset ||
inst.GetOpcode() == Opcode::StoreStackOffset ||
inst.GetOpcode() == Opcode::LoadStackAddr) {
if (ops.size() >= 2 &&
ops[1].GetKind() == Operand::Kind::FrameIndex) {
info[ops[1].GetFrameIndex()].is_promotable = false;
}
continue;
}
// LoadStack: ops[0]=dst_reg, ops[1]=frame_index
if (inst.GetOpcode() == Opcode::LoadStack) {
if (ops.size() >= 2 &&
ops[1].GetKind() == Operand::Kind::FrameIndex) {
int slot = ops[1].GetFrameIndex();
auto& si = info[slot];
si.load_count++;
PhysReg dst = ops[0].GetReg();
if (dst >= PhysReg::S0 && dst <= PhysReg::S10) {
si.reg_class = RegClass::FPR;
} else if (dst >= PhysReg::X0 && dst <= PhysReg::X11) {
// 64位加载 = 指针,不提升
si.is_promotable = false;
}
}
continue;
}
// StoreStack: ops[0]=src_reg, ops[1]=frame_index
if (inst.GetOpcode() == Opcode::StoreStack) {
if (ops.size() >= 2 &&
ops[1].GetKind() == Operand::Kind::FrameIndex) {
int slot = ops[1].GetFrameIndex();
auto& si = info[slot];
si.store_count++;
PhysReg src = ops[0].GetReg();
if (src >= PhysReg::S0 && src <= PhysReg::S10) {
si.reg_class = RegClass::FPR;
} else if (src >= PhysReg::X0 && src <= PhysReg::X11) {
si.is_promotable = false;
}
}
continue;
}
}
}
// 排除大小不为 4 的槽(数组、指针等)
for (const auto& slot : function.GetFrameSlots()) {
if (slot.size != 4) {
info[slot.index].is_promotable = false;
}
}
return info;
}
} // namespace
void RunStackPromotion(MachineFunction& function) {
auto slot_info = AnalyzeSlotUsage(function);
// 为每个可提升的栈槽分配一个虚拟寄存器
std::unordered_map<int, VRegId> slot_vreg_map;
for (auto& [slot_idx, si] : slot_info) {
if (!si.is_promotable) continue;
VRegId vreg = function.CreateVReg(si.reg_class);
slot_vreg_map[slot_idx] = vreg;
}
if (slot_vreg_map.empty()) return;
// 重写指令LoadStack/StoreStack -> MovReg/FMovReg with VReg
for (auto& bb_ptr : function.GetBlocks()) {
auto& insts = bb_ptr->GetInstructions();
std::vector<MachineInstr> new_insts;
new_insts.reserve(insts.size());
for (auto& inst : insts) {
const auto& ops = inst.GetOperands();
// LoadStack reg, [slot] -> MovReg reg, %vreg
if (inst.GetOpcode() == Opcode::LoadStack &&
ops.size() >= 2 &&
ops[1].GetKind() == Operand::Kind::FrameIndex) {
int slot = ops[1].GetFrameIndex();
auto it = slot_vreg_map.find(slot);
if (it != slot_vreg_map.end()) {
RegClass rc = slot_info[slot].reg_class;
Opcode mov_op = (rc == RegClass::FPR)
? Opcode::FMovReg : Opcode::MovReg;
new_insts.emplace_back(mov_op, std::vector<Operand>{
ops[0], // dst physreg (暂保留,后续被 regalloc 处理)
Operand::VReg(it->second, rc)
});
continue;
}
}
// StoreStack reg, [slot] -> MovReg %vreg, reg
if (inst.GetOpcode() == Opcode::StoreStack &&
ops.size() >= 2 &&
ops[1].GetKind() == Operand::Kind::FrameIndex) {
int slot = ops[1].GetFrameIndex();
auto it = slot_vreg_map.find(slot);
if (it != slot_vreg_map.end()) {
RegClass rc = slot_info[slot].reg_class;
Opcode mov_op = (rc == RegClass::FPR)
? Opcode::FMovReg : Opcode::MovReg;
new_insts.emplace_back(mov_op, std::vector<Operand>{
Operand::VReg(it->second, rc),
ops[0] // src physreg
});
continue;
}
}
// 其他指令保持不变
new_insts.push_back(std::move(inst));
}
insts = std::move(new_insts);
}
}
} // namespace mir
```
### 3.2 提升前后的 MIR 对比(示例)
**提升前(当前输出)**
```asm
; a = b + c
LoadStack w8, [slot_b] ; 从栈加载 b
LoadStack w9, [slot_c] ; 从栈加载 c
AddRR w8, w8, w9 ; w8 = b + c
StoreStack w8, [slot_a] ; 结果存回栈
```
**提升后**
```asm
; a = b + c
MovReg w8, %vreg2 ; w8 <- vreg_b
MovReg w9, %vreg3 ; w9 <- vreg_c
AddRR w8, w8, w9 ; w8 = b + c
MovReg %vreg1, w8 ; vreg_a <- w8
```
### 3.3 为什么这样设计
提升后的形态保留了 Lowering 生成的"固定寄存器做中间计算"的模式,只是把 Load/Store 换成了与 vreg 的 copy。好处
1. **正确性有保障**:即使 regalloc 全部 spill等价于恢复原始行为
2. **mov coalescing 自然发生**:如果 vreg1 被分配到 w8则最后的 `MovReg w8, w8` 变为 no-op
3. **渐进式优化**:后续可以做更激进的提升(把 AddRR 的操作数也换成 vreg
---
## Step 4: 活跃性分析Liveness Analysis
### 4.1 新增文件 `src/mir/Liveness.cpp`
```cpp
#include "mir/MIR.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace mir {
namespace {
// 判断指令的第一个操作数是否为 def写入
bool FirstOpIsDef(Opcode op) {
switch (op) {
case Opcode::MovImm: case Opcode::MovReg:
case Opcode::FMovImm: case Opcode::FMovReg:
case Opcode::LoadStack: case Opcode::LoadStackOffset:
case Opcode::LoadStackAddr: case Opcode::LoadIndirect:
case Opcode::LoadGlobal: case Opcode::LoadGlobalAddr:
case Opcode::AddRI: case Opcode::SubRI:
case Opcode::AddRR: case Opcode::SubRR:
case Opcode::MulRR: case Opcode::DivRR: case Opcode::ModRR:
case Opcode::LsrRI: case Opcode::LslRI: case Opcode::LslRR:
case Opcode::FAddRR: case Opcode::FSubRR:
case Opcode::FMulRR: case Opcode::FDivRR: case Opcode::FSqrtRR:
case Opcode::SIToFP: case Opcode::FPToSI:
case Opcode::CmpRR: case Opcode::FCmpRR:
return true;
default:
return false;
}
}
struct BlockDefUse {
std::unordered_set<VRegId> def; // 块中定义的 vreg
std::unordered_set<VRegId> use; // 块中 upward-exposed use 的 vreg
};
BlockDefUse ComputeBlockDefUse(const MachineBasicBlock& bb) {
BlockDefUse info;
for (const auto& inst : bb.GetInstructions()) {
const auto& ops = inst.GetOperands();
bool first_is_def = FirstOpIsDef(inst.GetOpcode());
// 先收集 use在 def 之前的引用)
for (size_t i = 0; i < ops.size(); ++i) {
if (!ops[i].IsVReg()) continue;
if (i == 0 && first_is_def) continue; // 这是 def不是 use
VRegId v = ops[i].GetVReg();
if (info.def.find(v) == info.def.end()) {
info.use.insert(v); // 在本块定义前被使用
}
}
// 再收集 def
if (first_is_def && !ops.empty() && ops[0].IsVReg()) {
info.def.insert(ops[0].GetVReg());
}
}
return info;
}
} // namespace
LivenessInfo ComputeLiveness(MachineFunction& function) {
// 1. 计算每个块的 def/use
std::unordered_map<MachineBasicBlock*, BlockDefUse> block_info;
for (auto& bb_ptr : function.GetBlocks()) {
block_info[bb_ptr.get()] = ComputeBlockDefUse(*bb_ptr);
}
// 2. 初始化
LivenessInfo result;
for (auto& bb_ptr : function.GetBlocks()) {
result.live_in[bb_ptr.get()] = {};
result.live_out[bb_ptr.get()] = {};
}
// 3. 迭代数据流直到不动点
// live_out[B] = U live_in[S], for S in succ(B)
// live_in[B] = use[B] U (live_out[B] - def[B])
bool changed = true;
while (changed) {
changed = false;
auto& blocks = function.GetBlocks();
for (int i = (int)blocks.size() - 1; i >= 0; --i) {
auto* bb = blocks[i].get();
auto& bi = block_info[bb];
// live_out = union of successors' live_in
std::unordered_set<VRegId> new_out;
for (auto* succ : bb->GetSuccessors()) {
for (VRegId v : result.live_in[succ]) {
new_out.insert(v);
}
}
// live_in = use U (live_out - def)
std::unordered_set<VRegId> new_in = bi.use;
for (VRegId v : new_out) {
if (bi.def.find(v) == bi.def.end()) {
new_in.insert(v);
}
}
if (new_in != result.live_in[bb] || new_out != result.live_out[bb]) {
changed = true;
result.live_in[bb] = std::move(new_in);
result.live_out[bb] = std::move(new_out);
}
}
}
return result;
}
} // namespace mir
```
### 4.2 在 `MIR.h` 中声明数据结构和接口
```cpp
// 活跃性分析结果
struct LivenessInfo {
std::unordered_map<MachineBasicBlock*, std::unordered_set<VRegId>> live_in;
std::unordered_map<MachineBasicBlock*, std::unordered_set<VRegId>> live_out;
};
LivenessInfo ComputeLiveness(MachineFunction& function);
```
---
## Step 5: 干涉图构建
### 5.1 数据结构(在 `RegAlloc.cpp` 中)
```cpp
namespace {
struct InterferenceGraph {
int num_vregs;
std::vector<std::unordered_set<VRegId>> adj; // 邻接表
std::vector<int> degree; // 度数
explicit InterferenceGraph(int n)
: num_vregs(n), adj(n + 1), degree(n + 1, 0) {}
void AddEdge(VRegId u, VRegId v) {
if (u == v) return;
if (adj[u].insert(v).second) degree[u]++;
if (adj[v].insert(u).second) degree[v]++;
}
void RemoveNode(VRegId v) {
for (VRegId neighbor : adj[v]) {
adj[neighbor].erase(v);
degree[neighbor]--;
}
adj[v].clear();
degree[v] = 0;
}
};
} // namespace
```
### 5.2 构建算法
基于活跃性分析结果,逆序遍历每个基本块的指令,维护当前活跃集合:
```cpp
namespace {
InterferenceGraph BuildInterferenceGraph(
MachineFunction& function,
const LivenessInfo& liveness,
std::vector<bool>& crosses_call) { // 输出vreg 是否跨越 call
int n = function.GetNumVRegs();
InterferenceGraph graph(n);
crosses_call.assign(n + 1, false);
for (auto& bb_ptr : function.GetBlocks()) {
auto* bb = bb_ptr.get();
std::unordered_set<VRegId> live = liveness.live_out.at(bb);
// 逆序遍历指令
auto& insts = bb->GetInstructions();
for (int i = (int)insts.size() - 1; i >= 0; --i) {
const auto& inst = insts[i];
const auto& ops = inst.GetOperands();
// 处理 Call 指令:所有当前活跃的 vreg 都 crosses_call
if (inst.GetOpcode() == Opcode::Bl) {
for (VRegId v : live) {
crosses_call[v] = true;
}
}
// 收集 def 和 use
bool first_is_def = FirstOpIsDef(inst.GetOpcode());
std::vector<VRegId> defs, uses;
for (size_t j = 0; j < ops.size(); ++j) {
if (!ops[j].IsVReg()) continue;
if (j == 0 && first_is_def) {
defs.push_back(ops[j].GetVReg());
} else {
uses.push_back(ops[j].GetVReg());
}
}
// 对每个 def与所有当前 live 的 vreg 建立干涉边
// 例外mov d, s 指令中d 不与 s 干涉(允许合并)
bool is_move = (inst.GetOpcode() == Opcode::MovReg ||
inst.GetOpcode() == Opcode::FMovReg);
std::unordered_set<VRegId> use_set(uses.begin(), uses.end());
for (VRegId d : defs) {
for (VRegId l : live) {
if (l == d) continue;
if (is_move && use_set.count(l)) continue; // mov coalescing 友好
graph.AddEdge(d, l);
}
}
// 更新 live 集合:移除 def加入 use
for (VRegId d : defs) live.erase(d);
for (VRegId u : uses) live.insert(u);
}
}
return graph;
}
} // namespace
```
### 5.3 物理寄存器干涉的处理
提升后的 MIR 中仍然存在物理寄存器(如 `AddRR w8, w8, w9`),这些物理寄存器的生命期极短(通常在一条指令内定义并被下一条 `MovReg %vreg, w8` 消费)。
**关键洞察**:由于物理寄存器不参与图着色(它们已经是物理的),我们不需要在干涉图中建模它们。但在 Select 阶段,需要避免将 vreg 分配到与其"相邻"物理寄存器使用点冲突的寄存器上。
**简化处理**当前设计中vreg 的生命期与物理寄存器的使用点基本不重叠vreg 在 `MovReg %vreg, physreg` 处被定义,在 `MovReg physreg, %vreg` 处被使用)。因此物理寄存器约束可以通过 mov coalescing 自然解决,无需显式建模。
---
## Step 6: 图着色寄存器分配 (核心算法)
### 6.1 重写 `src/mir/RegAlloc.cpp`
整个文件结构如下:
```cpp
#include "mir/MIR.h"
#include <algorithm>

@ -0,0 +1,14 @@
{
"permissions": {
"allow": [
"Bash(mkdir -p \"C:\\\\Users\\\\郑同学\\\\.claude\\\\plans\")",
"Bash(mkdir -p \"/c/Users/郑同学/.claude/plans\")",
"Bash(mkdir -p \"/mnt/c/Users/郑同学/.claude/plans\")",
"Bash(echo \"test\")",
"Bash(rm \"/mnt/c/Users/郑同学/.claude/plans/test.md\")",
"Bash(cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER=/c/mingw/mingw64/bin/g++.exe -DCMAKE_C_COMPILER=/c/mingw/mingw64/bin/gcc.exe)",
"Bash(cmake -S . -B build -G \"MinGW Makefiles\" -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER=\"C:/mingw/MinGW/bin/g++.exe\" -DCMAKE_C_COMPILER=\"C:/mingw/MinGW/bin/gcc.exe\")",
"Bash(wsl *)"
]
}
}

@ -197,6 +197,7 @@ enum class Opcode {
Store,
Ret,
Gep, // getelementptr数组元素地址计算
Phi, // SSA phi 节点
};
enum class CmpOp { Eq, Ne, Lt, Le, Gt, Ge };
@ -214,6 +215,8 @@ class User : public Value {
protected:
// 统一的 operand 入口。
void AddOperand(Value* value);
// 清空所有 operand不清除 use 关系,调用者需自行处理)。
void ClearOperands();
private:
std::vector<Value*> operands_;
@ -355,6 +358,21 @@ class GepInst : public Instruction {
Value* GetIndex() const;
};
// PhiInstSSA phi 节点,用于控制流汇合点合并不同前驱传来的值。
// 操作数布局:[val_0, bb_0, val_1, bb_1, ...]
class PhiInst : public Instruction {
public:
PhiInst(std::shared_ptr<Type> ty, std::string name);
// 添加一组 (value, incoming_block) 入边。
void AddIncoming(Value* val, BasicBlock* bb);
size_t GetNumIncoming() const;
Value* GetIncomingValue(size_t i) const;
BasicBlock* GetIncomingBlock(size_t i) const;
void SetIncomingValue(size_t i, Value* val);
// 移除来自指定前驱块的入边。
void RemoveIncomingBlock(BasicBlock* bb);
};
// BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。
// 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。
class BasicBlock : public Value {
@ -364,10 +382,30 @@ class BasicBlock : public Value {
void SetParent(Function* parent);
bool HasTerminator() const;
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
std::vector<std::unique_ptr<Instruction>>& MutableInstructions();
const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const;
std::vector<BasicBlock*>& MutablePredecessors();
std::vector<BasicBlock*>& MutableSuccessors();
void AddPredecessor(BasicBlock* pred);
void AddSuccessor(BasicBlock* succ);
void RemovePredecessor(BasicBlock* pred);
void RemoveSuccessor(BasicBlock* succ);
// 在块头部(所有 phi 之后)插入指令。
template <typename T, typename... Args>
T* Prepend(Args&&... args) {
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
// 插入到第一条非-phi 指令之前
auto it = instructions_.begin();
while (it != instructions_.end() &&
(*it)->GetOpcode() == Opcode::Phi) {
++it;
}
instructions_.insert(it, std::move(inst));
return ptr;
}
template <typename T, typename... Args>
T* Append(Args&&... args) {
if (HasTerminator()) {
@ -380,6 +418,12 @@ class BasicBlock : public Value {
instructions_.push_back(std::move(inst));
return ptr;
}
// 在块的最前面插入 phi 节点。
PhiInst* PrependPhi(std::shared_ptr<Type> ty, const std::string& name);
// 删除指定指令(从块中移除 ownership
void RemoveInstruction(Instruction* inst);
// 判断块是否为空(不含任何指令)。
bool IsEmpty() const { return instructions_.empty(); }
private:
Function* parent_ = nullptr;
@ -403,6 +447,9 @@ class Function : public Value {
size_t GetNumParams() const;
Argument* GetArgument(size_t index) const;
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
std::vector<std::unique_ptr<BasicBlock>>& MutableBlocks();
// 删除指定基本块(从函数中移除 ownership
void RemoveBlock(BasicBlock* bb);
// 外部函数声明(无函数体,打印为 declare
void SetExternal(bool v) { is_external_ = v; }

@ -50,6 +50,7 @@ enum class Opcode {
AddRI,
SubRI,
AddRR,
AddRR_UXTW, // add xN, xM, wK, uxtw零扩展W寄存器后加到X寄存器
SubRR,
MulRR,
DivRR,
@ -77,27 +78,49 @@ enum class Opcode {
Ret,
};
// 虚拟寄存器类别
enum class VRegClass {
GPR, // 通用寄存器 (w0-w11)
GPR64, // 64位通用寄存器 (x0-x11)
FPR, // 浮点寄存器 (s0-s10)
};
class Operand {
public:
enum class Kind { Reg, Imm, FrameIndex, Symbol };
enum class Kind { Reg, VReg, Imm, FrameIndex, Symbol };
static Operand Reg(PhysReg reg);
static Operand VReg(int vreg_id, VRegClass rc = VRegClass::GPR);
static Operand Imm(int value);
static Operand FrameIndex(int index);
static Operand Symbol(std::string name);
Kind GetKind() const { return kind_; }
bool IsReg() const { return kind_ == Kind::Reg; }
bool IsVReg() const { return kind_ == Kind::VReg; }
bool IsImm() const { return kind_ == Kind::Imm; }
bool IsFrameIndex() const { return kind_ == Kind::FrameIndex; }
bool IsSymbol() const { return kind_ == Kind::Symbol; }
PhysReg GetReg() const { return reg_; }
int GetVRegId() const { return vreg_id_; }
VRegClass GetVRegClass() const { return vreg_class_; }
int GetImm() const { return imm_; }
int GetFrameIndex() const { return imm_; }
const std::string& GetSymbol() const { return symbol_; }
// 用于寄存器分配后替换 VReg → Reg
void AssignPhysReg(PhysReg reg);
private:
Operand(Kind kind, PhysReg reg, int imm, std::string symbol = "");
Operand(int vreg_id, VRegClass rc);
Kind kind_;
PhysReg reg_;
int imm_;
PhysReg reg_ = PhysReg::W0;
int vreg_id_ = -1;
VRegClass vreg_class_ = VRegClass::GPR;
int imm_ = 0;
std::string symbol_;
};
@ -107,6 +130,9 @@ class MachineInstr {
Opcode GetOpcode() const { return opcode_; }
const std::vector<Operand>& GetOperands() const { return operands_; }
std::vector<Operand>& GetOperands() { return operands_; }
void SetOperand(size_t i, Operand op);
private:
Opcode opcode_;
@ -130,9 +156,18 @@ class MachineBasicBlock {
MachineInstr& Append(Opcode opcode,
std::initializer_list<Operand> operands = {});
// CFG 支持
void AddSuccessor(MachineBasicBlock* succ);
void AddPredecessor(MachineBasicBlock* pred);
const std::vector<MachineBasicBlock*>& GetSuccessors() const { return successors_; }
const std::vector<MachineBasicBlock*>& GetPredecessors() const { return predecessors_; }
void ClearCFG() { successors_.clear(); predecessors_.clear(); }
private:
std::string name_;
std::vector<MachineInstr> instructions_;
std::vector<MachineBasicBlock*> successors_;
std::vector<MachineBasicBlock*> predecessors_;
};
class MachineFunction {
@ -153,15 +188,26 @@ class MachineFunction {
FrameSlot& GetFrameSlot(int index);
const FrameSlot& GetFrameSlot(int index) const;
const std::vector<FrameSlot>& GetFrameSlots() const { return frame_slots_; }
std::vector<FrameSlot>& GetMutableFrameSlots() { return frame_slots_; }
int GetFrameSize() const { return frame_size_; }
void SetFrameSize(int size) { frame_size_ = size; }
// 虚拟寄存器管理
int CreateVReg(VRegClass rc = VRegClass::GPR);
int GetNumVRegs() const { return next_vreg_id_; }
// Callee-saved 寄存器管理
void AddUsedCalleeSaved(PhysReg reg);
const std::vector<PhysReg>& GetUsedCalleeSaved() const { return used_callee_saved_; }
private:
std::string name_;
std::vector<std::unique_ptr<MachineBasicBlock>> blocks_;
std::vector<FrameSlot> frame_slots_;
int frame_size_ = 0;
int next_vreg_id_ = 0;
std::vector<PhysReg> used_callee_saved_;
};
class MachineModule {

@ -0,0 +1,266 @@
#!/usr/bin/env bash
# 批量回归测试脚本:对 test/test_case 下全部 .sy 用例执行 IR 语义验证。
# 用法:./scripts/run_all_tests.sh [--ir | --asm | --both]
#
# 默认只测 IR通过 llc + clang 编译运行)。
# --asm 只测汇编(需要 aarch64-linux-gnu-gcc + qemu-aarch64
# --both 同时测 IR 和汇编。
set -uo pipefail
mode="ir"
if [[ "${1:-}" == "--asm" ]]; then
mode="asm"
elif [[ "${1:-}" == "--both" ]]; then
mode="both"
fi
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
ROOT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)"
cd "$ROOT_DIR"
compiler="./build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "❌ 未找到编译器: $compiler" >&2
echo "请先构建cmake -S . -B build -DCMAKE_BUILD_TYPE=Release && cmake --build build -j \$(nproc)" >&2
exit 1
fi
total=0
passed=0
failed=0
skipped=0
fail_list=()
run_ir_test() {
local sy="$1"
local dir
dir=$(dirname "$sy")
local stem
stem=$(basename "$sy" .sy)
local out_dir="test/test_result/ir_batch"
mkdir -p "$out_dir"
local out_file="$out_dir/$stem.ll"
local stdin_file="$dir/$stem.in"
local expected_file="$dir/$stem.out"
local stdout_file="$out_dir/$stem.stdout"
local actual_file="$out_dir/$stem.actual.out"
# 生成 IR
if ! timeout 30 "$compiler" --emit-ir "$sy" > "$out_file" 2>/dev/null; then
echo " [SKIP-IR] $sy (编译器报错或超时)"
return 2
fi
# 需要 llc + clang
if ! command -v llc >/dev/null 2>&1 || ! command -v clang >/dev/null 2>&1; then
echo " [SKIP-IR] $sy (缺少 llc/clang)"
return 2
fi
local obj="$out_dir/$stem.o"
local exe="$out_dir/$stem"
if ! llc -filetype=obj "$out_file" -o "$obj" 2>/dev/null; then
echo " [SKIP-IR] $sy (llc 编译失败)"
return 2
fi
if ! clang -no-pie "$obj" sylib/sylib.c -o "$exe" -lm 2>/dev/null; then
echo " [SKIP-IR] $sy (clang 链接失败)"
return 2
fi
set +e
# performance 用例给更长的超时时间
local run_timeout=30
if [[ "$sy" == *"performance"* ]]; then
run_timeout=1000
fi
if [[ -f "$stdin_file" ]]; then
timeout $run_timeout "$exe" < "$stdin_file" > "$stdout_file" 2>/dev/null
else
timeout $run_timeout "$exe" > "$stdout_file" 2>/dev/null
fi
local status=$?
set -e
# timeout 返回 124 表示超时,标记为 SKIP
if [[ $status -eq 124 ]]; then
echo " [SKIP-IR] $sy (运行超时)"
return 2
fi
# 组装实际输出
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
if [[ ! -f "$expected_file" ]]; then
echo " [SKIP-IR] $sy (无预期输出)"
return 2
fi
if diff -q <(sed -e 's/\r$//' -e '$a\\' "$expected_file") \
<(sed -e 's/\r$//' -e '$a\\' "$actual_file") >/dev/null 2>&1; then
return 0
else
return 1
fi
}
run_asm_test() {
local sy="$1"
local dir
dir=$(dirname "$sy")
local stem
stem=$(basename "$sy" .sy)
local out_dir="test/test_result/asm_batch"
mkdir -p "$out_dir"
local asm_file="$out_dir/$stem.s"
local stdin_file="$dir/$stem.in"
local expected_file="$dir/$stem.out"
local stdout_file="$out_dir/$stem.stdout"
local actual_file="$out_dir/$stem.actual.out"
local exe="$out_dir/$stem"
# 生成汇编
if ! timeout 30 "$compiler" --emit-asm "$sy" > "$asm_file" 2>/dev/null; then
echo " [SKIP-ASM] $sy (编译器报错或超时)"
return 2
fi
if ! command -v aarch64-linux-gnu-gcc >/dev/null 2>&1; then
echo " [SKIP-ASM] $sy (缺少 aarch64-linux-gnu-gcc)"
return 2
fi
if ! timeout 30 aarch64-linux-gnu-gcc "$asm_file" sylib/sylib.c -o "$exe" -static 2>/dev/null; then
echo " [SKIP-ASM] $sy (汇编/链接失败)"
return 2
fi
if ! command -v qemu-aarch64 >/dev/null 2>&1; then
echo " [SKIP-ASM] $sy (缺少 qemu-aarch64)"
return 2
fi
set +e
# performance 用例给更长的超时时间
local run_timeout=30
if [[ "$sy" == *"performance"* ]]; then
run_timeout=1000
fi
if [[ -f "$stdin_file" ]]; then
timeout $run_timeout qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file" 2>/dev/null
else
timeout $run_timeout qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" > "$stdout_file" 2>/dev/null
fi
local status=$?
set -e
# timeout 返回 124 表示超时,标记为 SKIP
if [[ $status -eq 124 ]]; then
echo " [SKIP-ASM] $sy (运行超时)"
return 2
fi
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
if [[ ! -f "$expected_file" ]]; then
echo " [SKIP-ASM] $sy (无预期输出)"
return 2
fi
if diff -q <(sed -e 's/\r$//' -e '$a\\' "$expected_file") \
<(sed -e 's/\r$//' -e '$a\\' "$actual_file") >/dev/null 2>&1; then
return 0
else
return 1
fi
}
echo "========================================"
echo " Lab4 批量回归测试 (mode: $mode)"
echo "========================================"
echo ""
# 收集所有测试文件
mapfile -t test_files < <(find test/test_case -name '*.sy' | sort)
for sy in "${test_files[@]}"; do
total=$((total + 1))
if [[ "$mode" == "ir" || "$mode" == "both" ]]; then
run_ir_test "$sy"
rc=$?
if [[ $rc -eq 0 ]]; then
echo " [PASS-IR] $sy"
passed=$((passed + 1))
elif [[ $rc -eq 1 ]]; then
echo " [FAIL-IR] $sy"
failed=$((failed + 1))
fail_list+=("$sy (IR)")
else
skipped=$((skipped + 1))
fi
fi
if [[ "$mode" == "asm" || "$mode" == "both" ]]; then
run_asm_test "$sy"
rc=$?
if [[ $rc -eq 0 ]]; then
echo " [PASS-ASM] $sy"
if [[ "$mode" == "asm" ]]; then
passed=$((passed + 1))
fi
elif [[ $rc -eq 1 ]]; then
echo " [FAIL-ASM] $sy"
if [[ "$mode" == "asm" ]]; then
failed=$((failed + 1))
fi
fail_list+=("$sy (ASM)")
else
if [[ "$mode" == "asm" ]]; then
skipped=$((skipped + 1))
fi
fi
fi
done
echo ""
echo "========================================"
echo " 测试结果汇总"
echo "========================================"
echo " 总计: $total"
echo " 通过: $passed"
echo " 失败: $failed"
echo " 跳过: $skipped"
echo ""
if [[ ${#fail_list[@]} -gt 0 ]]; then
echo " 失败用例:"
for f in "${fail_list[@]}"; do
echo " - $f"
done
echo ""
fi
if [[ $failed -gt 0 ]]; then
echo "❌ 存在失败用例"
exit 1
else
echo "✅ 全部通过(跳过 $skipped 个)"
exit 0
fi

@ -65,4 +65,55 @@ void BasicBlock::AddSuccessor(BasicBlock* succ) {
successors_.push_back(succ);
}
std::vector<std::unique_ptr<Instruction>>& BasicBlock::MutableInstructions() {
return instructions_;
}
std::vector<BasicBlock*>& BasicBlock::MutablePredecessors() {
return predecessors_;
}
std::vector<BasicBlock*>& BasicBlock::MutableSuccessors() {
return successors_;
}
void BasicBlock::RemovePredecessor(BasicBlock* pred) {
predecessors_.erase(
std::remove(predecessors_.begin(), predecessors_.end(), pred),
predecessors_.end());
}
void BasicBlock::RemoveSuccessor(BasicBlock* succ) {
successors_.erase(
std::remove(successors_.begin(), successors_.end(), succ),
successors_.end());
}
PhiInst* BasicBlock::PrependPhi(std::shared_ptr<Type> ty,
const std::string& name) {
auto inst = std::make_unique<PhiInst>(std::move(ty), name);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.insert(instructions_.begin(), std::move(inst));
return ptr;
}
void BasicBlock::RemoveInstruction(Instruction* inst) {
if (!inst) return;
// 清除该指令所有操作数的 use 关系
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
auto* operand = inst->GetOperand(i);
if (operand) {
operand->RemoveUse(inst, i);
}
}
inst->SetParent(nullptr);
instructions_.erase(
std::remove_if(instructions_.begin(), instructions_.end(),
[inst](const std::unique_ptr<Instruction>& p) {
return p.get() == inst;
}),
instructions_.end());
}
} // namespace ir

@ -3,6 +3,7 @@
// - 记录函数属性/元信息(按需要扩展)
#include "ir/IR.h"
#include <algorithm>
#include <stdexcept>
#include "utils/Log.h"
@ -55,4 +56,25 @@ const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {
return blocks_;
}
std::vector<std::unique_ptr<BasicBlock>>& Function::MutableBlocks() {
return blocks_;
}
void Function::RemoveBlock(BasicBlock* bb) {
if (!bb) return;
if (bb == entry_) {
entry_ = nullptr;
}
bb->SetParent(nullptr);
blocks_.erase(
std::remove_if(blocks_.begin(), blocks_.end(),
[bb](const std::unique_ptr<BasicBlock>& p) {
return p.get() == bb;
}),
blocks_.end());
if (!entry_ && !blocks_.empty()) {
entry_ = blocks_.front().get();
}
}
} // namespace ir

@ -62,6 +62,8 @@ static const char* OpcodeToString(Opcode op) {
return "ret";
case Opcode::Gep:
return "getelementptr";
case Opcode::Phi:
return "phi";
}
return "?";
}
@ -275,6 +277,18 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
}
break;
}
case Opcode::Phi: {
auto* phi = static_cast<const PhiInst*>(inst);
os << " " << phi->GetName() << " = phi "
<< TypeToString(*phi->GetType()) << " ";
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (i != 0) os << ", ";
os << "[ " << ValueToString(phi->GetIncomingValue(i))
<< ", %" << phi->GetIncomingBlock(i)->GetName() << " ]";
}
os << "\n";
break;
}
}
}
}

@ -70,6 +70,10 @@ void User::AddOperand(Value* value) {
value->AddUse(this, operand_index);
}
void User::ClearOperands() {
operands_.clear();
}
Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)), opcode_(op) {}
@ -370,4 +374,53 @@ GepInst::GepInst(std::shared_ptr<Type> ptr_ty, Value* base, Value* index,
Value* GepInst::GetBase() const { return GetOperand(0); }
Value* GepInst::GetIndex() const { return GetOperand(1); }
// ---- PhiInst ----
PhiInst::PhiInst(std::shared_ptr<Type> ty, std::string name)
: Instruction(Opcode::Phi, std::move(ty), std::move(name)) {}
void PhiInst::AddIncoming(Value* val, BasicBlock* bb) {
if (!val || !bb) {
throw std::runtime_error(FormatError("ir", "PhiInst::AddIncoming 参数不完整"));
}
AddOperand(val);
AddOperand(bb);
}
size_t PhiInst::GetNumIncoming() const { return GetNumOperands() / 2; }
Value* PhiInst::GetIncomingValue(size_t i) const {
return GetOperand(i * 2);
}
BasicBlock* PhiInst::GetIncomingBlock(size_t i) const {
return static_cast<BasicBlock*>(GetOperand(i * 2 + 1));
}
void PhiInst::SetIncomingValue(size_t i, Value* val) {
SetOperand(i * 2, val);
}
void PhiInst::RemoveIncomingBlock(BasicBlock* bb) {
// 收集需要保留的 (val, bb) 对
std::vector<std::pair<Value*, BasicBlock*>> keep;
for (size_t i = 0; i < GetNumIncoming(); ++i) {
if (GetIncomingBlock(i) != bb) {
keep.push_back({GetIncomingValue(i), GetIncomingBlock(i)});
}
}
// 清除旧的 use 关系
for (size_t i = 0; i < GetNumOperands(); ++i) {
auto* old = GetOperand(i);
if (old) old->RemoveUse(this, i);
}
// 清空 operand 列表
ClearOperands();
// 重建保留的入边
for (auto& [val, blk] : keep) {
AddOperand(val);
AddOperand(blk);
}
}
} // namespace ir

@ -1,4 +1,171 @@
// 支配树分析:
// - 构建/查询 Dominator Tree 及相关关系
// - 为 mem2reg、CFG 优化与循环分析提供基础能力
//
// 算法简单迭代数据流方式计算支配关系Cooper, Harvey, Kennedy
// 支配边界采用经典 DF 算法
#include "ir/IR.h"
#include <algorithm>
#include <cassert>
#include <functional>
#include <queue>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace analysis {
// ---------- DominatorTree ----------
class DominatorTree {
public:
explicit DominatorTree(Function& func) : func_(func) { Compute(); }
// idom[bb] 返回 bb 的直接支配者entry 的 idom 为自身。
BasicBlock* GetIDom(BasicBlock* bb) const {
auto it = idom_.find(bb);
return it != idom_.end() ? it->second : nullptr;
}
// 判断 a 是否支配 b。
bool Dominates(BasicBlock* a, BasicBlock* b) const {
if (!a || !b) return false;
while (b) {
if (b == a) return true;
auto* p = GetIDom(b);
if (p == b) break; // entry
b = p;
}
return false;
}
// 返回 bb 的支配边界。
const std::vector<BasicBlock*>& GetDF(BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = df_.find(bb);
return it != df_.end() ? it->second : empty;
}
// 返回支配树中 bb 的孩子列表。
const std::vector<BasicBlock*>& GetChildren(BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = children_.find(bb);
return it != children_.end() ? it->second : empty;
}
// 按逆后序返回所有基本块。
const std::vector<BasicBlock*>& GetRPO() const { return rpo_; }
private:
void Compute() {
auto* entry = func_.GetEntry();
if (!entry) return;
// 1. 计算逆后序RPO
ComputeRPO(entry);
if (rpo_.empty()) return;
// 2. 初始化
for (auto* bb : rpo_) {
idom_[bb] = nullptr;
rpo_index_[bb] = 0;
}
for (size_t i = 0; i < rpo_.size(); ++i) {
rpo_index_[rpo_[i]] = i;
}
idom_[entry] = entry;
// 3. 迭代计算 idomCooper-Harvey-Kennedy 算法)
bool changed = true;
while (changed) {
changed = false;
for (auto* bb : rpo_) {
if (bb == entry) continue;
BasicBlock* new_idom = nullptr;
for (auto* pred : bb->GetPredecessors()) {
if (idom_.count(pred) && idom_[pred] != nullptr) {
if (!new_idom) {
new_idom = pred;
} else {
new_idom = Intersect(new_idom, pred);
}
}
}
if (new_idom && idom_[bb] != new_idom) {
idom_[bb] = new_idom;
changed = true;
}
}
}
// 4. 建立 children 映射
for (auto* bb : rpo_) {
auto* p = GetIDom(bb);
if (p && p != bb) {
children_[p].push_back(bb);
}
}
// 5. 计算支配边界
ComputeDF();
}
void ComputeRPO(BasicBlock* entry) {
std::unordered_set<BasicBlock*> visited;
std::vector<BasicBlock*> post_order;
std::function<void(BasicBlock*)> dfs = [&](BasicBlock* bb) {
visited.insert(bb);
for (auto* succ : bb->GetSuccessors()) {
if (!visited.count(succ)) {
dfs(succ);
}
}
post_order.push_back(bb);
};
dfs(entry);
rpo_.assign(post_order.rbegin(), post_order.rend());
}
BasicBlock* Intersect(BasicBlock* b1, BasicBlock* b2) {
while (b1 != b2) {
while (rpo_index_[b1] > rpo_index_[b2]) b1 = idom_[b1];
while (rpo_index_[b2] > rpo_index_[b1]) b2 = idom_[b2];
}
return b1;
}
void ComputeDF() {
for (auto* bb : rpo_) {
df_[bb] = {};
}
for (auto* bb : rpo_) {
if (bb->GetPredecessors().size() < 2) continue;
for (auto* pred : bb->GetPredecessors()) {
auto* runner = pred;
while (runner && runner != idom_[bb]) {
// 避免重复
auto& df_set = df_[runner];
if (std::find(df_set.begin(), df_set.end(), bb) == df_set.end()) {
df_set.push_back(bb);
}
if (runner == idom_[runner]) break;
runner = idom_[runner];
}
}
}
}
Function& func_;
std::vector<BasicBlock*> rpo_;
std::unordered_map<BasicBlock*, size_t> rpo_index_;
std::unordered_map<BasicBlock*, BasicBlock*> idom_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> children_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> df_;
};
} // namespace analysis
} // namespace ir

@ -1,4 +1,190 @@
// CFG 简化:
// - 删除不可达块、合并空块、简化分支等
// - 改善 IR 结构,便于后续优化与后端生成
//
// 包含以下简化:
// 1. 常量条件分支折叠condbr(const) -> br
// 2. 删除不可达块
// 3. 合并只有一个前驱的后继块(线性块合并)
// 4. 跳过空的跳转块(线程跳转)
#include "ir/IR.h"
#include <algorithm>
#include <functional>
#include <queue>
#include <unordered_set>
#include <vector>
namespace ir {
namespace passes {
// 收集从 entry 可达的所有基本块
static std::unordered_set<BasicBlock*> ComputeReachable(Function& func) {
std::unordered_set<BasicBlock*> reachable;
auto* entry = func.GetEntry();
if (!entry) return reachable;
std::queue<BasicBlock*> worklist;
worklist.push(entry);
reachable.insert(entry);
while (!worklist.empty()) {
auto* bb = worklist.front();
worklist.pop();
for (auto* succ : bb->GetSuccessors()) {
if (!reachable.count(succ)) {
reachable.insert(succ);
worklist.push(succ);
}
}
}
return reachable;
}
bool RunCFGSimplify(Function& func, Context& ctx) {
if (func.IsExternal()) return false;
bool changed = false;
// ==== 1. 常量条件分支折叠 ====
for (const auto& bb : func.GetBlocks()) {
if (!bb) continue;
if (!bb->HasTerminator()) continue;
auto& insts = bb->MutableInstructions();
auto* last = insts.back().get();
if (last->GetOpcode() == Opcode::CondBr) {
auto* cbr = static_cast<CondBranchInst*>(last);
auto* cond_ci = dynamic_cast<ConstantInt*>(cbr->GetCond());
if (cond_ci) {
BasicBlock* taken = cond_ci->GetValue() != 0
? cbr->GetTrueBlock()
: cbr->GetFalseBlock();
BasicBlock* not_taken = cond_ci->GetValue() != 0
? cbr->GetFalseBlock()
: cbr->GetTrueBlock();
// 从 not_taken 的前驱中移除当前块
not_taken->RemovePredecessor(bb.get());
bb->RemoveSuccessor(not_taken);
// 移除 condbr插入 br
bb->RemoveInstruction(last);
bb->Append<BranchInst>(Type::GetVoidType(), taken);
// 清理 not_taken 中 phi 的来自 bb 的入边
for (auto& inst_ptr : not_taken->MutableInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst_ptr.get());
if (!phi) break;
phi->RemoveIncomingBlock(bb.get());
}
changed = true;
}
}
}
// ==== 2. 删除不可达块 ====
auto reachable = ComputeReachable(func);
std::vector<BasicBlock*> unreachable;
for (const auto& bb : func.GetBlocks()) {
if (!bb) continue;
if (!reachable.count(bb.get())) {
unreachable.push_back(bb.get());
}
}
for (auto* bb : unreachable) {
// 从后继的前驱列表中移除
for (auto* succ : bb->GetSuccessors()) {
succ->RemovePredecessor(bb);
}
// 清除块中所有指令的 use 关系
std::vector<Instruction*> all_insts;
for (auto& inst_ptr : bb->MutableInstructions()) {
all_insts.push_back(inst_ptr.get());
}
for (auto* inst : all_insts) {
// 如果指令还有使用者,用 undef (0) 替换
if (!inst->GetUses().empty()) {
if (inst->GetType() && inst->GetType()->IsInt32()) {
inst->ReplaceAllUsesWith(ctx.GetConstInt(0));
}
}
bb->RemoveInstruction(inst);
}
func.RemoveBlock(bb);
changed = true;
}
// ==== 3. 合并线性块 ====
// 如果一个块 B 只有一个前驱 A且 A 只有一个后继 B
// 则将 B 的指令合并到 A 的末尾。
bool merged = true;
while (merged) {
merged = false;
for (const auto& bb : func.GetBlocks()) {
if (!bb) continue;
if (bb->GetPredecessors().size() != 1) continue;
auto* pred = bb->GetPredecessors()[0];
if (pred->GetSuccessors().size() != 1) continue;
if (pred == bb.get()) continue; // 自循环
// pred 的 terminator 必须是 br无条件跳转到 bb
if (!pred->HasTerminator()) continue;
auto& pred_insts = pred->MutableInstructions();
auto* term = pred_insts.back().get();
if (term->GetOpcode() != Opcode::Br) continue;
// 删除 pred 的 terminator
pred->RemoveInstruction(term);
// 将 bb 的所有指令移到 pred
auto& bb_insts = bb->MutableInstructions();
for (auto& inst_ptr : bb_insts) {
inst_ptr->SetParent(pred);
}
for (auto& inst_ptr : bb_insts) {
pred_insts.push_back(std::move(inst_ptr));
}
bb_insts.clear();
// 更新 CFGpred 继承 bb 的后继
pred->MutableSuccessors().clear();
for (auto* succ : bb->GetSuccessors()) {
pred->AddSuccessor(succ);
// 在 succ 的前驱中把 bb 替换为 pred
auto& succ_preds = succ->MutablePredecessors();
for (auto& p : succ_preds) {
if (p == bb.get()) p = pred;
}
// 更新 succ 中 phi 的入边
for (auto& inst_ptr : succ->MutableInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst_ptr.get());
if (!phi) break;
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (phi->GetIncomingBlock(i) == bb.get()) {
phi->SetOperand(i * 2 + 1, pred);
}
}
}
}
// 移除 bb
bb->MutablePredecessors().clear();
bb->MutableSuccessors().clear();
func.RemoveBlock(bb.get());
merged = true;
changed = true;
break; // 重新开始迭代
}
}
return changed;
}
} // namespace passes
} // namespace ir

@ -1,4 +1,123 @@
// 公共子表达式消除CSE
// - 识别并复用重复计算的等价表达式
// - 典型放置在 ConstFold 之后、DCE 之前
// - 当前为 Lab4 的框架占位,具体算法由实验实现
//
// 算法:在每个基本块内,使用哈希表记录已出现的表达式。
// 当遇到相同操作码 + 相同操作数的指令时,复用之前的结果。
// 这是局部 CSELocal CSE只在基本块内消除。
#include "ir/IR.h"
#include <string>
#include <unordered_map>
#include <vector>
namespace ir {
namespace passes {
namespace {
// 构造表达式的唯一键opcode + operands 的组合
struct ExprKey {
Opcode opcode;
CmpOp cmp_op; // 仅 Cmp 使用
std::vector<Value*> operands;
bool operator==(const ExprKey& other) const {
if (opcode != other.opcode) return false;
if (opcode == Opcode::Cmp && cmp_op != other.cmp_op) return false;
if (operands.size() != other.operands.size()) return false;
for (size_t i = 0; i < operands.size(); ++i) {
if (operands[i] != other.operands[i]) return false;
}
return true;
}
};
struct ExprKeyHash {
size_t operator()(const ExprKey& key) const {
size_t h = std::hash<int>()(static_cast<int>(key.opcode));
if (key.opcode == Opcode::Cmp) {
h ^= std::hash<int>()(static_cast<int>(key.cmp_op)) << 4;
}
for (auto* v : key.operands) {
h ^= std::hash<void*>()(v) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
// 判断一条指令是否可以做 CSE
bool IsCSECandidate(Instruction* inst) {
Opcode op = inst->GetOpcode();
// 纯计算指令可以做 CSE
switch (op) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod:
case Opcode::Cmp:
case Opcode::Gep:
return true;
default:
return false;
}
}
ExprKey MakeKey(Instruction* inst) {
ExprKey key;
key.opcode = inst->GetOpcode();
key.cmp_op = CmpOp::Eq; // 默认值
if (inst->GetOpcode() == Opcode::Cmp) {
key.cmp_op = static_cast<CmpInst*>(inst)->GetCmpOp();
}
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
key.operands.push_back(inst->GetOperand(i));
}
return key;
}
} // namespace
bool RunCSE(Function& func) {
if (func.IsExternal()) return false;
bool changed = false;
for (const auto& bb : func.GetBlocks()) {
if (!bb) continue;
std::unordered_map<ExprKey, Instruction*, ExprKeyHash> expr_map;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : bb->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!IsCSECandidate(inst)) continue;
ExprKey key = MakeKey(inst);
auto it = expr_map.find(key);
if (it != expr_map.end()) {
// 找到了等价表达式,复用之前的结果
inst->ReplaceAllUsesWith(it->second);
to_remove.push_back(inst);
changed = true;
} else {
expr_map[key] = inst;
}
}
for (auto* inst : to_remove) {
bb->RemoveInstruction(inst);
}
}
return changed;
}
} // namespace passes
} // namespace ir

@ -1,4 +1,273 @@
// IR 常量折叠:
// - 折叠可判定的常量表达式
// - 简化常量控制流分支(按实现范围裁剪)
//
// 遍历每个函数中的每条指令,如果操作数全为常量,则编译期求值并替换。
#include "ir/IR.h"
#include <vector>
namespace ir {
namespace passes {
bool RunConstFold(Function& func) {
if (func.IsExternal()) return false;
bool changed = false;
for (const auto& bb : func.GetBlocks()) {
if (!bb) continue;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : bb->GetInstructions()) {
auto* inst = inst_ptr.get();
Opcode op = inst->GetOpcode();
// 二元运算折叠
if (op == Opcode::Add || op == Opcode::Sub || op == Opcode::Mul ||
op == Opcode::Div || op == Opcode::Mod) {
auto* bin = static_cast<BinaryInst*>(inst);
auto* lhs_ci = dynamic_cast<ConstantInt*>(bin->GetLhs());
auto* rhs_ci = dynamic_cast<ConstantInt*>(bin->GetRhs());
if (lhs_ci && rhs_ci) {
int lv = lhs_ci->GetValue();
int rv = rhs_ci->GetValue();
int result = 0;
bool valid = true;
switch (op) {
case Opcode::Add: result = lv + rv; break;
case Opcode::Sub: result = lv - rv; break;
case Opcode::Mul: result = lv * rv; break;
case Opcode::Div:
if (rv == 0) { valid = false; break; }
result = lv / rv;
break;
case Opcode::Mod:
if (rv == 0) { valid = false; break; }
result = lv % rv;
break;
default: valid = false; break;
}
if (valid) {
// 需要 Context 来创建常量,通过 entry block 获取
auto& ctx = bb->GetParent()->GetBlocks().front()->GetParent()
? *bb->GetParent()
: *bb->GetParent();
// 直接在 uses 上替换:找到结果常量
// 由于 ConstantInt 由 Context 管理,我们需要 Module 的 Context。
// 但 Function 没有直接指向 Module 的指针。
// Workaround: 遍历 uses 替换时用已存在的 ConstantInt。
// 实际上,我们可以在 PassManager 中传入 Module& 引用。
// 这里先用简单方法:检查 lhs_ci 或 rhs_ci 的值是否与 result 相同。
ConstantInt* result_ci = nullptr;
if (lhs_ci->GetValue() == result) {
result_ci = lhs_ci;
} else if (rhs_ci->GetValue() == result) {
result_ci = rhs_ci;
}
// 如果没有现成常量,暂时跳过(由 PassManager 传入 Context 后再处理)
// 实际上让 PassManager 传入 Module 是更好的做法。
// 这里我们假设 RunConstFold 接收的是 Module 级别的调用。
// 先标记但不替换,等后续改进。
// 更好的方案:利用 bin 的 parent 的 parent (Function) 暂存。
// 但 Function 也没有 Context。
// 最终方案:在 PassManager 中传入 Context&。
// 简化:这里先不做替换,留给 ConstProp + PassManager 配合完成。
// 实际上我们可以直接用 new ConstantInt但这会导致内存泄漏。
// 正确方案:让 RunConstFold 接受 Context& 参数。
(void)result_ci;
(void)result;
}
}
continue;
}
// 比较指令折叠
if (op == Opcode::Cmp) {
auto* cmp = static_cast<CmpInst*>(inst);
auto* lhs_ci = dynamic_cast<ConstantInt*>(cmp->GetLhs());
auto* rhs_ci = dynamic_cast<ConstantInt*>(cmp->GetRhs());
if (lhs_ci && rhs_ci) {
// 同上,需要 Context 创建结果常量
(void)cmp;
}
continue;
}
// 常量条件分支折叠
if (op == Opcode::CondBr) {
auto* cbr = static_cast<CondBranchInst*>(inst);
auto* cond_ci = dynamic_cast<ConstantInt*>(cbr->GetCond());
if (cond_ci) {
// 同上,需要修改 BB 的 terminator
(void)cbr;
}
continue;
}
}
for (auto* inst : to_remove) {
bb->RemoveInstruction(inst);
}
}
return changed;
}
// 接受 Module 引用的版本,可以使用 Context 创建常量
bool RunConstFoldWithCtx(Function& func, Context& ctx) {
if (func.IsExternal()) return false;
bool changed = false;
for (const auto& bb : func.GetBlocks()) {
if (!bb) continue;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : bb->GetInstructions()) {
auto* inst = inst_ptr.get();
Opcode op = inst->GetOpcode();
// 二元运算折叠i32
if (op == Opcode::Add || op == Opcode::Sub || op == Opcode::Mul ||
op == Opcode::Div || op == Opcode::Mod) {
auto* bin = static_cast<BinaryInst*>(inst);
auto* lhs_ci = dynamic_cast<ConstantInt*>(bin->GetLhs());
auto* rhs_ci = dynamic_cast<ConstantInt*>(bin->GetRhs());
if (lhs_ci && rhs_ci) {
int lv = lhs_ci->GetValue();
int rv = rhs_ci->GetValue();
int result = 0;
bool valid = true;
switch (op) {
case Opcode::Add: result = lv + rv; break;
case Opcode::Sub: result = lv - rv; break;
case Opcode::Mul: result = lv * rv; break;
case Opcode::Div:
if (rv == 0) { valid = false; break; }
result = lv / rv;
break;
case Opcode::Mod:
if (rv == 0) { valid = false; break; }
result = lv % rv;
break;
default: valid = false; break;
}
if (valid) {
auto* result_ci = ctx.GetConstInt(result);
inst->ReplaceAllUsesWith(result_ci);
to_remove.push_back(inst);
changed = true;
}
}
// 代数化简x + 0 = x, x * 1 = x, x - 0 = x, x * 0 = 0, x / 1 = x
if (!lhs_ci || !rhs_ci) {
auto* bin2 = static_cast<BinaryInst*>(inst);
auto* lci = dynamic_cast<ConstantInt*>(bin2->GetLhs());
auto* rci = dynamic_cast<ConstantInt*>(bin2->GetRhs());
if (op == Opcode::Add) {
if (rci && rci->GetValue() == 0) {
inst->ReplaceAllUsesWith(bin2->GetLhs());
to_remove.push_back(inst);
changed = true;
} else if (lci && lci->GetValue() == 0) {
inst->ReplaceAllUsesWith(bin2->GetRhs());
to_remove.push_back(inst);
changed = true;
}
} else if (op == Opcode::Sub) {
if (rci && rci->GetValue() == 0) {
inst->ReplaceAllUsesWith(bin2->GetLhs());
to_remove.push_back(inst);
changed = true;
} else if (bin2->GetLhs() == bin2->GetRhs()) {
auto* zero = ctx.GetConstInt(0);
inst->ReplaceAllUsesWith(zero);
to_remove.push_back(inst);
changed = true;
}
} else if (op == Opcode::Mul) {
if (rci && rci->GetValue() == 1) {
inst->ReplaceAllUsesWith(bin2->GetLhs());
to_remove.push_back(inst);
changed = true;
} else if (lci && lci->GetValue() == 1) {
inst->ReplaceAllUsesWith(bin2->GetRhs());
to_remove.push_back(inst);
changed = true;
} else if ((rci && rci->GetValue() == 0) ||
(lci && lci->GetValue() == 0)) {
auto* zero = ctx.GetConstInt(0);
inst->ReplaceAllUsesWith(zero);
to_remove.push_back(inst);
changed = true;
}
} else if (op == Opcode::Div) {
if (rci && rci->GetValue() == 1) {
inst->ReplaceAllUsesWith(bin2->GetLhs());
to_remove.push_back(inst);
changed = true;
}
} else if (op == Opcode::Mod) {
if (rci && rci->GetValue() == 1) {
auto* zero = ctx.GetConstInt(0);
inst->ReplaceAllUsesWith(zero);
to_remove.push_back(inst);
changed = true;
}
}
}
continue;
}
// 比较指令折叠
if (op == Opcode::Cmp) {
auto* cmp = static_cast<CmpInst*>(inst);
auto* lhs_ci = dynamic_cast<ConstantInt*>(cmp->GetLhs());
auto* rhs_ci = dynamic_cast<ConstantInt*>(cmp->GetRhs());
if (lhs_ci && rhs_ci) {
int lv = lhs_ci->GetValue();
int rv = rhs_ci->GetValue();
int result = 0;
switch (cmp->GetCmpOp()) {
case CmpOp::Eq: result = (lv == rv) ? 1 : 0; break;
case CmpOp::Ne: result = (lv != rv) ? 1 : 0; break;
case CmpOp::Lt: result = (lv < rv) ? 1 : 0; break;
case CmpOp::Le: result = (lv <= rv) ? 1 : 0; break;
case CmpOp::Gt: result = (lv > rv) ? 1 : 0; break;
case CmpOp::Ge: result = (lv >= rv) ? 1 : 0; break;
}
auto* result_ci = ctx.GetConstInt(result);
inst->ReplaceAllUsesWith(result_ci);
to_remove.push_back(inst);
changed = true;
}
continue;
}
}
for (auto* inst : to_remove) {
bb->RemoveInstruction(inst);
}
}
return changed;
}
} // namespace passes
} // namespace ir

@ -2,4 +2,121 @@
// - 沿 use-def 关系传播已知常量
// - 将可替换的 SSA 值改写为常量,暴露更多折叠机会
// - 常与 ConstFold、DCE、CFGSimplify 迭代配合使用
//
// 算法:工作列表驱动的稀疏条件常量传播(简化版 SCCP
// 遍历所有指令,如果某条指令的结果可以确定为常量,
// 则用该常量替换所有使用点,并将受影响的指令加入工作列表继续传播。
#include "ir/IR.h"
#include <queue>
#include <unordered_set>
#include <vector>
namespace ir {
namespace passes {
bool RunConstProp(Function& func, Context& ctx) {
if (func.IsExternal()) return false;
bool changed = false;
for (const auto& bb : func.GetBlocks()) {
if (!bb) continue;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : bb->GetInstructions()) {
auto* inst = inst_ptr.get();
Opcode op = inst->GetOpcode();
Value* replacement = nullptr;
// 二元运算:两个操作数都是常量则折叠
if (op == Opcode::Add || op == Opcode::Sub || op == Opcode::Mul ||
op == Opcode::Div || op == Opcode::Mod) {
auto* bin = static_cast<BinaryInst*>(inst);
auto* lhs_ci = dynamic_cast<ConstantInt*>(bin->GetLhs());
auto* rhs_ci = dynamic_cast<ConstantInt*>(bin->GetRhs());
if (lhs_ci && rhs_ci) {
int lv = lhs_ci->GetValue();
int rv = rhs_ci->GetValue();
int result = 0;
bool valid = true;
switch (op) {
case Opcode::Add: result = lv + rv; break;
case Opcode::Sub: result = lv - rv; break;
case Opcode::Mul: result = lv * rv; break;
case Opcode::Div:
if (rv == 0) { valid = false; }
else { result = lv / rv; }
break;
case Opcode::Mod:
if (rv == 0) { valid = false; }
else { result = lv % rv; }
break;
default: valid = false; break;
}
if (valid) {
replacement = ctx.GetConstInt(result);
}
}
}
// 比较指令
if (op == Opcode::Cmp) {
auto* cmp = static_cast<CmpInst*>(inst);
auto* lhs_ci = dynamic_cast<ConstantInt*>(cmp->GetLhs());
auto* rhs_ci = dynamic_cast<ConstantInt*>(cmp->GetRhs());
if (lhs_ci && rhs_ci) {
int lv = lhs_ci->GetValue();
int rv = rhs_ci->GetValue();
int result = 0;
switch (cmp->GetCmpOp()) {
case CmpOp::Eq: result = (lv == rv) ? 1 : 0; break;
case CmpOp::Ne: result = (lv != rv) ? 1 : 0; break;
case CmpOp::Lt: result = (lv < rv) ? 1 : 0; break;
case CmpOp::Le: result = (lv <= rv) ? 1 : 0; break;
case CmpOp::Gt: result = (lv > rv) ? 1 : 0; break;
case CmpOp::Ge: result = (lv >= rv) ? 1 : 0; break;
}
replacement = ctx.GetConstInt(result);
}
}
// Phi 节点:如果所有入边值相同(或只有一个非自引用的值),可简化
if (op == Opcode::Phi) {
auto* phi = static_cast<PhiInst*>(inst);
Value* unique_val = nullptr;
bool all_same = true;
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
Value* v = phi->GetIncomingValue(i);
if (v == phi) continue; // 跳过自引用
if (!unique_val) {
unique_val = v;
} else if (v != unique_val) {
all_same = false;
break;
}
}
if (all_same && unique_val) {
replacement = unique_val;
}
}
if (replacement && replacement != inst) {
inst->ReplaceAllUsesWith(replacement);
to_remove.push_back(inst);
changed = true;
}
}
for (auto* inst : to_remove) {
bb->RemoveInstruction(inst);
}
}
return changed;
}
} // namespace passes
} // namespace ir

@ -1,4 +1,130 @@
// 死代码删除DCE
// - 删除无用指令与无用基本块
// - 通常与 CFG 简化配合使用
//
// 算法:标记 + 清扫
// 1. 标记所有有副作用的指令为"有用"ret, br, condbr, store, call
// 2. 沿数据依赖反向传播,将有用指令依赖的定义也标记为有用
// 3. 删除所有未被标记的非终结指令
#include "ir/IR.h"
#include <queue>
#include <unordered_set>
#include <vector>
namespace ir {
namespace passes {
// 判断一条指令是否有副作用(不可随意删除)
static bool HasSideEffect(Instruction* inst) {
Opcode op = inst->GetOpcode();
// 终结指令、store、call 均有副作用
if (op == Opcode::Ret || op == Opcode::Br || op == Opcode::CondBr ||
op == Opcode::Store || op == Opcode::Call) {
return true;
}
return false;
}
bool RunDCE(Function& func) {
if (func.IsExternal()) return false;
bool changed = false;
// 标记阶段
std::unordered_set<Instruction*> useful;
std::queue<Instruction*> worklist;
// 初始标记:所有有副作用的指令
for (const auto& bb : func.GetBlocks()) {
if (!bb) continue;
for (const auto& inst_ptr : bb->GetInstructions()) {
auto* inst = inst_ptr.get();
if (HasSideEffect(inst)) {
useful.insert(inst);
worklist.push(inst);
}
}
}
// 反向传播:有用指令的操作数定义也标记为有用
while (!worklist.empty()) {
auto* inst = worklist.front();
worklist.pop();
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
auto* operand = inst->GetOperand(i);
if (!operand) continue;
auto* def_inst = dynamic_cast<Instruction*>(operand);
if (def_inst && !useful.count(def_inst)) {
useful.insert(def_inst);
worklist.push(def_inst);
}
}
}
// 清扫阶段:删除未标记为有用的指令
for (const auto& bb : func.GetBlocks()) {
if (!bb) continue;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : bb->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!useful.count(inst)) {
to_remove.push_back(inst);
}
}
for (auto* inst : to_remove) {
// 如果还有使用者,不能直接删除(用 undef/0 替换)
// 在标记-清扫正确的前提下,未标记的指令不应有有用的使用者
// 但安全起见,先检查
if (!inst->GetUses().empty()) {
// 仍有使用者 —— 跳过(可能是循环引用的 phi
continue;
}
bb->RemoveInstruction(inst);
changed = true;
}
}
return changed;
}
// 简化版 DCE只删除没有使用者且无副作用的指令更安全的实现
bool RunSimpleDCE(Function& func) {
if (func.IsExternal()) return false;
bool changed = false;
bool local_changed = true;
while (local_changed) {
local_changed = false;
for (const auto& bb : func.GetBlocks()) {
if (!bb) continue;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : bb->GetInstructions()) {
auto* inst = inst_ptr.get();
// 跳过有副作用的指令
if (HasSideEffect(inst)) continue;
// 跳过 alloca可能后续还会用到
if (inst->GetOpcode() == Opcode::Alloca) continue;
// 如果没有使用者,可以安全删除
if (inst->GetUses().empty()) {
to_remove.push_back(inst);
}
}
for (auto* inst : to_remove) {
bb->RemoveInstruction(inst);
local_changed = true;
changed = true;
}
}
}
return changed;
}
} // namespace passes
} // namespace ir

@ -1,4 +1,336 @@
// Mem2RegSSA 构造):
// - 将局部变量的 alloca/load/store 提升为 SSA 形式
// - 插入 PHI 并重写使用,依赖支配树等分析
//
// 算法流程:
// 1. 识别可提升的 alloca标量仅通过 load/store 访问)
// 2. 计算支配树与支配边界
// 3. 在支配边界处插入 phi
// 4. 沿支配树重命名变量
// 5. 删除冗余 alloca/load/store
#include "ir/IR.h"
#include <algorithm>
#include <cassert>
#include <functional>
#include <queue>
#include <set>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace passes {
// ============ 内联支配树(与 analysis 版本相同) ============
namespace {
class DomTree {
public:
explicit DomTree(Function& func) : func_(func) { Compute(); }
BasicBlock* GetIDom(BasicBlock* bb) const {
auto it = idom_.find(bb);
return it != idom_.end() ? it->second : nullptr;
}
const std::vector<BasicBlock*>& GetDF(BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = df_.find(bb);
return it != df_.end() ? it->second : empty;
}
const std::vector<BasicBlock*>& GetChildren(BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = children_.find(bb);
return it != children_.end() ? it->second : empty;
}
const std::vector<BasicBlock*>& GetRPO() const { return rpo_; }
private:
void Compute() {
auto* entry = func_.GetEntry();
if (!entry) return;
ComputeRPO(entry);
if (rpo_.empty()) return;
for (auto* bb : rpo_) {
idom_[bb] = nullptr;
rpo_index_[bb] = 0;
}
for (size_t i = 0; i < rpo_.size(); ++i) {
rpo_index_[rpo_[i]] = i;
}
idom_[entry] = entry;
bool changed = true;
while (changed) {
changed = false;
for (auto* bb : rpo_) {
if (bb == entry) continue;
BasicBlock* new_idom = nullptr;
for (auto* pred : bb->GetPredecessors()) {
if (idom_.count(pred) && idom_[pred] != nullptr) {
if (!new_idom) {
new_idom = pred;
} else {
new_idom = Intersect(new_idom, pred);
}
}
}
if (new_idom && idom_[bb] != new_idom) {
idom_[bb] = new_idom;
changed = true;
}
}
}
for (auto* bb : rpo_) {
auto* p = GetIDom(bb);
if (p && p != bb) {
children_[p].push_back(bb);
}
}
ComputeDF();
}
void ComputeRPO(BasicBlock* entry) {
std::unordered_set<BasicBlock*> visited;
std::vector<BasicBlock*> post_order;
std::function<void(BasicBlock*)> dfs = [&](BasicBlock* bb) {
visited.insert(bb);
for (auto* succ : bb->GetSuccessors()) {
if (!visited.count(succ)) {
dfs(succ);
}
}
post_order.push_back(bb);
};
dfs(entry);
rpo_.assign(post_order.rbegin(), post_order.rend());
}
BasicBlock* Intersect(BasicBlock* b1, BasicBlock* b2) {
while (b1 != b2) {
while (rpo_index_[b1] > rpo_index_[b2]) b1 = idom_[b1];
while (rpo_index_[b2] > rpo_index_[b1]) b2 = idom_[b2];
}
return b1;
}
void ComputeDF() {
for (auto* bb : rpo_) {
df_[bb] = {};
}
for (auto* bb : rpo_) {
if (bb->GetPredecessors().size() < 2) continue;
for (auto* pred : bb->GetPredecessors()) {
auto* runner = pred;
while (runner && runner != idom_[bb]) {
auto& df_set = df_[runner];
if (std::find(df_set.begin(), df_set.end(), bb) == df_set.end()) {
df_set.push_back(bb);
}
if (runner == idom_[runner]) break;
runner = idom_[runner];
}
}
}
}
Function& func_;
std::vector<BasicBlock*> rpo_;
std::unordered_map<BasicBlock*, size_t> rpo_index_;
std::unordered_map<BasicBlock*, BasicBlock*> idom_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> children_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> df_;
};
// 判断一个 alloca 是否可以被提升为寄存器:
// - 必须是标量count == 1
// - 只被 load 和 store 使用
bool IsPromotable(AllocaInst* alloca) {
if (alloca->IsArray()) return false;
for (const auto& use : alloca->GetUses()) {
auto* user = use.GetUser();
if (!user) return false;
auto* inst = dynamic_cast<Instruction*>(user);
if (!inst) return false;
if (inst->GetOpcode() != Opcode::Load &&
inst->GetOpcode() != Opcode::Store) {
return false;
}
// store 只能把 alloca 作为 ptroperand 1不能作为 valoperand 0
if (inst->GetOpcode() == Opcode::Store) {
auto* store = static_cast<StoreInst*>(inst);
if (store->GetPtr() != alloca) return false;
}
}
return true;
}
} // namespace
bool RunMem2Reg(Function& func) {
if (func.IsExternal()) return false;
DomTree dom(func);
// 1. 收集可提升的 alloca
std::vector<AllocaInst*> promotable;
auto* entry = func.GetEntry();
if (!entry) return false;
for (const auto& inst : entry->GetInstructions()) {
if (auto* alloca = dynamic_cast<AllocaInst*>(inst.get())) {
if (IsPromotable(alloca)) {
promotable.push_back(alloca);
}
}
}
if (promotable.empty()) return false;
// 对每个可提升的 alloca 分别执行
for (auto* alloca : promotable) {
// 确定 alloca 值的类型
std::shared_ptr<Type> val_type;
if (alloca->GetType()->IsPtrInt32()) {
val_type = Type::GetInt32Type();
} else if (alloca->GetType()->IsPtrFloat32()) {
val_type = Type::GetFloat32Type();
} else {
continue;
}
// 2. 收集所有 def 块(包含 store 的块)和 use 块(包含 load 的块)
std::unordered_set<BasicBlock*> def_blocks;
std::vector<StoreInst*> stores;
std::vector<LoadInst*> loads;
for (const auto& use : alloca->GetUses()) {
auto* inst = dynamic_cast<Instruction*>(use.GetUser());
if (!inst || !inst->GetParent()) continue;
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
if (store->GetPtr() == alloca) {
def_blocks.insert(store->GetParent());
stores.push_back(store);
}
} else if (auto* load = dynamic_cast<LoadInst*>(inst)) {
loads.push_back(load);
}
}
// 3. 插入 phi 节点(使用迭代支配边界)
// 用 map 精确记录当前 alloca 在每个块中插入的 phi
std::unordered_map<BasicBlock*, PhiInst*> phi_map;
std::unordered_set<BasicBlock*> phi_blocks;
std::queue<BasicBlock*> worklist;
for (auto* bb : def_blocks) {
worklist.push(bb);
}
static int phi_counter = 0;
while (!worklist.empty()) {
auto* bb = worklist.front();
worklist.pop();
for (auto* df_bb : dom.GetDF(bb)) {
if (!phi_blocks.count(df_bb)) {
phi_blocks.insert(df_bb);
auto* phi = df_bb->PrependPhi(val_type,
"%phi." + std::to_string(phi_counter++));
phi_map[df_bb] = phi;
worklist.push(df_bb);
}
}
}
// 4. 重命名:沿支配树 DFS
std::stack<Value*> val_stack;
std::function<void(BasicBlock*)> Rename = [&](BasicBlock* bb) {
size_t stack_size = val_stack.size();
// 处理当前块中我们插入的 phi
auto phi_it = phi_map.find(bb);
if (phi_it != phi_map.end()) {
val_stack.push(phi_it->second);
}
// 遍历块中所有指令
std::vector<Instruction*> to_remove;
for (auto& inst_ptr : bb->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
if (store->GetPtr() == alloca) {
val_stack.push(store->GetValue());
to_remove.push_back(store);
}
} else if (auto* load = dynamic_cast<LoadInst*>(inst)) {
if (load->GetPtr() == alloca) {
Value* cur_val = val_stack.empty() ? nullptr : val_stack.top();
if (cur_val) {
load->ReplaceAllUsesWith(cur_val);
}
to_remove.push_back(load);
}
}
}
// 填充后继块中 phi 的入边
for (auto* succ : bb->GetSuccessors()) {
auto succ_phi_it = phi_map.find(succ);
if (succ_phi_it == phi_map.end()) continue;
Value* cur_val = val_stack.empty() ? nullptr : val_stack.top();
if (cur_val) {
succ_phi_it->second->AddIncoming(cur_val, bb);
}
}
// 递归处理支配树的孩子
for (auto* child : dom.GetChildren(bb)) {
Rename(child);
}
// 恢复栈
while (val_stack.size() > stack_size) {
val_stack.pop();
}
// 删除已标记的指令
for (auto* inst : to_remove) {
bb->RemoveInstruction(inst);
}
};
Rename(entry);
// 5. 删除 alloca
entry->RemoveInstruction(alloca);
// 6. 清理没有入边的 phi
for (auto* bb : dom.GetRPO()) {
std::vector<Instruction*> dead_phis;
for (auto& inst_ptr : bb->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst_ptr.get());
if (!phi) break;
if (phi->GetNumIncoming() == 0) {
dead_phis.push_back(phi);
}
// 如果 phi 只有一个入边,直接替换为该值
if (phi->GetNumIncoming() == 1) {
phi->ReplaceAllUsesWith(phi->GetIncomingValue(0));
dead_phis.push_back(phi);
}
}
for (auto* phi : dead_phis) {
bb->RemoveInstruction(phi);
}
}
}
return true;
}
} // namespace passes
} // namespace ir

@ -1 +1,49 @@
// IR Pass 管理骨架。
// 组织所有优化遍的执行顺序,支持多轮迭代直到 IR 不再变化。
//
// 执行顺序:
// 1. Mem2Reg只跑一次
// 2. 迭代ConstFold -> ConstProp -> CSE -> DCE -> CFGSimplify
// 直到 IR 不再变化或达到最大迭代次数
#include "ir/IR.h"
#include <iostream>
namespace ir {
namespace passes {
// 前向声明各 pass 入口
bool RunMem2Reg(Function& func);
bool RunConstFoldWithCtx(Function& func, Context& ctx);
bool RunConstProp(Function& func, Context& ctx);
bool RunCSE(Function& func);
bool RunSimpleDCE(Function& func);
bool RunCFGSimplify(Function& func, Context& ctx);
static const int kMaxIterations = 20;
void RunAllPasses(Module& module) {
auto& ctx = module.GetContext();
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsExternal()) continue;
RunMem2Reg(*func);
for (int iter = 0; iter < kMaxIterations; ++iter) {
bool changed = false;
changed |= RunConstFoldWithCtx(*func, ctx);
changed |= RunConstProp(*func, ctx);
changed |= RunCSE(*func);
changed |= RunSimpleDCE(*func);
changed |= RunCFGSimplify(*func, ctx);
if (!changed) break;
}
}
}
} // namespace passes
} // namespace ir

@ -15,6 +15,8 @@
#include "irgen/IRGen.h"
#include "mir/MIR.h"
#include "sem/Sema.h"
// 前向声明优化 pass 入口
namespace ir { namespace passes { void RunAllPasses(ir::Module& module); } }
#endif
#include "utils/CLI.h"
#include "utils/Log.h"
@ -139,6 +141,10 @@ int main(int argc, char** argv) {
auto sema = RunSema(*comp_unit);
auto module = GenerateIR(*comp_unit, sema);
// 运行 IR 优化 passMem2Reg + 标量优化迭代)
ir::passes::RunAllPasses(*module);
if (opts.emit_ir) {
ir::IRPrinter printer;
if (need_blank_line) {

@ -281,6 +281,12 @@ void PrintAsm(const MachineModule& module, std::ostream& os) {
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::AddRR_UXTW:
// add xN, xM, wK, uxtw — 零扩展W寄存器后加到X寄存器
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << ", uxtw\n";
break;
case Opcode::SubRR:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "

@ -1,5 +1,6 @@
#include "mir/MIR.h"
#include <algorithm>
#include <stdexcept>
#include <vector>
@ -12,20 +13,45 @@ int AlignTo(int value, int align) {
return ((value + align - 1) / align) * align;
}
// 获取 W 寄存器对应的 X 寄存器
PhysReg WRegToXReg(PhysReg w) {
int idx = static_cast<int>(w) - static_cast<int>(PhysReg::W0);
if (idx >= 0 && idx <= 11) {
return static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + idx);
}
return w;
}
} // namespace
void RunFrameLowering(MachineFunction& function) {
// 计算栈槽偏移
int cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
function.GetFrameSlot(slot.index).offset = -cursor;
}
cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
function.GetFrameSlot(slot.index).offset = -cursor;
// callee-saved 寄存器:为每个分配栈槽并记住索引
const auto& callee_saved = function.GetUsedCalleeSaved();
std::vector<std::pair<PhysReg, int>> callee_save_slots; // (x_reg, slot_index)
for (size_t i = 0; i < callee_saved.size(); ++i) {
PhysReg save_reg = callee_saved[i];
PhysReg x_reg = save_reg;
if (save_reg >= PhysReg::W0 && save_reg <= PhysReg::W11) {
x_reg = WRegToXReg(save_reg);
}
// 浮点 callee-saved 直接用 s 寄存器保存4字节
bool is_float = (save_reg >= PhysReg::S0 && save_reg <= PhysReg::S10);
int slot_size = is_float ? 4 : 8;
int slot = function.CreateFrameIndex(slot_size);
function.GetFrameSlot(slot).offset = -(cursor + static_cast<int>(i + 1) * 8);
callee_save_slots.emplace_back(is_float ? save_reg : x_reg, slot);
}
function.SetFrameSize(AlignTo(cursor, 16));
int callee_save_size = static_cast<int>(callee_saved.size()) * 8;
int total_frame = AlignTo(cursor + callee_save_size, 16);
function.SetFrameSize(total_frame);
// 在每个基本块的开头和结尾插入 prologue/epilogue
for (const auto& bb_ptr : function.GetBlocks()) {
@ -36,10 +62,24 @@ void RunFrameLowering(MachineFunction& function) {
// 只在入口块插入 prologue
if (bb.GetName() == "entry") {
lowered.emplace_back(Opcode::Prologue);
// 保存 callee-saved 寄存器
for (const auto& [reg, slot] : callee_save_slots) {
lowered.emplace_back(Opcode::StoreStack,
std::vector<Operand>{Operand::Reg(reg),
Operand::FrameIndex(slot)});
}
}
for (const auto& inst : insts) {
if (inst.GetOpcode() == Opcode::Ret) {
// 恢复 callee-saved 寄存器
for (const auto& [reg, slot] : callee_save_slots) {
lowered.emplace_back(
Opcode::LoadStack,
std::vector<Operand>{Operand::Reg(reg),
Operand::FrameIndex(slot)});
}
lowered.emplace_back(Opcode::Epilogue);
}
lowered.push_back(inst);

@ -2,8 +2,10 @@
#include <cstdint>
#include <cstring>
#include <iterator>
#include <stdexcept>
#include <unordered_map>
#include <vector>
#include "ir/IR.h"
#include "utils/Log.h"
@ -80,8 +82,9 @@ void EmitAddOffset(PhysReg reg, int byte_offset, MachineBasicBlock& block) {
{Operand::Reg(reg), Operand::Reg(reg), Operand::Imm(byte_offset)});
return;
}
// 使用 X10 统一做 64 位加法,避免 W10/X10 别名问题
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::W10), Operand::Imm(byte_offset)});
{Operand::Reg(PhysReg::X10), Operand::Imm(byte_offset)});
block.Append(Opcode::AddRR,
{Operand::Reg(reg), Operand::Reg(reg), Operand::Reg(PhysReg::X10)});
}
@ -192,15 +195,15 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
geps.emplace(&inst, GepInfo{-1, -1 - index_slot, gv->GetName()});
if (ptr_slot >= 0) {
// 计算地址x9 = &global_array + (index * 4)
// 计算地址x9 = &global_array + (w10 * 4)
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
EmitLslBy2(PhysReg::W10, block);
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::X9), Operand::Symbol(gv->GetName())});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
Operand::Reg(PhysReg::W10)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)});
}
@ -247,10 +250,10 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
EmitLslBy2(PhysReg::W10, block);
// x9 = x9 + w10
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
// x9 = x9 + uxtw(w10)
block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
Operand::Reg(PhysReg::W10)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)});
}
@ -284,15 +287,15 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
geps.emplace(&inst, GepInfo{base_it->second, -1 - index_slot, ""});
if (ptr_slot >= 0) {
// 计算地址x9 = x29 + base_offset + (index * 4)
// 计算地址x9 = x29 + base_offset + (w10 * 4)
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
EmitLslBy2(PhysReg::W10, block);
block.Append(Opcode::LoadStackAddr,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
Operand::Reg(PhysReg::W10)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)});
}
@ -328,7 +331,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
} else {
// 变量索引global_array[var_idx]
int index_slot = -1 - gep_info.byte_offset;
// 1. 加载 index
// 1. 加载 index4字节 W 寄存器)
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
// 2. index * 4
@ -336,10 +339,10 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
// 3. 获取全局数组基址
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)});
// 4. x9 + offset
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
// 4. x9 + w10, uxtw零扩展 W 寄存器后加到 X 寄存器)
block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
Operand::Reg(PhysReg::W10)});
// 5. 存储
block.Append(Opcode::StoreIndirect,
{Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)});
@ -359,9 +362,9 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
block.Append(Opcode::LoadStackAddr,
{Operand::Reg(PhysReg::X9),
Operand::FrameIndex(gep_info.base_slot)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
Operand::Reg(PhysReg::W10)});
block.Append(Opcode::StoreIndirect,
{Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)});
}
@ -429,9 +432,9 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
EmitLslBy2(PhysReg::W10, block);
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
Operand::Reg(PhysReg::W10)});
block.Append(Opcode::LoadIndirect,
{Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)});
}
@ -450,9 +453,9 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
block.Append(Opcode::LoadStackAddr,
{Operand::Reg(PhysReg::X9),
Operand::FrameIndex(gep_info.base_slot)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
Operand::Reg(PhysReg::W10)});
block.Append(Opcode::LoadIndirect,
{Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)});
}
@ -797,7 +800,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X9)});
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(2)});
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X9), Operand::Imm(2)});
block.Append(Opcode::DivRR, {Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X9)});
@ -946,6 +949,14 @@ std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
{Operand::Reg(param_reg), Operand::FrameIndex(slot)});
}
// Phi 信息收集:每个 Phi 对应一个栈槽,以及各入边 (value, pred_block)
struct PhiInfo {
int slot;
bool is_float;
std::vector<std::pair<const ir::Value*, const ir::BasicBlock*>> incomings;
};
std::vector<PhiInfo> phi_infos;
// 遍历所有基本块,生成指令
for (const auto& bb_ptr : func.GetBlocks()) {
const auto& bb = *bb_ptr;
@ -956,6 +967,23 @@ std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
const auto& inst = *ir_insts[i];
auto opcode = inst.GetOpcode();
// Phi 节点:分配栈槽,收集入边信息,后续统一插入 store
if (opcode == ir::Opcode::Phi) {
auto& phi = static_cast<const ir::PhiInst&>(inst);
bool is_float = phi.GetType() && phi.GetType()->IsFloat32();
int slot = machine_func->CreateFrameIndex();
slots.emplace(&phi, slot);
PhiInfo info;
info.slot = slot;
info.is_float = is_float;
for (size_t j = 0; j < phi.GetNumIncoming(); ++j) {
info.incomings.emplace_back(phi.GetIncomingValue(j),
phi.GetIncomingBlock(j));
}
phi_infos.push_back(std::move(info));
continue;
}
// Cmp + CondBr 融合:避免 cmp 结果落栈后再读回。
if (opcode == ir::Opcode::Cmp && i + 1 < ir_insts.size()) {
auto* cmp_inst = dynamic_cast<const ir::CmpInst*>(ir_insts[i].get());
@ -1035,6 +1063,52 @@ std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
LowerInstruction(inst, *machine_func, *current_mbb, slots, geps);
}
}
// Phi 消除:在每个前驱块的跳转指令之前插入 store
for (const auto& phi_info : phi_infos) {
for (const auto& [val, pred_bb] : phi_info.incomings) {
if (!val) continue; // 安全检查
auto it_pred = block_map.find(pred_bb);
if (it_pred == block_map.end()) continue; // 前驱块可能已被优化掉
auto* pred_mbb = it_pred->second;
auto& pred_insts = pred_mbb->GetInstructions();
// 找到跳转指令的位置(从末尾往前找第一条 B/Bcond/Cbnz/FBcond
size_t insert_pos = pred_insts.size();
for (size_t j = pred_insts.size(); j > 0; --j) {
auto op = pred_insts[j - 1].GetOpcode();
if (op == Opcode::B || op == Opcode::Bcond ||
op == Opcode::Cbnz || op == Opcode::FBcond) {
insert_pos = j - 1;
} else {
break;
}
}
// 检查 val 是否在 slots 中或者是常量/全局变量
// 如果是常量EmitValueToReg 能直接处理;否则需要有栈槽
bool can_emit = false;
if (dynamic_cast<const ir::ConstantInt*>(val) ||
dynamic_cast<const ir::ConstantFloat*>(val) ||
dynamic_cast<const ir::GlobalVariable*>(val)) {
can_emit = true;
} else if (slots.find(val) != slots.end()) {
can_emit = true;
}
if (!can_emit) continue; // 跳过无法发射的值
PhysReg tmp = phi_info.is_float ? PhysReg::S8 : PhysReg::W8;
MachineBasicBlock tmp_block("__phi_tmp__");
EmitValueToReg(val, tmp, slots, tmp_block);
tmp_block.Append(Opcode::StoreStack,
{Operand::Reg(tmp), Operand::FrameIndex(phi_info.slot)});
auto& tmp_insts = tmp_block.GetInstructions();
pred_insts.insert(pred_insts.begin() + insert_pos,
std::make_move_iterator(tmp_insts.begin()),
std::make_move_iterator(tmp_insts.end()));
}
}
}
return machine_module;

@ -1,5 +1,6 @@
#include "mir/MIR.h"
#include <algorithm>
#include <utility>
namespace mir {
@ -13,4 +14,16 @@ MachineInstr& MachineBasicBlock::Append(Opcode opcode,
return instructions_.back();
}
void MachineBasicBlock::AddSuccessor(MachineBasicBlock* succ) {
if (std::find(successors_.begin(), successors_.end(), succ) == successors_.end()) {
successors_.push_back(succ);
}
}
void MachineBasicBlock::AddPredecessor(MachineBasicBlock* pred) {
if (std::find(predecessors_.begin(), predecessors_.end(), pred) == predecessors_.end()) {
predecessors_.push_back(pred);
}
}
} // namespace mir

@ -1,5 +1,6 @@
#include "mir/MIR.h"
#include <algorithm>
#include <stdexcept>
#include <utility>
@ -47,6 +48,18 @@ const FrameSlot& MachineFunction::GetFrameSlot(int index) const {
return frame_slots_[index];
}
int MachineFunction::CreateVReg(VRegClass rc) {
(void)rc; // 类别信息在 Operand 中记录
return next_vreg_id_++;
}
void MachineFunction::AddUsedCalleeSaved(PhysReg reg) {
if (std::find(used_callee_saved_.begin(), used_callee_saved_.end(), reg) ==
used_callee_saved_.end()) {
used_callee_saved_.push_back(reg);
}
}
MachineFunction* MachineModule::CreateFunction(std::string name) {
functions_.push_back(std::make_unique<MachineFunction>(std::move(name)));
return functions_.back().get();

@ -1,5 +1,6 @@
#include "mir/MIR.h"
#include <stdexcept>
#include <utility>
namespace mir {
@ -7,8 +8,15 @@ namespace mir {
Operand::Operand(Kind kind, PhysReg reg, int imm, std::string symbol)
: kind_(kind), reg_(reg), imm_(imm), symbol_(std::move(symbol)) {}
Operand::Operand(int vreg_id, VRegClass rc)
: kind_(Kind::VReg), vreg_id_(vreg_id), vreg_class_(rc) {}
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); }
Operand Operand::VReg(int vreg_id, VRegClass rc) {
return Operand(vreg_id, rc);
}
Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::W0, value);
}
@ -21,7 +29,19 @@ Operand Operand::Symbol(std::string name) {
return Operand(Kind::Symbol, PhysReg::W0, 0, std::move(name));
}
void Operand::AssignPhysReg(PhysReg reg) {
kind_ = Kind::Reg;
reg_ = reg;
vreg_id_ = -1;
}
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)
: opcode_(opcode), operands_(std::move(operands)) {}
void MachineInstr::SetOperand(size_t i, Operand op) {
if (i < operands_.size()) {
operands_[i] = std::move(op);
}
}
} // namespace mir

@ -1,64 +1,954 @@
#include "mir/MIR.h"
#include <algorithm>
#include <set>
#include <stack>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>
#include "utils/Log.h"
namespace mir {
namespace {
bool IsAllowedReg(PhysReg reg) {
switch (reg) {
case PhysReg::W0:
case PhysReg::W1:
case PhysReg::W2:
case PhysReg::W3:
case PhysReg::W4:
case PhysReg::W5:
case PhysReg::W6:
case PhysReg::W7:
case PhysReg::W8:
case PhysReg::W9:
case PhysReg::W10:
case PhysReg::X0:
case PhysReg::X1:
case PhysReg::X2:
case PhysReg::X3:
case PhysReg::X4:
case PhysReg::X5:
case PhysReg::X6:
case PhysReg::X7:
case PhysReg::X8:
case PhysReg::X9:
case PhysReg::X10:
case PhysReg::X29:
case PhysReg::X30:
case PhysReg::SP:
case PhysReg::S0:
case PhysReg::S1:
case PhysReg::S2:
case PhysReg::S3:
case PhysReg::S4:
case PhysReg::S5:
case PhysReg::S6:
case PhysReg::S7:
case PhysReg::S8:
case PhysReg::S9:
case PhysReg::S10:
return true;
// ============================================================================
// 物理寄存器定义
// ============================================================================
// 可分配的 GPR (32-bit)w0-w11共12个
static const std::vector<PhysReg> kAllocGPR = {
PhysReg::W0, PhysReg::W1, PhysReg::W2, PhysReg::W3,
PhysReg::W4, PhysReg::W5, PhysReg::W6, PhysReg::W7,
PhysReg::W8, PhysReg::W9, PhysReg::W10, PhysReg::W11,
};
// 可分配的 GPR64 (64-bit)x0-x11共12个
static const std::vector<PhysReg> kAllocGPR64 = {
PhysReg::X0, PhysReg::X1, PhysReg::X2, PhysReg::X3,
PhysReg::X4, PhysReg::X5, PhysReg::X6, PhysReg::X7,
PhysReg::X8, PhysReg::X9, PhysReg::X10, PhysReg::X11,
};
// 可分配的 FPRs0-s10共11个
static const std::vector<PhysReg> kAllocFPR = {
PhysReg::S0, PhysReg::S1, PhysReg::S2, PhysReg::S3,
PhysReg::S4, PhysReg::S5, PhysReg::S6, PhysReg::S7,
PhysReg::S8, PhysReg::S9, PhysReg::S10,
};
// Callee-saved 寄存器
static const std::set<PhysReg> kCalleeSavedGPR = {
PhysReg::W8, PhysReg::W9, PhysReg::W10, PhysReg::W11,
};
static const std::set<PhysReg> kCalleeSavedGPR64 = {
PhysReg::X8, PhysReg::X9, PhysReg::X10, PhysReg::X11,
};
static const std::set<PhysReg> kCalleeSavedFPR = {
PhysReg::S8, PhysReg::S9, PhysReg::S10,
};
// Caller-saved 寄存器(被函数调用破坏)
static const std::set<PhysReg> kCallerSavedGPR = {
PhysReg::W0, PhysReg::W1, PhysReg::W2, PhysReg::W3,
PhysReg::W4, PhysReg::W5, PhysReg::W6, PhysReg::W7,
};
static const std::set<PhysReg> kCallerSavedGPR64 = {
PhysReg::X0, PhysReg::X1, PhysReg::X2, PhysReg::X3,
PhysReg::X4, PhysReg::X5, PhysReg::X6, PhysReg::X7,
};
static const std::set<PhysReg> kCallerSavedFPR = {
PhysReg::S0, PhysReg::S1, PhysReg::S2, PhysReg::S3,
PhysReg::S4, PhysReg::S5, PhysReg::S6, PhysReg::S7,
};
bool IsFloatReg(PhysReg r) {
return r >= PhysReg::S0 && r <= PhysReg::S10;
}
// W/X 寄存器别名w0-w11 和 x0-x11 是同一物理寄存器的 32/64 位视图
PhysReg WToX(PhysReg w) {
int idx = static_cast<int>(w) - static_cast<int>(PhysReg::W0);
if (idx >= 0 && idx <= 11)
return static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + idx);
return w;
}
PhysReg XToW(PhysReg x) {
int idx = static_cast<int>(x) - static_cast<int>(PhysReg::X0);
if (idx >= 0 && idx <= 11)
return static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + idx);
return x;
}
// 将物理寄存器及其别名都加入集合
void InsertWithAlias(std::set<PhysReg>& s, PhysReg r) {
s.insert(r);
if (r >= PhysReg::W0 && r <= PhysReg::W11) {
s.insert(WToX(r));
} else if (r >= PhysReg::X0 && r <= PhysReg::X11) {
s.insert(XToW(r));
}
}
bool IsGPR64(PhysReg r) {
return (r >= PhysReg::X0 && r <= PhysReg::X11) ||
r == PhysReg::X29 || r == PhysReg::X30 || r == PhysReg::SP;
}
// 确定物理寄存器对应的 VRegClass
VRegClass ClassForPhysReg(PhysReg r) {
if (IsFloatReg(r)) return VRegClass::FPR;
if (IsGPR64(r)) return VRegClass::GPR64;
return VRegClass::GPR;
}
// 获取虚拟寄存器类别的可分配物理寄存器列表
const std::vector<PhysReg>& GetAllocRegs(VRegClass rc) {
switch (rc) {
case VRegClass::GPR: return kAllocGPR;
case VRegClass::GPR64: return kAllocGPR64;
case VRegClass::FPR: return kAllocFPR;
}
return kAllocGPR;
}
// 判断是否应该提升的寄存器
// 只提升临时计算寄存器 (w8-w11, x8-x11, s8-s10)
// 不提升 ABI 寄存器 (w0-w7, x0-x7, s0-s7) 和特殊寄存器 (x29, x30, sp)
bool ShouldPromote(PhysReg r) {
if (r == PhysReg::X29 || r == PhysReg::X30 || r == PhysReg::SP) return false;
// 不提升 ABI caller-saved 寄存器w0-w7, x0-x7, s0-s7
if (r >= PhysReg::W0 && r <= PhysReg::W7) return false;
if (r >= PhysReg::X0 && r <= PhysReg::X7) return false;
if (r >= PhysReg::S0 && r <= PhysReg::S7) return false;
return true;
}
// ============================================================================
// 1. 构建 CFG
// ============================================================================
void BuildCFG(MachineFunction& func) {
for (auto& bb_ptr : func.GetBlocks()) {
bb_ptr->ClearCFG();
}
for (auto& bb_ptr : func.GetBlocks()) {
auto& bb = *bb_ptr;
for (const auto& inst : bb.GetInstructions()) {
auto op = inst.GetOpcode();
if (op == Opcode::B) {
auto* target = func.FindBlock(inst.GetOperands()[0].GetSymbol());
if (target) {
bb.AddSuccessor(target);
target->AddPredecessor(&bb);
}
} else if (op == Opcode::Bcond || op == Opcode::FBcond) {
auto* target = func.FindBlock(inst.GetOperands()[0].GetSymbol());
if (target) {
bb.AddSuccessor(target);
target->AddPredecessor(&bb);
}
} else if (op == Opcode::Cbnz || op == Opcode::Cbz) {
auto* target = func.FindBlock(inst.GetOperands()[1].GetSymbol());
if (target) {
bb.AddSuccessor(target);
target->AddPredecessor(&bb);
}
}
}
// fall-through如果最后一条指令不是无条件跳转/ret则下一个块是后继
if (!bb.GetInstructions().empty()) {
auto last_op = bb.GetInstructions().back().GetOpcode();
if (last_op != Opcode::B && last_op != Opcode::Ret) {
const auto& blocks = func.GetBlocks();
for (size_t i = 0; i + 1 < blocks.size(); ++i) {
if (blocks[i].get() == &bb) {
auto* next = blocks[i + 1].get();
bb.AddSuccessor(next);
next->AddPredecessor(&bb);
break;
}
}
}
}
}
}
// ============================================================================
// 2. VRegInfo 结构
// ============================================================================
struct VRegInfo {
int id = 0;
VRegClass rc = VRegClass::GPR;
int spill_slot = -1;
bool is_spilled = false;
};
// ============================================================================
// 3. 将 MIR 中的物理寄存器提升为虚拟寄存器
// ============================================================================
//
// 策略:只提升临时计算寄存器 (w8-w11, x8-x11, s8-s10)。
// ABI 寄存器 (w0-w7, x0-x7, s0-s7) 保持物理寄存器不变,因为它们用于:
// - 函数参数传递、返回值、调用约定
// 跨块的值通过栈槽传递,所以无需跨块跟踪 vreg。
// 图着色分配器仍然可以将 vreg 分配到 w0-w7 等寄存器。
void PromoteToVRegs(MachineFunction& func,
std::vector<VRegInfo>& vreg_infos) {
for (auto& bb_ptr : func.GetBlocks()) {
auto& insts = bb_ptr->GetInstructions();
// 块内物理寄存器 → vreg 的映射
std::unordered_map<int, int> phys_to_vreg;
auto EnsureSize = [&](int vreg_id) {
if (static_cast<int>(vreg_infos.size()) <= vreg_id)
vreg_infos.resize(vreg_id + 1);
};
auto GetOrCreateVReg = [&](PhysReg reg) -> int {
int key = static_cast<int>(reg);
auto it = phys_to_vreg.find(key);
if (it != phys_to_vreg.end()) return it->second;
VRegClass rc = ClassForPhysReg(reg);
int vreg_id = func.CreateVReg(rc);
EnsureSize(vreg_id);
vreg_infos[vreg_id] = {vreg_id, rc, -1, false};
phys_to_vreg[key] = vreg_id;
return vreg_id;
};
auto CreateNewVReg = [&](PhysReg reg) -> int {
VRegClass rc = ClassForPhysReg(reg);
int vreg_id = func.CreateVReg(rc);
EnsureSize(vreg_id);
vreg_infos[vreg_id] = {vreg_id, rc, -1, false};
phys_to_vreg[static_cast<int>(reg)] = vreg_id;
return vreg_id;
};
// 辅助:提升 use 位置的操作数
auto PromoteUse = [&](Operand& op) {
if (op.IsReg() && ShouldPromote(op.GetReg())) {
PhysReg orig = op.GetReg();
int vreg = GetOrCreateVReg(orig);
op = Operand::VReg(vreg, ClassForPhysReg(orig));
}
};
// 辅助:提升 def 位置的操作数(为 def 创建新 vreg
auto PromoteDef = [&](Operand& op) {
if (op.IsReg() && ShouldPromote(op.GetReg())) {
PhysReg orig = op.GetReg();
int vreg = CreateNewVReg(orig);
op = Operand::VReg(vreg, ClassForPhysReg(orig));
}
};
for (size_t i = 0; i < insts.size(); ++i) {
auto& inst = insts[i];
auto opcode = inst.GetOpcode();
// 跳过控制流/伪指令
if (opcode == Opcode::Prologue || opcode == Opcode::Epilogue ||
opcode == Opcode::B || opcode == Opcode::Bcond ||
opcode == Opcode::FBcond || opcode == Opcode::Ret)
continue;
// Bl清除所有 caller-saved 映射(它们被调用破坏了)
if (opcode == Opcode::Bl) {
// callee-saved (w8-w11等) 的映射也要清除,因为调用可能破坏它们
// (虽然按约定不应该,但这里保守处理)
phys_to_vreg.clear();
continue;
}
auto& ops = inst.GetOperands();
// Cbnz/Cbz: ops[0] 是 use
if (opcode == Opcode::Cbnz || opcode == Opcode::Cbz) {
if (!ops.empty()) PromoteUse(ops[0]);
continue;
}
// CmpOnlyRR / FCmpOnlyRR: ops[0], ops[1] 都是 use不写寄存器只设条件码
if (opcode == Opcode::CmpOnlyRR || opcode == Opcode::FCmpOnlyRR) {
for (size_t j = 0; j < std::min(ops.size(), size_t(2)); ++j)
PromoteUse(ops[j]);
continue;
}
// ---- 先处理 use 操作数 ----
switch (opcode) {
case Opcode::MovReg:
case Opcode::FMovReg:
if (ops.size() > 1) PromoteUse(ops[1]);
break;
case Opcode::StoreStack:
case Opcode::StoreStackOffset:
case Opcode::StoreGlobal:
if (!ops.empty()) PromoteUse(ops[0]);
break;
case Opcode::StoreIndirect:
if (!ops.empty()) PromoteUse(ops[0]);
if (ops.size() > 1) PromoteUse(ops[1]);
break;
case Opcode::LoadIndirect:
if (ops.size() > 1) PromoteUse(ops[1]);
break;
case Opcode::AddRI: case Opcode::SubRI:
case Opcode::LsrRI: case Opcode::LslRI:
case Opcode::FSqrtRR: case Opcode::SIToFP: case Opcode::FPToSI:
if (ops.size() > 1) PromoteUse(ops[1]);
break;
case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR:
case Opcode::MulRR: case Opcode::DivRR:
case Opcode::ModRR: case Opcode::LslRR:
case Opcode::FAddRR: case Opcode::FSubRR:
case Opcode::FMulRR: case Opcode::FDivRR:
if (ops.size() > 1) PromoteUse(ops[1]);
if (ops.size() > 2) PromoteUse(ops[2]);
break;
case Opcode::CmpRR: case Opcode::FCmpRR:
if (ops.size() > 1) PromoteUse(ops[1]);
if (ops.size() > 2) PromoteUse(ops[2]);
break;
default:
break;
}
// ---- 然后处理 def 操作数 ----
switch (opcode) {
case Opcode::MovImm: case Opcode::FMovImm:
case Opcode::MovReg: case Opcode::FMovReg:
case Opcode::LoadStack: case Opcode::LoadStackOffset:
case Opcode::LoadStackAddr: case Opcode::LoadIndirect:
case Opcode::LoadGlobal: case Opcode::LoadGlobalAddr:
case Opcode::AddRI: case Opcode::SubRI:
case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR:
case Opcode::MulRR: case Opcode::DivRR:
case Opcode::ModRR: case Opcode::LsrRI:
case Opcode::LslRI: case Opcode::LslRR:
case Opcode::FAddRR: case Opcode::FSubRR:
case Opcode::FMulRR: case Opcode::FDivRR:
case Opcode::FSqrtRR: case Opcode::SIToFP: case Opcode::FPToSI:
case Opcode::CmpRR: case Opcode::FCmpRR:
if (!ops.empty()) PromoteDef(ops[0]);
break;
default:
break;
}
}
}
}
// ============================================================================
// 4. VReg 活跃性分析
// ============================================================================
struct VRegLiveInfo {
std::set<int> live_in;
std::set<int> live_out;
};
using VRegLiveMap = std::unordered_map<MachineBasicBlock*, VRegLiveInfo>;
void GetVRegDefsUses(const MachineInstr& inst,
std::set<int>& defs, std::set<int>& uses) {
const auto& ops = inst.GetOperands();
auto opcode = inst.GetOpcode();
// Use 操作数
switch (opcode) {
case Opcode::MovReg: case Opcode::FMovReg:
if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId());
break;
case Opcode::StoreStack: case Opcode::StoreStackOffset:
case Opcode::StoreGlobal:
if (!ops.empty() && ops[0].IsVReg()) uses.insert(ops[0].GetVRegId());
break;
case Opcode::StoreIndirect:
if (!ops.empty() && ops[0].IsVReg()) uses.insert(ops[0].GetVRegId());
if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId());
break;
case Opcode::LoadIndirect:
if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId());
break;
case Opcode::AddRI: case Opcode::SubRI:
case Opcode::LsrRI: case Opcode::LslRI:
case Opcode::FSqrtRR: case Opcode::SIToFP: case Opcode::FPToSI:
if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId());
break;
case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR:
case Opcode::MulRR: case Opcode::DivRR:
case Opcode::ModRR: case Opcode::LslRR:
case Opcode::FAddRR: case Opcode::FSubRR:
case Opcode::FMulRR: case Opcode::FDivRR:
if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId());
if (ops.size() > 2 && ops[2].IsVReg()) uses.insert(ops[2].GetVRegId());
break;
case Opcode::CmpRR: case Opcode::FCmpRR:
if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId());
if (ops.size() > 2 && ops[2].IsVReg()) uses.insert(ops[2].GetVRegId());
break;
case Opcode::CmpOnlyRR: case Opcode::FCmpOnlyRR:
if (!ops.empty() && ops[0].IsVReg()) uses.insert(ops[0].GetVRegId());
if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId());
break;
case Opcode::Cbnz: case Opcode::Cbz:
if (!ops.empty() && ops[0].IsVReg()) uses.insert(ops[0].GetVRegId());
break;
default:
break;
}
// Def 操作数
switch (opcode) {
case Opcode::MovImm: case Opcode::FMovImm:
case Opcode::MovReg: case Opcode::FMovReg:
case Opcode::LoadStack: case Opcode::LoadStackOffset:
case Opcode::LoadStackAddr: case Opcode::LoadIndirect:
case Opcode::LoadGlobal: case Opcode::LoadGlobalAddr:
case Opcode::AddRI: case Opcode::SubRI:
case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR:
case Opcode::MulRR: case Opcode::DivRR:
case Opcode::ModRR: case Opcode::LsrRI:
case Opcode::LslRI: case Opcode::LslRR:
case Opcode::FAddRR: case Opcode::FSubRR:
case Opcode::FMulRR: case Opcode::FDivRR:
case Opcode::FSqrtRR: case Opcode::SIToFP: case Opcode::FPToSI:
case Opcode::CmpRR: case Opcode::FCmpRR:
if (!ops.empty() && ops[0].IsVReg()) defs.insert(ops[0].GetVRegId());
break;
default:
break;
}
}
VRegLiveMap ComputeVRegLiveness(MachineFunction& func) {
VRegLiveMap live_map;
for (auto& bb_ptr : func.GetBlocks()) {
live_map[bb_ptr.get()] = VRegLiveInfo{};
}
bool changed = true;
while (changed) {
changed = false;
const auto& blocks = func.GetBlocks();
for (int bi = static_cast<int>(blocks.size()) - 1; bi >= 0; --bi) {
auto* bb = blocks[bi].get();
auto& info = live_map[bb];
// live_out = union of successors' live_in
std::set<int> new_live_out;
for (auto* succ : bb->GetSuccessors()) {
const auto& succ_in = live_map[succ].live_in;
new_live_out.insert(succ_in.begin(), succ_in.end());
}
// 从块末尾向前扫描计算 live_in
std::set<int> live = new_live_out;
const auto& insts = bb->GetInstructions();
for (int i = static_cast<int>(insts.size()) - 1; i >= 0; --i) {
std::set<int> defs, uses_set;
GetVRegDefsUses(insts[i], defs, uses_set);
for (auto v : defs) live.erase(v);
for (auto v : uses_set) live.insert(v);
}
if (live != info.live_in || new_live_out != info.live_out) {
info.live_in = std::move(live);
info.live_out = std::move(new_live_out);
changed = true;
}
}
}
return live_map;
}
// ============================================================================
// 5. 干涉图构建
// ============================================================================
// 获取指令中定义和使用的物理寄存器(非 vreg 的)
void GetPhysRegDefsUses(const MachineInstr& inst,
std::set<PhysReg>& phys_defs,
std::set<PhysReg>& phys_uses) {
const auto& ops = inst.GetOperands();
auto opcode = inst.GetOpcode();
// 对于 Bl 指令,所有 caller-saved 寄存器被隐式定义(破坏)
if (opcode == Opcode::Bl) {
for (auto r : kCallerSavedGPR) phys_defs.insert(r);
for (auto r : kCallerSavedGPR64) phys_defs.insert(r);
for (auto r : kCallerSavedFPR) phys_defs.insert(r);
return;
}
// Use 位置的物理寄存器
switch (opcode) {
case Opcode::MovReg: case Opcode::FMovReg:
if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg()))
phys_uses.insert(ops[1].GetReg());
break;
case Opcode::StoreStack: case Opcode::StoreStackOffset:
case Opcode::StoreGlobal:
if (!ops.empty() && ops[0].IsReg() && !ShouldPromote(ops[0].GetReg()))
phys_uses.insert(ops[0].GetReg());
break;
case Opcode::StoreIndirect:
if (!ops.empty() && ops[0].IsReg() && !ShouldPromote(ops[0].GetReg()))
phys_uses.insert(ops[0].GetReg());
if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg()))
phys_uses.insert(ops[1].GetReg());
break;
case Opcode::AddRI: case Opcode::SubRI:
case Opcode::LsrRI: case Opcode::LslRI:
case Opcode::FSqrtRR: case Opcode::SIToFP: case Opcode::FPToSI:
if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg()))
phys_uses.insert(ops[1].GetReg());
break;
case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR:
case Opcode::MulRR: case Opcode::DivRR:
case Opcode::ModRR: case Opcode::LslRR:
case Opcode::FAddRR: case Opcode::FSubRR:
case Opcode::FMulRR: case Opcode::FDivRR:
if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg()))
phys_uses.insert(ops[1].GetReg());
if (ops.size() > 2 && ops[2].IsReg() && !ShouldPromote(ops[2].GetReg()))
phys_uses.insert(ops[2].GetReg());
break;
case Opcode::CmpRR: case Opcode::FCmpRR:
case Opcode::CmpOnlyRR: case Opcode::FCmpOnlyRR:
if (!ops.empty() && ops[0].IsReg() && !ShouldPromote(ops[0].GetReg()))
phys_uses.insert(ops[0].GetReg());
if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg()))
phys_uses.insert(ops[1].GetReg());
break;
case Opcode::Cbnz: case Opcode::Cbz:
if (!ops.empty() && ops[0].IsReg() && !ShouldPromote(ops[0].GetReg()))
phys_uses.insert(ops[0].GetReg());
break;
case Opcode::LoadIndirect:
if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg()))
phys_uses.insert(ops[1].GetReg());
break;
default:
break;
}
// Def 位置的物理寄存器
switch (opcode) {
case Opcode::MovImm: case Opcode::FMovImm:
case Opcode::MovReg: case Opcode::FMovReg:
case Opcode::LoadStack: case Opcode::LoadStackOffset:
case Opcode::LoadStackAddr: case Opcode::LoadIndirect:
case Opcode::LoadGlobal: case Opcode::LoadGlobalAddr:
case Opcode::AddRI: case Opcode::SubRI:
case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR:
case Opcode::MulRR: case Opcode::DivRR:
case Opcode::ModRR: case Opcode::LsrRI:
case Opcode::LslRI: case Opcode::LslRR:
case Opcode::FAddRR: case Opcode::FSubRR:
case Opcode::FMulRR: case Opcode::FDivRR:
case Opcode::FSqrtRR: case Opcode::SIToFP: case Opcode::FPToSI:
case Opcode::CmpRR: case Opcode::FCmpRR:
if (!ops.empty() && ops[0].IsReg() && !ShouldPromote(ops[0].GetReg()))
phys_defs.insert(ops[0].GetReg());
break;
default:
break;
}
}
struct InterferenceGraph {
int num_vregs = 0;
std::vector<std::set<int>> adj;
std::vector<int> degree;
// 每个 vreg 不能使用的物理寄存器集合(与物理寄存器的干涉)
std::vector<std::set<PhysReg>> phys_interference;
void Init(int n) {
num_vregs = n;
adj.resize(n);
degree.resize(n, 0);
phys_interference.resize(n);
}
void AddEdge(int u, int v) {
if (u == v) return;
if (u < 0 || u >= num_vregs || v < 0 || v >= num_vregs) return;
if (adj[u].count(v)) return;
adj[u].insert(v);
adj[v].insert(u);
degree[u]++;
degree[v]++;
}
void AddPhysInterference(int vreg, PhysReg phys) {
if (vreg < 0 || vreg >= num_vregs) return;
// 同时插入物理寄存器及其 W/X 别名
InsertWithAlias(phys_interference[vreg], phys);
}
};
InterferenceGraph BuildInterferenceGraph(
MachineFunction& func,
const VRegLiveMap& live_map,
int num_vregs) {
InterferenceGraph ig;
ig.Init(num_vregs);
for (auto& bb_ptr : func.GetBlocks()) {
auto* bb = bb_ptr.get();
const auto& insts = bb->GetInstructions();
auto it = live_map.find(bb);
if (it == live_map.end()) continue;
std::set<int> live = it->second.live_out;
// 追踪当前活跃的物理寄存器
std::set<PhysReg> phys_live;
// 从块末尾向前扫描
for (int i = static_cast<int>(insts.size()) - 1; i >= 0; --i) {
std::set<int> defs, uses_set;
GetVRegDefsUses(insts[i], defs, uses_set);
std::set<PhysReg> phys_defs, phys_uses;
GetPhysRegDefsUses(insts[i], phys_defs, phys_uses);
// 对每个 vreg def与所有当前 live 的 vreg 添加干涉边
for (int d : defs) {
for (int l : live) {
ig.AddEdge(d, l);
}
// vreg def 与当前活跃的物理寄存器干涉
for (auto pr : phys_live) {
ig.AddPhysInterference(d, pr);
}
}
// 物理寄存器 def 与当前活跃的 vreg 干涉
for (auto pr : phys_defs) {
for (int l : live) {
ig.AddPhysInterference(l, pr);
}
}
// 更新 vreg 活跃集合
for (int d : defs) live.erase(d);
for (int u : uses_set) live.insert(u);
// 更新物理寄存器活跃集合
for (auto pr : phys_defs) phys_live.erase(pr);
for (auto pr : phys_uses) phys_live.insert(pr);
}
}
return ig;
}
// ============================================================================
// 6. 图着色Simplify → Select → Spill
// ============================================================================
struct ColoringResult {
std::unordered_map<int, PhysReg> assignment;
std::vector<int> spilled;
};
bool IsCalleeSaved(PhysReg r, VRegClass rc) {
switch (rc) {
case VRegClass::GPR: return kCalleeSavedGPR.count(r) > 0;
case VRegClass::GPR64: return kCalleeSavedGPR64.count(r) > 0;
case VRegClass::FPR: return kCalleeSavedFPR.count(r) > 0;
}
return false;
}
ColoringResult ColorGraph(
InterferenceGraph& ig,
const std::vector<VRegInfo>& vreg_infos,
MachineFunction& func) {
int n = ig.num_vregs;
ColoringResult result;
if (n == 0) return result;
// Simplify: 迭代移除度 < K 的节点入栈
std::vector<bool> removed(n, false);
std::stack<int> select_stack;
std::vector<int> cur_degree(ig.degree);
auto K = [&](int v) -> int {
return static_cast<int>(GetAllocRegs(vreg_infos[v].rc).size());
};
int remaining = n;
while (remaining > 0) {
bool found = false;
for (int v = 0; v < n; ++v) {
if (removed[v]) continue;
if (cur_degree[v] < K(v)) {
select_stack.push(v);
removed[v] = true;
remaining--;
for (int nb : ig.adj[v]) {
if (!removed[nb]) cur_degree[nb]--;
}
found = true;
break;
}
}
if (!found) {
// Potential spill选度数最大的节点
int best = -1;
int best_degree = -1;
for (int v = 0; v < n; ++v) {
if (removed[v]) continue;
if (cur_degree[v] > best_degree) {
best_degree = cur_degree[v];
best = v;
}
}
if (best >= 0) {
select_stack.push(best);
removed[best] = true;
remaining--;
for (int nb : ig.adj[best]) {
if (!removed[nb]) cur_degree[nb]--;
}
}
}
}
// Select: 从栈中弹出,尝试着色
std::vector<int> color(n, -1);
while (!select_stack.empty()) {
int v = select_stack.top();
select_stack.pop();
const auto& alloc_regs = GetAllocRegs(vreg_infos[v].rc);
// 收集邻居已使用的颜色 + 物理寄存器干涉
std::set<PhysReg> used_colors;
for (int nb : ig.adj[v]) {
if (color[nb] >= 0) {
PhysReg nb_color = static_cast<PhysReg>(color[nb]);
// 插入颜色及其 W/X 别名,防止跨类别冲突
InsertWithAlias(used_colors, nb_color);
}
}
// 加入与此 vreg 干涉的物理寄存器(已含别名)
for (auto pr : ig.phys_interference[v]) {
used_colors.insert(pr);
}
// 优先使用 callee-saved减少保存开销对仅在块内使用的 vreg 更优)
PhysReg chosen = PhysReg::W0;
bool colored = false;
for (auto r : alloc_regs) {
if (used_colors.count(r)) continue;
if (IsCalleeSaved(r, vreg_infos[v].rc)) {
chosen = r;
colored = true;
break;
}
}
if (!colored) {
for (auto r : alloc_regs) {
if (!used_colors.count(r)) {
chosen = r;
colored = true;
break;
}
}
}
if (colored) {
color[v] = static_cast<int>(chosen);
result.assignment[v] = chosen;
if (IsCalleeSaved(chosen, vreg_infos[v].rc)) {
func.AddUsedCalleeSaved(chosen);
}
} else {
result.spilled.push_back(v);
}
}
return result;
}
// ============================================================================
// 7. Spill 处理
// ============================================================================
void RewriteSpills(MachineFunction& func,
const std::vector<int>& spilled,
std::vector<VRegInfo>& vreg_infos) {
for (int v : spilled) {
int slot_size = (vreg_infos[v].rc == VRegClass::GPR64) ? 8 : 4;
int slot = func.CreateFrameIndex(slot_size);
vreg_infos[v].spill_slot = slot;
vreg_infos[v].is_spilled = true;
}
std::set<int> spilled_set(spilled.begin(), spilled.end());
for (auto& bb_ptr : func.GetBlocks()) {
auto& insts = bb_ptr->GetInstructions();
std::vector<MachineInstr> new_insts;
new_insts.reserve(insts.size() * 2);
for (auto& inst : insts) {
auto& ops = inst.GetOperands();
std::set<int> inst_defs, inst_uses;
GetVRegDefsUses(inst, inst_defs, inst_uses);
// 收集此指令中 spilled vreg 的 use 和 def 位置
std::vector<std::pair<size_t, int>> use_positions;
std::vector<std::pair<size_t, int>> def_positions;
for (size_t j = 0; j < ops.size(); ++j) {
if (!ops[j].IsVReg()) continue;
int vid = ops[j].GetVRegId();
if (!spilled_set.count(vid)) continue;
if (inst_uses.count(vid)) use_positions.emplace_back(j, vid);
if (inst_defs.count(vid)) def_positions.emplace_back(j, vid);
}
auto EnsureSize = [&](int id) {
if (static_cast<int>(vreg_infos.size()) <= id)
vreg_infos.resize(id + 1);
};
// 为 spilled use 插入 LoadStack
for (auto& [op_idx, vid] : use_positions) {
VRegClass rc = vreg_infos[vid].rc;
int new_vreg = func.CreateVReg(rc);
EnsureSize(new_vreg);
vreg_infos[new_vreg] = {new_vreg, rc, -1, false};
new_insts.emplace_back(Opcode::LoadStack,
std::vector<Operand>{
Operand::VReg(new_vreg, rc),
Operand::FrameIndex(vreg_infos[vid].spill_slot)});
ops[op_idx] = Operand::VReg(new_vreg, rc);
}
new_insts.push_back(std::move(inst));
// 为 spilled def 插入 StoreStack
for (auto& [op_idx, vid] : def_positions) {
VRegClass rc = vreg_infos[vid].rc;
int new_vreg = func.CreateVReg(rc);
EnsureSize(new_vreg);
vreg_infos[new_vreg] = {new_vreg, rc, -1, false};
// 替换 emitted 指令中的 def 操作数
new_insts.back().GetOperands()[op_idx] = Operand::VReg(new_vreg, rc);
new_insts.emplace_back(Opcode::StoreStack,
std::vector<Operand>{
Operand::VReg(new_vreg, rc),
Operand::FrameIndex(vreg_infos[vid].spill_slot)});
}
}
insts = std::move(new_insts);
}
}
// ============================================================================
// 8. 应用着色结果
// ============================================================================
void ApplyColoring(MachineFunction& func,
const std::unordered_map<int, PhysReg>& assignment) {
for (auto& bb_ptr : func.GetBlocks()) {
for (auto& inst : bb_ptr->GetInstructions()) {
for (auto& op : inst.GetOperands()) {
if (op.IsVReg()) {
auto it = assignment.find(op.GetVRegId());
if (it != assignment.end()) {
op.AssignPhysReg(it->second);
}
}
}
}
}
}
} // namespace
void RunRegAlloc(MachineFunction& function) {
for (const auto& bb_ptr : function.GetBlocks()) {
// ============================================================================
// 9. 主入口
// ============================================================================
void RunRegAlloc(MachineFunction& func) {
// 步骤 1: 构建 CFG
BuildCFG(func);
// 步骤 2: 将物理寄存器提升为虚拟寄存器
std::vector<VRegInfo> vreg_infos;
PromoteToVRegs(func, vreg_infos);
int num_vregs = func.GetNumVRegs();
if (num_vregs == 0) return;
vreg_infos.resize(num_vregs);
// 图着色 + spill 迭代
const int kMaxIterations = 10;
for (int iter = 0; iter < kMaxIterations; ++iter) {
// 重建 CFGspill 后指令变了)
if (iter > 0) BuildCFG(func);
VRegLiveMap live_map = ComputeVRegLiveness(func);
num_vregs = func.GetNumVRegs();
vreg_infos.resize(num_vregs);
InterferenceGraph ig = BuildInterferenceGraph(func, live_map, num_vregs);
ColoringResult result = ColorGraph(ig, vreg_infos, func);
if (result.spilled.empty()) {
ApplyColoring(func, result.assignment);
return;
}
RewriteSpills(func, result.spilled, vreg_infos);
}
// 最后一次尝试
BuildCFG(func);
VRegLiveMap live_map = ComputeVRegLiveness(func);
num_vregs = func.GetNumVRegs();
vreg_infos.resize(num_vregs);
InterferenceGraph ig = BuildInterferenceGraph(func, live_map, num_vregs);
ColoringResult result = ColorGraph(ig, vreg_infos, func);
ApplyColoring(func, result.assignment);
// 最终检查:不应有残留 VReg
for (auto& bb_ptr : func.GetBlocks()) {
for (const auto& inst : bb_ptr->GetInstructions()) {
for (const auto& operand : inst.GetOperands()) {
if (operand.GetKind() == Operand::Kind::Reg &&
!IsAllowedReg(operand.GetReg())) {
throw std::runtime_error(FormatError("mir", "寄存器分配失败"));
for (const auto& op : inst.GetOperands()) {
if (op.IsVReg()) {
throw std::runtime_error(
FormatError("mir", "寄存器分配失败:存在未着色的虚拟寄存器 v" +
std::to_string(op.GetVRegId())));
}
}
}

@ -72,6 +72,7 @@ std::optional<PhysReg> GetWrittenReg(const MachineInstr& inst) {
case Opcode::AddRI:
case Opcode::SubRI:
case Opcode::AddRR:
case Opcode::AddRR_UXTW:
case Opcode::SubRR:
case Opcode::MulRR:
case Opcode::DivRR:
@ -125,6 +126,23 @@ void RecordStore(std::unordered_map<int, PhysReg>& slot_to_reg,
slot_to_reg[ops[1].GetFrameIndex()] = ops[0].GetReg();
}
bool IsWReg(PhysReg reg) {
return reg >= PhysReg::W0 && reg <= PhysReg::W11;
}
bool IsXReg(PhysReg reg) {
return (reg >= PhysReg::X0 && reg <= PhysReg::X11) ||
reg == PhysReg::X29 || reg == PhysReg::X30;
}
// 检查两个寄存器宽度是否兼容(同为 W同为 X或同为 S
bool SameRegWidth(PhysReg a, PhysReg b) {
if (IsWReg(a) && IsWReg(b)) return true;
if (IsXReg(a) && IsXReg(b)) return true;
if (IsFloatReg(a) && IsFloatReg(b)) return true;
return false;
}
bool TryForwardLoad(std::vector<MachineInstr>& out,
std::unordered_map<int, PhysReg>& slot_to_reg,
const MachineInstr& load) {
@ -142,6 +160,12 @@ bool TryForwardLoad(std::vector<MachineInstr>& out,
}
const PhysReg src = it->second;
// 宽度不匹配时不能转发(如 W8 → X8 会生成非法的 mov x8, w8
if (!SameRegWidth(src, dst)) {
return false;
}
if (RegAlias(src, dst)) {
slot_to_reg[slot] = dst;
return true;
@ -242,7 +266,10 @@ void RunPeephole(MachineFunction& function) {
IsLoadStack(insts[i + 1]) && IsSameFrameIndex(cur, insts[i + 1])) {
optimized.push_back(cur);
RecordStore(slot_to_reg, cur);
TryForwardLoad(optimized, slot_to_reg, insts[i + 1]);
if (!TryForwardLoad(optimized, slot_to_reg, insts[i + 1])) {
// 转发失败(如宽度不匹配),保留原始 load
optimized.push_back(insts[i + 1]);
}
++i;
continue;
}

Loading…
Cancel
Save