feat(mir): 实现 LLVM-style 贪婪寄存器分配器 —— 统一架构

核心变更:
- MIR.h: 增强 LiveInterval(VNInfo/UsePosition/Segment)+ LiveRegMatrix + RegClass
- GreedyAlloc.cpp: TryAssign/TryAnyFreeReg/TryEvict/TrySplit 贪婪分配 + RewriteSpills
- InstLiveness.cpp: EnhanceIntervals 前向 pass + ComputeInstLiveness 适配
- MIRBasicBlock.cpp: InsertInst/ReplaceVReg API
- main.cpp: 切换至 RunGreedyRegAlloc
- RegAlloc.cpp/LinearScanAlloc.cpp: #if 0 隔离

架构:优先级队列驱动分配(每轮全新分配),TryEvict 无条件驱逐,
StoreStack+LoadStack 溢出重写,区间分裂处理高寄存器压力。

功能测试通过率: 53/100(剩余 47 例需调试溢出重写循环)
lzk
lzkk 4 days ago
parent 0a29e6ac42
commit da1e456133

@ -0,0 +1,12 @@
#pragma once
namespace mir
{
class MachineFunction;
class MachineModule;
void RunGreedyRegAlloc(MachineFunction &function);
void RunGreedyRegAlloc(MachineModule &module);
} // namespace mir

@ -184,6 +184,17 @@ namespace mir
MovReg,
};
// ---- 寄存器类别 ----
enum class RegClass { GPR32, GPR64, FPR32, FPR64, Unknown };
inline RegClass ToRegClass(VRegClass vc) {
switch (vc) {
case VRegClass::Int: return RegClass::GPR32;
case VRegClass::Float: return RegClass::FPR32;
default: return RegClass::Unknown;
}
}
enum class CondCode
{
EQ,
@ -288,6 +299,9 @@ namespace mir
MachineInstr &Append(Opcode opcode,
std::initializer_list<Operand> operands = {});
void InsertInst(int local_idx, MachineInstr inst);
void ReplaceVReg(int local_idx, int old_vreg, int new_vreg);
private:
std::string name_;
int label_id_ = -1;
@ -421,11 +435,9 @@ namespace mir
std::unique_ptr<MachineModule> LowerModuleToMIR(const ir::Module &module);
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module &module);
void RunRegAlloc(MachineFunction &function);
void RunRegAlloc(MachineModule &module);
void RunLinearScanRegAlloc(MachineFunction &function);
void RunLinearScanRegAlloc(MachineModule &module);
// ---- 贪婪寄存器分配器入口 ----
void RunGreedyRegAlloc(MachineFunction &function);
void RunGreedyRegAlloc(MachineModule &module);
void RunFrameLowering(MachineFunction &function);
void RunFrameLowering(MachineModule &module);
@ -442,16 +454,109 @@ namespace mir
void PrintAsm(const MachineFunction &function, std::ostream &os);
void PrintAsm(const MachineModule &module, std::ostream &os);
struct VNInfo
{
int id = -1;
int def_pos = -1;
Opcode def_opcode = Opcode::Ret;
bool IsRematable() const
{
return def_opcode == Opcode::MovImm ||
def_opcode == Opcode::LoadStackAddr;
}
};
struct UsePosition
{
int pos = -1;
bool is_def = false;
int vn_id = -1;
Opcode opcode = Opcode::Ret;
};
struct Segment
{
int start = -1;
int end = -1;
int vn_id = -1;
bool crosses_call = false;
bool Contains(int pos) const { return start <= pos && pos <= end; }
bool Overlaps(const Segment &o) const
{
return !(end < o.start || o.end < start);
}
};
struct LiveInterval
{
int vreg;
int start; // instruction position (global index)
int end; // instruction position
VRegClass vreg_class;
int vreg = -1;
RegClass reg_class = RegClass::Unknown;
std::vector<VNInfo> valnos;
std::vector<Segment> segments;
std::vector<UsePosition> uses;
int assigned_reg = -1;
float spill_weight = 0.0f;
int hint_reg = -1;
int generation = 0;
// 保留旧字段以兼容 ComputeInstLiveness
int start = -1;
int end = -1;
VRegClass vreg_class = VRegClass::Int;
bool spilled = false;
int spill_slot = -1;
bool IsSpilled() const { return assigned_reg == -2; }
bool IsSplit() const { return assigned_reg == -3; }
bool IsAllocated() const { return assigned_reg >= 0; }
int FirstUsePos() const
{
if (!uses.empty()) return uses.front().pos;
return start;
}
int LastUsePos() const
{
if (!segments.empty()) return segments.back().end;
return end;
}
bool SegmentCrossesCall() const
{
for (auto &seg : segments)
if (seg.crosses_call) return true;
return false;
}
float Length() const
{
int total = 0;
for (auto &seg : segments)
total += seg.end - seg.start + 1;
return total > 0 ? (float)total : 1.0f;
}
};
class LiveRegMatrix
{
std::vector<std::vector<LiveInterval *>> reg_assignments_;
public:
void Init(int num_regs);
void Assign(LiveInterval *li, int phys_reg);
void Unassign(LiveInterval *li);
bool CheckInterference(const LiveInterval &li, int phys_reg) const;
LiveInterval *GetConflict(const LiveInterval &li, int phys_reg) const;
bool CheckInterferenceRange(int start, int end, int phys_reg) const;
};
// ---- 增强活跃分析 ----
std::vector<LiveInterval> EnhanceIntervals(
const std::vector<LiveInterval> &raw,
MachineFunction &function);
std::vector<LiveInterval> ComputeInstLiveness(MachineFunction &func);
} // namespace mir

@ -61,10 +61,7 @@ int main(int argc, char** argv) {
mir::VerifyMIR(*machine_module);
#endif
if (opts.regalloc == "linear")
mir::RunLinearScanRegAlloc(*machine_module);
else
mir::RunRegAlloc(*machine_module);
mir::RunGreedyRegAlloc(*machine_module);
#ifndef NDEBUG
mir::VerifyRegAlloc(*machine_module);

@ -0,0 +1,593 @@
#include "mir/GreedyAlloc.h"
#include "mir/MIR.h"
#include <algorithm>
#include <cmath>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace mir
{
namespace
{
// ---- 寄存器可分配集 ----
constexpr int GP_ALLOCATABLE[] = {8,9,10,11,12,15,16,17,19,20,21,22,23,24,25,26,27,28};
constexpr int GP_COUNT = 18;
constexpr int FP_ALLOCATABLE[] = {0,1,2,3,4,5,6,7,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31};
constexpr int FP_COUNT = 24;
constexpr int MAX_ROUNDS = 5;
bool IsCallerSavedGP(int phys_reg) { return phys_reg <= 17; }
const int* GetRegList(RegClass rc, int& count)
{
if (rc == RegClass::GPR32 || rc == RegClass::GPR64)
{ count = GP_COUNT; return GP_ALLOCATABLE; }
else
{ count = FP_COUNT; return FP_ALLOCATABLE; }
}
struct SpillWeightCmp
{
bool operator()(LiveInterval* a, LiveInterval* b) const
{ return a->spill_weight < b->spill_weight; }
};
// ---- def/use 提取(与 InstLiveness.cpp 保持一致)----
static bool HasVRegDef(Opcode opcode)
{
switch (opcode)
{
case Opcode::MovImm: case Opcode::LoadStack: case Opcode::LoadGlobal:
case Opcode::LoadGlobalAddr: case Opcode::LoadStackAddr: case Opcode::LoadMem:
case Opcode::AddRR: case Opcode::SubRR: case Opcode::AddImm:
case Opcode::SubImm: case Opcode::MulRR: case Opcode::DivRR:
case Opcode::ModRR: case Opcode::AndRR: case Opcode::OrRR:
case Opcode::XorRR: case Opcode::ShlRR: case Opcode::ShrRR:
case Opcode::AsrRR: case Opcode::Asr64RR: case Opcode::Uxtw:
case Opcode::Sxtw: case Opcode::CSet: case Opcode::Csel:
case Opcode::Smull: case Opcode::Msub: case Opcode::NegRR:
case Opcode::FAddRR: case Opcode::FSubRR: case Opcode::FMulRR:
case Opcode::FDivRR: case Opcode::Scvtf: case Opcode::FCvtzs:
case Opcode::FMovWS: case Opcode::MovReg:
return true;
default: return false;
}
}
static void ExtractDefUse(const MachineInstr &inst, int &def_vreg,
std::vector<int> &use_vregs)
{
def_vreg = -1;
use_vregs.clear();
const auto &ops = inst.GetOperands();
const auto opcode = inst.GetOpcode();
if (HasVRegDef(opcode) && !ops.empty() &&
ops[0].GetKind() == Operand::Kind::VReg)
def_vreg = ops[0].GetVRegId();
for (size_t i = 0; i < ops.size(); ++i)
{
if (HasVRegDef(opcode) && i == 0) continue;
if (ops[i].GetKind() == Operand::Kind::VReg)
use_vregs.push_back(ops[i].GetVRegId());
}
}
// ---- 循环深度分析 ----
std::vector<int> AnalyzeLoopDepth(MachineFunction &func)
{
auto &blocks = func.GetBlocks();
int n = (int)blocks.size();
std::vector<int> depth(n, 0);
std::unordered_map<int, int> label_to_idx;
for (int i = 0; i < n; ++i)
if (blocks[i]) label_to_idx[blocks[i]->GetLabelId()] = i;
struct Edge { int src; int dst; };
std::vector<Edge> back_edges;
for (int i = 0; i < n; ++i)
{
if (!blocks[i]) continue;
for (auto &inst : blocks[i]->GetInstructions())
{
auto opcode = inst.GetOpcode();
int target_label = -1;
if (opcode == Opcode::Br && !inst.GetOperands().empty() &&
inst.GetOperands()[0].GetKind() == Operand::Kind::Label)
target_label = inst.GetOperands()[0].GetLabel();
else if (opcode == Opcode::CondBr && inst.GetOperands().size() >= 2 &&
inst.GetOperands()[1].GetKind() == Operand::Kind::Label)
target_label = inst.GetOperands()[1].GetLabel();
if (target_label < 0) continue;
auto it = label_to_idx.find(target_label);
if (it != label_to_idx.end() && (int)it->second <= i)
back_edges.push_back({i, (int)it->second});
}
}
for (auto &be : back_edges)
{
int header = be.dst;
std::unordered_set<int> body;
std::queue<int> q;
q.push(be.src);
while (!q.empty())
{
int cur = q.front(); q.pop();
if (cur == header || body.count(cur)) continue;
body.insert(cur);
if (cur > 0 && !body.count(cur - 1)) q.push(cur - 1);
for (int p = 0; p < n; ++p)
{
if (!blocks[p]) continue;
for (auto &inst : blocks[p]->GetInstructions())
{
int tgt = -1;
if (inst.GetOpcode() == Opcode::Br && !inst.GetOperands().empty() &&
inst.GetOperands()[0].GetKind() == Operand::Kind::Label)
tgt = inst.GetOperands()[0].GetLabel();
else if (inst.GetOpcode() == Opcode::CondBr &&
inst.GetOperands().size() >= 2 &&
inst.GetOperands()[1].GetKind() == Operand::Kind::Label)
tgt = inst.GetOperands()[1].GetLabel();
auto it2 = label_to_idx.find(tgt);
if (it2 != label_to_idx.end() && (int)it2->second == cur && !body.count(p))
q.push(p);
}
}
}
body.insert(header);
int max_existing = 0;
for (int b : body) max_existing = std::max(max_existing, depth[b]);
for (int b : body) depth[b] = std::max(depth[b], max_existing + 1);
}
return depth;
}
// ---- Spill Weight ----
void ComputeSpillWeights(std::vector<LiveInterval> &intervals,
const std::vector<int> &block_depth,
const std::vector<int> &pos_to_block)
{
for (auto &li : intervals)
{
float w = 0.0f;
for (auto &use : li.uses)
{
int block = (use.pos >= 0 && use.pos < (int)pos_to_block.size())
? pos_to_block[use.pos] : 0;
int d = (block >= 0 && block < (int)block_depth.size())
? block_depth[block] : 0;
float mult = std::pow(10.0f, (float)d);
if (use.is_def) mult *= 0.5f;
w += mult;
}
li.spill_weight = w / li.Length();
}
}
// ---- Copy Hints ----
void PropagateCopyHints(std::vector<LiveInterval> &intervals,
MachineFunction &func)
{
for (auto &block : func.GetBlocks())
{
if (!block) continue;
for (auto &inst : block->GetInstructions())
{
if (inst.GetOpcode() != Opcode::MovReg) continue;
auto &ops = inst.GetOperands();
if (ops.size() < 2) continue;
if (ops[0].GetKind() != Operand::Kind::VReg) continue;
if (ops[1].GetKind() != Operand::Kind::VReg) continue;
int dst = ops[0].GetVRegId();
int src = ops[1].GetVRegId();
if (dst < 0 || dst >= (int)intervals.size()) continue;
if (src < 0 || src >= (int)intervals.size()) continue;
if (intervals[src].IsAllocated())
intervals[dst].hint_reg = intervals[src].assigned_reg;
else if (intervals[dst].IsAllocated())
intervals[src].hint_reg = intervals[dst].assigned_reg;
else if (intervals[src].hint_reg >= 0)
intervals[dst].hint_reg = intervals[src].hint_reg;
}
}
}
// ---- TryAssign / TryAnyFreeReg ----
bool TryAssign(LiveInterval &li, LiveRegMatrix &m, int hint)
{
if (hint < 0) return false;
if (IsCallerSavedGP(hint) && li.SegmentCrossesCall()) return false;
if (!m.CheckInterference(li, hint))
{
m.Assign(&li, hint);
li.assigned_reg = hint;
return true;
}
return false;
}
bool TryAnyFreeReg(LiveInterval &li, LiveRegMatrix &m)
{
int n = 0;
const int *regs = GetRegList(li.reg_class, n);
for (int i = 0; i < n; ++i)
{
int r = regs[i];
if (IsCallerSavedGP(r) && li.SegmentCrossesCall()) continue;
if (!m.CheckInterference(li, r))
{
m.Assign(&li, r);
li.assigned_reg = r;
return true;
}
}
return false;
}
// ---- TryEvict ----
bool TryEvict(LiveInterval &li, LiveRegMatrix &m,
std::vector<LiveInterval *> &heap,
const SpillWeightCmp &cmp)
{
int best_reg = -1;
float min_weight = 1e9f;
LiveInterval *victim = nullptr;
int n = 0;
const int *regs = GetRegList(li.reg_class, n);
for (int i = 0; i < n; ++i)
{
int r = regs[i];
if (IsCallerSavedGP(r) && li.SegmentCrossesCall()) continue;
auto *conflict = m.GetConflict(li, r);
if (!conflict)
{
m.Assign(&li, r);
li.assigned_reg = r;
return true;
}
if (conflict->spill_weight < min_weight)
{
min_weight = conflict->spill_weight;
best_reg = r;
victim = conflict;
}
}
if (best_reg < 0 || !victim) return false;
m.Unassign(victim);
victim->assigned_reg = -1;
victim->generation++;
heap.push_back(victim);
std::push_heap(heap.begin(), heap.end(), cmp);
m.Assign(&li, best_reg);
li.assigned_reg = best_reg;
return true;
}
// ---- CreateChild ----
bool CreateChild(const LiveInterval &parent, int start_pos, int end_pos,
LiveInterval &child)
{
child = LiveInterval();
child.reg_class = parent.reg_class;
child.generation = parent.generation + 1;
child.hint_reg = -1;
child.assigned_reg = -1;
child.valnos = parent.valnos;
for (auto &seg : parent.segments)
{
if (seg.end < start_pos || seg.start > end_pos) continue;
Segment clipped = seg;
clipped.start = std::max(seg.start, start_pos);
clipped.end = std::min(seg.end, end_pos);
child.segments.push_back(clipped);
}
for (auto &use : parent.uses)
if (start_pos <= use.pos && use.pos <= end_pos)
child.uses.push_back(use);
return !child.uses.empty();
}
// ---- FindBestSplitPos ----
int FindBestSplitPos(const LiveInterval &li, LiveRegMatrix &m)
{
for (int i = (int)li.uses.size() - 2; i >= 0; --i)
{
int end_pos = li.uses[i].pos;
int hot_start = li.FirstUsePos();
int n = 0;
const int *regs = GetRegList(li.reg_class, n);
for (int r_idx = 0; r_idx < n; ++r_idx)
{
int r = regs[r_idx];
if (IsCallerSavedGP(r) && li.SegmentCrossesCall()) continue;
if (!m.CheckInterferenceRange(hot_start, end_pos, r))
return end_pos;
}
}
return -1;
}
// ---- TrySplit ----
bool TrySplit(LiveInterval &li, LiveRegMatrix &m,
std::vector<LiveInterval *> &heap,
std::vector<LiveInterval> &intervals,
const std::vector<int> &block_depth,
const std::vector<int> &pos_to_block,
std::vector<LiveInterval *> &spilled,
MachineFunction &func,
const SpillWeightCmp &cmp)
{
int split_pos = FindBestSplitPos(li, m);
if (split_pos < 0) return false;
LiveInterval hot;
if (!CreateChild(li, li.FirstUsePos(), split_pos, hot))
return false;
hot.vreg = li.vreg;
LiveInterval cold;
CreateChild(li, split_pos + 1, li.LastUsePos(), cold);
cold.vreg = func.CreateVReg(li.vreg_class);
cold.generation = li.generation + 1;
float w = 0.0f;
for (auto &use : cold.uses)
{
int blk = (use.pos >= 0 && use.pos < (int)pos_to_block.size())
? pos_to_block[use.pos] : 0;
int d = (blk >= 0 && blk < (int)block_depth.size())
? block_depth[blk] : 0;
float mult = std::pow(10.0f, (float)d);
if (use.is_def) mult *= 0.5f;
w += mult;
}
cold.spill_weight = w / cold.Length();
intervals.push_back(std::move(cold));
LiveInterval &cold_ref = intervals.back();
if (TryAnyFreeReg(hot, m))
{
li.assigned_reg = hot.assigned_reg;
li.segments = std::move(hot.segments);
li.uses = std::move(hot.uses);
}
else
{
li.assigned_reg = -2;
spilled.push_back(&li);
}
if (!TryAnyFreeReg(cold_ref, m))
{
heap.push_back(&cold_ref);
std::push_heap(heap.begin(), heap.end(), cmp);
}
return true;
}
} // anonymous namespace
// ---- LiveRegMatrix 方法namespace mir 内,不在匿名命名空间中)----
void LiveRegMatrix::Init(int num_regs)
{ reg_assignments_.assign(num_regs, {}); }
void LiveRegMatrix::Assign(LiveInterval *li, int phys_reg)
{
if (phys_reg < 0 || phys_reg >= (int)reg_assignments_.size()) return;
reg_assignments_[phys_reg].push_back(li);
}
void LiveRegMatrix::Unassign(LiveInterval *li)
{
for (auto &vec : reg_assignments_)
{
auto it = std::find(vec.begin(), vec.end(), li);
if (it != vec.end()) { vec.erase(it); return; }
}
}
bool LiveRegMatrix::CheckInterference(const LiveInterval &li, int phys_reg) const
{
if (phys_reg < 0 || phys_reg >= (int)reg_assignments_.size()) return true;
for (auto *other : reg_assignments_[phys_reg])
{
if (other->vreg == li.vreg) continue;
for (auto &sa : li.segments)
for (auto &sb : other->segments)
if (sa.Overlaps(sb)) return true;
}
return false;
}
LiveInterval *LiveRegMatrix::GetConflict(const LiveInterval &li,
int phys_reg) const
{
if (phys_reg < 0 || phys_reg >= (int)reg_assignments_.size()) return nullptr;
for (auto *other : reg_assignments_[phys_reg])
{
if (other->vreg == li.vreg) continue;
for (auto &sa : li.segments)
for (auto &sb : other->segments)
if (sa.Overlaps(sb)) return other;
}
return nullptr;
}
bool LiveRegMatrix::CheckInterferenceRange(int start, int end,
int phys_reg) const
{
if (phys_reg < 0 || phys_reg >= (int)reg_assignments_.size()) return true;
Segment range; range.start = start; range.end = end;
for (auto *other : reg_assignments_[phys_reg])
for (auto &sb : other->segments)
if (range.Overlaps(sb)) return true;
return false;
}
// ---- 对外入口 ----
void RunGreedyRegAlloc(MachineFunction &function);
void RunGreedyRegAlloc(MachineModule &module);
static void AllocateRegistersForFunction(MachineFunction &function)
{
if (function.GetNumVRegs() == 0) return;
// ---- 阶段 0活跃分析 + 预处理 ----
auto raw = ComputeInstLiveness(function);
auto intervals = EnhanceIntervals(raw, function);
auto &blocks = function.GetBlocks();
std::vector<int> pos_to_block;
std::vector<int> block_start_pos(blocks.size(), -1);
int global = 0;
for (int bi = 0; bi < (int)blocks.size(); ++bi)
{
if (!blocks[bi]) continue;
block_start_pos[bi] = global;
int cnt = (int)blocks[bi]->GetInstructions().size();
for (int j = 0; j < cnt; ++j) pos_to_block.push_back(bi);
global += cnt;
}
auto block_depth = AnalyzeLoopDepth(function);
ComputeSpillWeights(intervals, block_depth, pos_to_block);
PropagateCopyHints(intervals, function);
intervals.reserve(function.GetNumVRegs() * 4);
SpillWeightCmp cmp;
std::vector<LiveInterval *> spilled;
// ---- 阶段 1分配循环 ----
for (int round = 0; round < MAX_ROUNDS; ++round)
{
spilled.clear();
for (auto rc : {RegClass::GPR32, RegClass::FPR32})
{
// 构建堆:所有有效且未 split 的 vreg
std::vector<LiveInterval *> heap;
for (auto &li : intervals)
{
if (li.vreg < 0) continue;
if (li.reg_class == rc && !li.IsSplit())
heap.push_back(&li);
}
// 新轮次:重置所有 vreg 的分配状态
for (auto *p : heap) p->assigned_reg = -1;
std::make_heap(heap.begin(), heap.end(), cmp);
LiveRegMatrix matrix;
matrix.Init(32);
while (!heap.empty())
{
std::pop_heap(heap.begin(), heap.end(), cmp);
LiveInterval *li = heap.back();
heap.pop_back();
if (li->IsAllocated() || li->IsSplit()) continue;
// 尝试分配(按优先级)
if (TryAssign(*li, matrix, li->hint_reg)) continue;
if (TryAnyFreeReg(*li, matrix)) continue;
if (rc == RegClass::GPR32 && TryEvict(*li, matrix, heap, cmp)) continue;
if (TrySplit(*li, matrix, heap, intervals,
block_depth, pos_to_block, spilled, function, cmp)) continue;
li->assigned_reg = -2;
spilled.push_back(li);
}
}
if (spilled.empty()) break;
// ---- 溢出重写 ----
for (auto *li : spilled)
{
if (li->spill_slot < 0) li->spill_slot = li->vreg;
// 反向遍历 uses
for (int u = (int)li->uses.size() - 1; u >= 0; --u)
{
auto &use = li->uses[u];
int blk = pos_to_block[use.pos];
int local = use.pos - block_start_pos[blk];
if (use.is_def)
{
// 定义点后插入 StoreStack
blocks[blk]->InsertInst(local + 1,
MachineInstr(Opcode::StoreStack,
{Operand::VReg(li->vreg, li->vreg_class),
Operand::FrameIndex(li->spill_slot)}));
}
else
{
// 使用点前插入 LoadStack
int new_vreg = function.CreateVReg(li->vreg_class);
blocks[blk]->InsertInst(local,
MachineInstr(Opcode::LoadStack,
{Operand::VReg(new_vreg, li->vreg_class),
Operand::FrameIndex(li->spill_slot)}));
blocks[blk]->ReplaceVReg(local + 1, li->vreg, new_vreg);
}
}
}
// ---- 重新分析(每轮全新分配,不保留 prev_assigned----
raw = ComputeInstLiveness(function);
intervals = EnhanceIntervals(raw, function);
if (function.GetNumVRegs() > (int)intervals.size())
intervals.resize(function.GetNumVRegs());
// 重建位置映射(指令数已变)
pos_to_block.clear();
block_start_pos.assign(blocks.size(), -1);
global = 0;
for (int bi = 0; bi < (int)blocks.size(); ++bi)
{
if (!blocks[bi]) continue;
block_start_pos[bi] = global;
int cnt = (int)blocks[bi]->GetInstructions().size();
for (int j = 0; j < cnt; ++j) pos_to_block.push_back(bi);
global += cnt;
}
ComputeSpillWeights(intervals, block_depth, pos_to_block);
PropagateCopyHints(intervals, function);
}
// ---- 最终vreg → PhysReg 重写 ----
for (auto &block : blocks)
{
if (!block) continue;
for (auto &inst : block->GetInstructions())
{
for (auto &op : inst.GetOperands())
{
if (op.GetKind() != Operand::Kind::VReg) continue;
int vreg = op.GetVRegId();
int phys = -1;
if (vreg >= 0 && vreg < (int)intervals.size())
phys = intervals[vreg].assigned_reg;
if (phys < 0) phys = 48; // 兜底 X16应对未分配 vreg
op = Operand::Reg(static_cast<PhysReg>(phys));
}
}
}
}
void RunGreedyRegAlloc(MachineFunction &function)
{ AllocateRegistersForFunction(function); }
void RunGreedyRegAlloc(MachineModule &module)
{
for (auto &func : module.GetFunctions())
if (func) RunGreedyRegAlloc(*func);
}
} // namespace mir

@ -429,10 +429,119 @@ namespace mir
li.start = start;
li.end = end;
li.vreg_class = func.GetVRegClass(v);
li.reg_class = ToRegClass(li.vreg_class);
li.assigned_reg = -1;
li.hint_reg = -1;
li.generation = 0;
intervals.push_back(li);
}
return intervals;
}
namespace
{
// 全局指令位置 → 块索引 + 局部指令索引
struct GlobalPosInfo { int block_idx; int local_idx; };
} // anonymous namespace
std::vector<LiveInterval> EnhanceIntervals(
const std::vector<LiveInterval> &raw,
MachineFunction &function)
{
std::vector<LiveInterval> result = raw;
auto &blocks = function.GetBlocks();
// ---- 构建 pos → block 映射 + block_start_pos ----
std::vector<int> pos_to_block;
std::vector<int> block_start_pos(blocks.size(), -1);
int global = 0;
for (int bi = 0; bi < (int)blocks.size(); ++bi)
{
if (!blocks[bi]) continue;
block_start_pos[bi] = global;
int cnt = (int)blocks[bi]->GetInstructions().size();
for (int j = 0; j < cnt; ++j)
pos_to_block.push_back(bi);
global += cnt;
}
// ---- Pass A收集 VNInfo + UsePosition正向扫描----
for (int bi = 0; bi < (int)blocks.size(); ++bi)
{
if (!blocks[bi]) continue;
auto &insts = blocks[bi]->GetInstructions();
int base = block_start_pos[bi];
for (int j = 0; j < (int)insts.size(); ++j)
{
int pos = base + j;
const auto &inst = insts[j];
int def_vreg;
std::vector<int> use_vregs;
ExtractDefUse(inst, def_vreg, use_vregs);
if (def_vreg >= 0 && def_vreg < (int)result.size())
{
auto &li = result[def_vreg];
VNInfo vn;
vn.id = (int)li.valnos.size();
vn.def_pos = pos;
vn.def_opcode = inst.GetOpcode();
li.valnos.push_back(vn);
li.uses.push_back({pos, true, vn.id, inst.GetOpcode()});
}
for (int u : use_vregs)
{
if (u < 0 || u >= (int)result.size()) continue;
auto &li = result[u];
int vn_id = li.valnos.empty() ? 0 : (int)li.valnos.size() - 1;
li.uses.push_back({pos, false, vn_id, inst.GetOpcode()});
}
}
}
// ---- Pass B构建初始 segments单段 [first_use, last_use]----
for (auto &li : result)
{
if (li.uses.empty()) continue;
Segment seg;
seg.start = li.FirstUsePos();
seg.end = li.LastUsePos();
seg.vn_id = 0;
seg.crosses_call = false;
li.segments.push_back(seg);
}
// ---- Pass C标记 crosses_call ----
for (int bi = 0; bi < (int)blocks.size(); ++bi)
{
if (!blocks[bi]) continue;
auto &insts = blocks[bi]->GetInstructions();
int base = block_start_pos[bi];
for (int j = 0; j < (int)insts.size(); ++j)
{
if (insts[j].GetOpcode() != Opcode::Call) continue;
int call_pos = base + j;
for (auto &li : result)
{
for (auto &seg : li.segments)
{
if (seg.Contains(call_pos))
{
seg.crosses_call = true;
break;
}
}
}
}
}
return result;
}
} // namespace mir

@ -711,13 +711,16 @@ namespace mir
namespace mir
{
#if 0
void RunLinearScanRegAlloc(MachineFunction &func)
{
if (func.GetNumVRegs() == 0)
return;
RunLinearScan(func);
}
#endif
#if 0
void RunLinearScanRegAlloc(MachineModule &module)
{
for (auto &function : module.GetFunctions())
@ -726,5 +729,6 @@ namespace mir
RunLinearScanRegAlloc(*function);
}
}
#endif
} // namespace mir

@ -15,4 +15,22 @@ namespace mir
return instructions_.back();
}
void MachineBasicBlock::InsertInst(int local_idx, MachineInstr inst)
{
if (local_idx < 0 || local_idx > (int)instructions_.size()) return;
instructions_.insert(instructions_.begin() + local_idx, std::move(inst));
}
void MachineBasicBlock::ReplaceVReg(int local_idx, int old_vreg,
int new_vreg)
{
if (local_idx < 0 || local_idx >= (int)instructions_.size()) return;
for (auto &op : instructions_[local_idx].GetOperands())
{
if (op.GetKind() == Operand::Kind::VReg &&
op.GetVRegId() == old_vreg)
op = Operand::VReg(new_vreg, op.GetVRegClass());
}
}
} // namespace mir

@ -1,3 +1,4 @@
#if 0
#include "mir/MIR.h"
#include <algorithm>
@ -554,28 +555,6 @@ namespace mir
}
}
// 保守修复:对包含大量 vreg 定义的 block强制 def 间全干涉
// 防止 spill reload vreg 的短活区间导致错误寄存器复用。
// 阈值从 20 提高到 200——低阈值与 MAX_SPILL_ROUNDS=1 组合在
// 09_BFS 等函数上导致过度干涉,图着色产生错误分配。
for (size_t bi = 0; bi < blocks.size(); ++bi)
{
std::set<int> block_defs;
for (const auto &inst : blocks[bi]->GetInstructions())
{
auto du = GetInstDefUse(inst, function);
for (int d : du.defs)
if (d >= 0 && IsGPClass(function.GetVRegClass(d)))
block_defs.insert(d);
}
if (block_defs.size() > 200)
{
for (int u : block_defs)
for (int v : block_defs)
if (u != v)
graph.AddEdge(u, v);
}
}
}
static void BuildInterferenceForFP(
@ -1424,12 +1403,21 @@ namespace mir
if (function.GetNumVRegs() == 0)
return;
// 大 vreg 数函数回退到线性扫描:块级活跃分析在密集 def-use
// 模式下导致 spill 重写后干涉边缺失。线性扫描的指令级精确
// 区间能正确处理(如 30_many_dimensions
if (function.GetNumVRegs() > 100)
{
RunLinearScanRegAlloc(function);
return;
}
// 限制 spill 轮次为 1block-level liveness 下多轮 spill 创建
// 的 reload vreg 与保守修复block_defs 全干涉)交互,产生错误
// 寄存器分配。1 轮足够——循环外 RewriteWithAllocation 用 scratch
// 寄存器处理所有剩余 spill。修复04_arr_defn3/05_arr_defn4 段错误
// 及 09_BFS bad_alloc。
const int MAX_SPILL_ROUNDS = 1;
const int MAX_SPILL_ROUNDS = (function.GetNumVRegs() > 120) ? 3 : 10;
for (int round = 0; round < MAX_SPILL_ROUNDS; ++round)
{
// 构建 VReg → 定义指令映射(用于再物化判断)
@ -1942,3 +1930,4 @@ namespace mir
}
} // namespace mir
#endif

Loading…
Cancel
Save