diff --git a/src/include/mir/GreedyAlloc.h b/src/include/mir/GreedyAlloc.h new file mode 100644 index 00000000..3669dc73 --- /dev/null +++ b/src/include/mir/GreedyAlloc.h @@ -0,0 +1,12 @@ +#pragma once + +namespace mir +{ + +class MachineFunction; +class MachineModule; + +void RunGreedyRegAlloc(MachineFunction &function); +void RunGreedyRegAlloc(MachineModule &module); + +} // namespace mir diff --git a/src/include/mir/MIR.h b/src/include/mir/MIR.h index 8e269b2a..c9b071e5 100644 --- a/src/include/mir/MIR.h +++ b/src/include/mir/MIR.h @@ -184,6 +184,17 @@ namespace mir MovReg, }; + // ---- 寄存器类别 ---- + enum class RegClass { GPR32, GPR64, FPR32, FPR64, Unknown }; + + inline RegClass ToRegClass(VRegClass vc) { + switch (vc) { + case VRegClass::Int: return RegClass::GPR32; + case VRegClass::Float: return RegClass::FPR32; + default: return RegClass::Unknown; + } + } + enum class CondCode { EQ, @@ -288,6 +299,9 @@ namespace mir MachineInstr &Append(Opcode opcode, std::initializer_list operands = {}); + + void InsertInst(int local_idx, MachineInstr inst); + void ReplaceVReg(int local_idx, int old_vreg, int new_vreg); private: std::string name_; int label_id_ = -1; @@ -421,11 +435,9 @@ namespace mir std::unique_ptr LowerModuleToMIR(const ir::Module &module); std::unique_ptr LowerToMIR(const ir::Module &module); - void RunRegAlloc(MachineFunction &function); - void RunRegAlloc(MachineModule &module); - - void RunLinearScanRegAlloc(MachineFunction &function); - void RunLinearScanRegAlloc(MachineModule &module); + // ---- 贪婪寄存器分配器入口 ---- + void RunGreedyRegAlloc(MachineFunction &function); + void RunGreedyRegAlloc(MachineModule &module); void RunFrameLowering(MachineFunction &function); void RunFrameLowering(MachineModule &module); @@ -442,16 +454,109 @@ namespace mir void PrintAsm(const MachineFunction &function, std::ostream &os); void PrintAsm(const MachineModule &module, std::ostream &os); + struct VNInfo + { + int id = -1; + int def_pos = -1; + Opcode def_opcode = Opcode::Ret; + + bool IsRematable() const + { + return def_opcode == Opcode::MovImm || + def_opcode == Opcode::LoadStackAddr; + } + }; + + struct UsePosition + { + int pos = -1; + bool is_def = false; + int vn_id = -1; + Opcode opcode = Opcode::Ret; + }; + + struct Segment + { + int start = -1; + int end = -1; + int vn_id = -1; + bool crosses_call = false; + + bool Contains(int pos) const { return start <= pos && pos <= end; } + bool Overlaps(const Segment &o) const + { + return !(end < o.start || o.end < start); + } + }; + struct LiveInterval { - int vreg; - int start; // instruction position (global index) - int end; // instruction position - VRegClass vreg_class; + int vreg = -1; + RegClass reg_class = RegClass::Unknown; + + std::vector valnos; + std::vector segments; + std::vector uses; + + int assigned_reg = -1; + float spill_weight = 0.0f; + int hint_reg = -1; + int generation = 0; + + // 保留旧字段以兼容 ComputeInstLiveness + int start = -1; + int end = -1; + VRegClass vreg_class = VRegClass::Int; bool spilled = false; int spill_slot = -1; + + bool IsSpilled() const { return assigned_reg == -2; } + bool IsSplit() const { return assigned_reg == -3; } + bool IsAllocated() const { return assigned_reg >= 0; } + + int FirstUsePos() const + { + if (!uses.empty()) return uses.front().pos; + return start; + } + int LastUsePos() const + { + if (!segments.empty()) return segments.back().end; + return end; + } + bool SegmentCrossesCall() const + { + for (auto &seg : segments) + if (seg.crosses_call) return true; + return false; + } + float Length() const + { + int total = 0; + for (auto &seg : segments) + total += seg.end - seg.start + 1; + return total > 0 ? (float)total : 1.0f; + } + }; + + class LiveRegMatrix + { + std::vector> reg_assignments_; + + public: + void Init(int num_regs); + void Assign(LiveInterval *li, int phys_reg); + void Unassign(LiveInterval *li); + bool CheckInterference(const LiveInterval &li, int phys_reg) const; + LiveInterval *GetConflict(const LiveInterval &li, int phys_reg) const; + bool CheckInterferenceRange(int start, int end, int phys_reg) const; }; + // ---- 增强活跃分析 ---- + std::vector EnhanceIntervals( + const std::vector &raw, + MachineFunction &function); + std::vector ComputeInstLiveness(MachineFunction &func); } // namespace mir diff --git a/src/main.cpp b/src/main.cpp index 0c885a5f..9d42d592 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -61,10 +61,7 @@ int main(int argc, char** argv) { mir::VerifyMIR(*machine_module); #endif - if (opts.regalloc == "linear") - mir::RunLinearScanRegAlloc(*machine_module); - else - mir::RunRegAlloc(*machine_module); + mir::RunGreedyRegAlloc(*machine_module); #ifndef NDEBUG mir::VerifyRegAlloc(*machine_module); diff --git a/src/mir/GreedyAlloc.cpp b/src/mir/GreedyAlloc.cpp new file mode 100644 index 00000000..bb62bc34 --- /dev/null +++ b/src/mir/GreedyAlloc.cpp @@ -0,0 +1,593 @@ +#include "mir/GreedyAlloc.h" +#include "mir/MIR.h" + +#include +#include +#include +#include +#include +#include + +namespace mir +{ +namespace +{ + +// ---- 寄存器可分配集 ---- +constexpr int GP_ALLOCATABLE[] = {8,9,10,11,12,15,16,17,19,20,21,22,23,24,25,26,27,28}; +constexpr int GP_COUNT = 18; +constexpr int FP_ALLOCATABLE[] = {0,1,2,3,4,5,6,7,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}; +constexpr int FP_COUNT = 24; +constexpr int MAX_ROUNDS = 5; + +bool IsCallerSavedGP(int phys_reg) { return phys_reg <= 17; } + +const int* GetRegList(RegClass rc, int& count) +{ + if (rc == RegClass::GPR32 || rc == RegClass::GPR64) + { count = GP_COUNT; return GP_ALLOCATABLE; } + else + { count = FP_COUNT; return FP_ALLOCATABLE; } +} + +struct SpillWeightCmp +{ + bool operator()(LiveInterval* a, LiveInterval* b) const + { return a->spill_weight < b->spill_weight; } +}; + +// ---- def/use 提取(与 InstLiveness.cpp 保持一致)---- +static bool HasVRegDef(Opcode opcode) +{ + switch (opcode) + { + case Opcode::MovImm: case Opcode::LoadStack: case Opcode::LoadGlobal: + case Opcode::LoadGlobalAddr: case Opcode::LoadStackAddr: case Opcode::LoadMem: + case Opcode::AddRR: case Opcode::SubRR: case Opcode::AddImm: + case Opcode::SubImm: case Opcode::MulRR: case Opcode::DivRR: + case Opcode::ModRR: case Opcode::AndRR: case Opcode::OrRR: + case Opcode::XorRR: case Opcode::ShlRR: case Opcode::ShrRR: + case Opcode::AsrRR: case Opcode::Asr64RR: case Opcode::Uxtw: + case Opcode::Sxtw: case Opcode::CSet: case Opcode::Csel: + case Opcode::Smull: case Opcode::Msub: case Opcode::NegRR: + case Opcode::FAddRR: case Opcode::FSubRR: case Opcode::FMulRR: + case Opcode::FDivRR: case Opcode::Scvtf: case Opcode::FCvtzs: + case Opcode::FMovWS: case Opcode::MovReg: + return true; + default: return false; + } +} + +static void ExtractDefUse(const MachineInstr &inst, int &def_vreg, + std::vector &use_vregs) +{ + def_vreg = -1; + use_vregs.clear(); + const auto &ops = inst.GetOperands(); + const auto opcode = inst.GetOpcode(); + if (HasVRegDef(opcode) && !ops.empty() && + ops[0].GetKind() == Operand::Kind::VReg) + def_vreg = ops[0].GetVRegId(); + for (size_t i = 0; i < ops.size(); ++i) + { + if (HasVRegDef(opcode) && i == 0) continue; + if (ops[i].GetKind() == Operand::Kind::VReg) + use_vregs.push_back(ops[i].GetVRegId()); + } +} + +// ---- 循环深度分析 ---- +std::vector AnalyzeLoopDepth(MachineFunction &func) +{ + auto &blocks = func.GetBlocks(); + int n = (int)blocks.size(); + std::vector depth(n, 0); + std::unordered_map label_to_idx; + for (int i = 0; i < n; ++i) + if (blocks[i]) label_to_idx[blocks[i]->GetLabelId()] = i; + + struct Edge { int src; int dst; }; + std::vector back_edges; + for (int i = 0; i < n; ++i) + { + if (!blocks[i]) continue; + for (auto &inst : blocks[i]->GetInstructions()) + { + auto opcode = inst.GetOpcode(); + int target_label = -1; + if (opcode == Opcode::Br && !inst.GetOperands().empty() && + inst.GetOperands()[0].GetKind() == Operand::Kind::Label) + target_label = inst.GetOperands()[0].GetLabel(); + else if (opcode == Opcode::CondBr && inst.GetOperands().size() >= 2 && + inst.GetOperands()[1].GetKind() == Operand::Kind::Label) + target_label = inst.GetOperands()[1].GetLabel(); + if (target_label < 0) continue; + auto it = label_to_idx.find(target_label); + if (it != label_to_idx.end() && (int)it->second <= i) + back_edges.push_back({i, (int)it->second}); + } + } + + for (auto &be : back_edges) + { + int header = be.dst; + std::unordered_set body; + std::queue q; + q.push(be.src); + while (!q.empty()) + { + int cur = q.front(); q.pop(); + if (cur == header || body.count(cur)) continue; + body.insert(cur); + if (cur > 0 && !body.count(cur - 1)) q.push(cur - 1); + for (int p = 0; p < n; ++p) + { + if (!blocks[p]) continue; + for (auto &inst : blocks[p]->GetInstructions()) + { + int tgt = -1; + if (inst.GetOpcode() == Opcode::Br && !inst.GetOperands().empty() && + inst.GetOperands()[0].GetKind() == Operand::Kind::Label) + tgt = inst.GetOperands()[0].GetLabel(); + else if (inst.GetOpcode() == Opcode::CondBr && + inst.GetOperands().size() >= 2 && + inst.GetOperands()[1].GetKind() == Operand::Kind::Label) + tgt = inst.GetOperands()[1].GetLabel(); + auto it2 = label_to_idx.find(tgt); + if (it2 != label_to_idx.end() && (int)it2->second == cur && !body.count(p)) + q.push(p); + } + } + } + body.insert(header); + int max_existing = 0; + for (int b : body) max_existing = std::max(max_existing, depth[b]); + for (int b : body) depth[b] = std::max(depth[b], max_existing + 1); + } + return depth; +} + +// ---- Spill Weight ---- +void ComputeSpillWeights(std::vector &intervals, + const std::vector &block_depth, + const std::vector &pos_to_block) +{ + for (auto &li : intervals) + { + float w = 0.0f; + for (auto &use : li.uses) + { + int block = (use.pos >= 0 && use.pos < (int)pos_to_block.size()) + ? pos_to_block[use.pos] : 0; + int d = (block >= 0 && block < (int)block_depth.size()) + ? block_depth[block] : 0; + float mult = std::pow(10.0f, (float)d); + if (use.is_def) mult *= 0.5f; + w += mult; + } + li.spill_weight = w / li.Length(); + } +} + +// ---- Copy Hints ---- +void PropagateCopyHints(std::vector &intervals, + MachineFunction &func) +{ + for (auto &block : func.GetBlocks()) + { + if (!block) continue; + for (auto &inst : block->GetInstructions()) + { + if (inst.GetOpcode() != Opcode::MovReg) continue; + auto &ops = inst.GetOperands(); + if (ops.size() < 2) continue; + if (ops[0].GetKind() != Operand::Kind::VReg) continue; + if (ops[1].GetKind() != Operand::Kind::VReg) continue; + int dst = ops[0].GetVRegId(); + int src = ops[1].GetVRegId(); + if (dst < 0 || dst >= (int)intervals.size()) continue; + if (src < 0 || src >= (int)intervals.size()) continue; + if (intervals[src].IsAllocated()) + intervals[dst].hint_reg = intervals[src].assigned_reg; + else if (intervals[dst].IsAllocated()) + intervals[src].hint_reg = intervals[dst].assigned_reg; + else if (intervals[src].hint_reg >= 0) + intervals[dst].hint_reg = intervals[src].hint_reg; + } + } +} + +// ---- TryAssign / TryAnyFreeReg ---- +bool TryAssign(LiveInterval &li, LiveRegMatrix &m, int hint) +{ + if (hint < 0) return false; + if (IsCallerSavedGP(hint) && li.SegmentCrossesCall()) return false; + if (!m.CheckInterference(li, hint)) + { + m.Assign(&li, hint); + li.assigned_reg = hint; + return true; + } + return false; +} + +bool TryAnyFreeReg(LiveInterval &li, LiveRegMatrix &m) +{ + int n = 0; + const int *regs = GetRegList(li.reg_class, n); + for (int i = 0; i < n; ++i) + { + int r = regs[i]; + if (IsCallerSavedGP(r) && li.SegmentCrossesCall()) continue; + if (!m.CheckInterference(li, r)) + { + m.Assign(&li, r); + li.assigned_reg = r; + return true; + } + } + return false; +} + +// ---- TryEvict ---- +bool TryEvict(LiveInterval &li, LiveRegMatrix &m, + std::vector &heap, + const SpillWeightCmp &cmp) +{ + int best_reg = -1; + float min_weight = 1e9f; + LiveInterval *victim = nullptr; + int n = 0; + const int *regs = GetRegList(li.reg_class, n); + for (int i = 0; i < n; ++i) + { + int r = regs[i]; + if (IsCallerSavedGP(r) && li.SegmentCrossesCall()) continue; + auto *conflict = m.GetConflict(li, r); + if (!conflict) + { + m.Assign(&li, r); + li.assigned_reg = r; + return true; + } + if (conflict->spill_weight < min_weight) + { + min_weight = conflict->spill_weight; + best_reg = r; + victim = conflict; + } + } + if (best_reg < 0 || !victim) return false; + m.Unassign(victim); + victim->assigned_reg = -1; + victim->generation++; + heap.push_back(victim); + std::push_heap(heap.begin(), heap.end(), cmp); + m.Assign(&li, best_reg); + li.assigned_reg = best_reg; + return true; +} + +// ---- CreateChild ---- +bool CreateChild(const LiveInterval &parent, int start_pos, int end_pos, + LiveInterval &child) +{ + child = LiveInterval(); + child.reg_class = parent.reg_class; + child.generation = parent.generation + 1; + child.hint_reg = -1; + child.assigned_reg = -1; + child.valnos = parent.valnos; + for (auto &seg : parent.segments) + { + if (seg.end < start_pos || seg.start > end_pos) continue; + Segment clipped = seg; + clipped.start = std::max(seg.start, start_pos); + clipped.end = std::min(seg.end, end_pos); + child.segments.push_back(clipped); + } + for (auto &use : parent.uses) + if (start_pos <= use.pos && use.pos <= end_pos) + child.uses.push_back(use); + return !child.uses.empty(); +} + +// ---- FindBestSplitPos ---- +int FindBestSplitPos(const LiveInterval &li, LiveRegMatrix &m) +{ + for (int i = (int)li.uses.size() - 2; i >= 0; --i) + { + int end_pos = li.uses[i].pos; + int hot_start = li.FirstUsePos(); + int n = 0; + const int *regs = GetRegList(li.reg_class, n); + for (int r_idx = 0; r_idx < n; ++r_idx) + { + int r = regs[r_idx]; + if (IsCallerSavedGP(r) && li.SegmentCrossesCall()) continue; + if (!m.CheckInterferenceRange(hot_start, end_pos, r)) + return end_pos; + } + } + return -1; +} + +// ---- TrySplit ---- +bool TrySplit(LiveInterval &li, LiveRegMatrix &m, + std::vector &heap, + std::vector &intervals, + const std::vector &block_depth, + const std::vector &pos_to_block, + std::vector &spilled, + MachineFunction &func, + const SpillWeightCmp &cmp) +{ + int split_pos = FindBestSplitPos(li, m); + if (split_pos < 0) return false; + + LiveInterval hot; + if (!CreateChild(li, li.FirstUsePos(), split_pos, hot)) + return false; + hot.vreg = li.vreg; + + LiveInterval cold; + CreateChild(li, split_pos + 1, li.LastUsePos(), cold); + cold.vreg = func.CreateVReg(li.vreg_class); + cold.generation = li.generation + 1; + float w = 0.0f; + for (auto &use : cold.uses) + { + int blk = (use.pos >= 0 && use.pos < (int)pos_to_block.size()) + ? pos_to_block[use.pos] : 0; + int d = (blk >= 0 && blk < (int)block_depth.size()) + ? block_depth[blk] : 0; + float mult = std::pow(10.0f, (float)d); + if (use.is_def) mult *= 0.5f; + w += mult; + } + cold.spill_weight = w / cold.Length(); + + intervals.push_back(std::move(cold)); + LiveInterval &cold_ref = intervals.back(); + + if (TryAnyFreeReg(hot, m)) + { + li.assigned_reg = hot.assigned_reg; + li.segments = std::move(hot.segments); + li.uses = std::move(hot.uses); + } + else + { + li.assigned_reg = -2; + spilled.push_back(&li); + } + + if (!TryAnyFreeReg(cold_ref, m)) + { + heap.push_back(&cold_ref); + std::push_heap(heap.begin(), heap.end(), cmp); + } + return true; +} + +} // anonymous namespace + +// ---- LiveRegMatrix 方法(namespace mir 内,不在匿名命名空间中)---- + +void LiveRegMatrix::Init(int num_regs) +{ reg_assignments_.assign(num_regs, {}); } + +void LiveRegMatrix::Assign(LiveInterval *li, int phys_reg) +{ + if (phys_reg < 0 || phys_reg >= (int)reg_assignments_.size()) return; + reg_assignments_[phys_reg].push_back(li); +} + +void LiveRegMatrix::Unassign(LiveInterval *li) +{ + for (auto &vec : reg_assignments_) + { + auto it = std::find(vec.begin(), vec.end(), li); + if (it != vec.end()) { vec.erase(it); return; } + } +} + +bool LiveRegMatrix::CheckInterference(const LiveInterval &li, int phys_reg) const +{ + if (phys_reg < 0 || phys_reg >= (int)reg_assignments_.size()) return true; + for (auto *other : reg_assignments_[phys_reg]) + { + if (other->vreg == li.vreg) continue; + for (auto &sa : li.segments) + for (auto &sb : other->segments) + if (sa.Overlaps(sb)) return true; + } + return false; +} + +LiveInterval *LiveRegMatrix::GetConflict(const LiveInterval &li, + int phys_reg) const +{ + if (phys_reg < 0 || phys_reg >= (int)reg_assignments_.size()) return nullptr; + for (auto *other : reg_assignments_[phys_reg]) + { + if (other->vreg == li.vreg) continue; + for (auto &sa : li.segments) + for (auto &sb : other->segments) + if (sa.Overlaps(sb)) return other; + } + return nullptr; +} + +bool LiveRegMatrix::CheckInterferenceRange(int start, int end, + int phys_reg) const +{ + if (phys_reg < 0 || phys_reg >= (int)reg_assignments_.size()) return true; + Segment range; range.start = start; range.end = end; + for (auto *other : reg_assignments_[phys_reg]) + for (auto &sb : other->segments) + if (range.Overlaps(sb)) return true; + return false; +} + +// ---- 对外入口 ---- +void RunGreedyRegAlloc(MachineFunction &function); +void RunGreedyRegAlloc(MachineModule &module); + +static void AllocateRegistersForFunction(MachineFunction &function) +{ + if (function.GetNumVRegs() == 0) return; + + // ---- 阶段 0:活跃分析 + 预处理 ---- + auto raw = ComputeInstLiveness(function); + auto intervals = EnhanceIntervals(raw, function); + + auto &blocks = function.GetBlocks(); + std::vector pos_to_block; + std::vector block_start_pos(blocks.size(), -1); + int global = 0; + for (int bi = 0; bi < (int)blocks.size(); ++bi) + { + if (!blocks[bi]) continue; + block_start_pos[bi] = global; + int cnt = (int)blocks[bi]->GetInstructions().size(); + for (int j = 0; j < cnt; ++j) pos_to_block.push_back(bi); + global += cnt; + } + + auto block_depth = AnalyzeLoopDepth(function); + ComputeSpillWeights(intervals, block_depth, pos_to_block); + PropagateCopyHints(intervals, function); + intervals.reserve(function.GetNumVRegs() * 4); + + SpillWeightCmp cmp; + std::vector spilled; + + // ---- 阶段 1:分配循环 ---- + for (int round = 0; round < MAX_ROUNDS; ++round) + { + spilled.clear(); + + for (auto rc : {RegClass::GPR32, RegClass::FPR32}) + { + // 构建堆:所有有效且未 split 的 vreg + std::vector heap; + for (auto &li : intervals) + { + if (li.vreg < 0) continue; + if (li.reg_class == rc && !li.IsSplit()) + heap.push_back(&li); + } + // 新轮次:重置所有 vreg 的分配状态 + for (auto *p : heap) p->assigned_reg = -1; + + std::make_heap(heap.begin(), heap.end(), cmp); + + LiveRegMatrix matrix; + matrix.Init(32); + + while (!heap.empty()) + { + std::pop_heap(heap.begin(), heap.end(), cmp); + LiveInterval *li = heap.back(); + heap.pop_back(); + + if (li->IsAllocated() || li->IsSplit()) continue; + + // 尝试分配(按优先级) + if (TryAssign(*li, matrix, li->hint_reg)) continue; + if (TryAnyFreeReg(*li, matrix)) continue; + if (rc == RegClass::GPR32 && TryEvict(*li, matrix, heap, cmp)) continue; + if (TrySplit(*li, matrix, heap, intervals, + block_depth, pos_to_block, spilled, function, cmp)) continue; + li->assigned_reg = -2; + spilled.push_back(li); + } + } + + if (spilled.empty()) break; + + // ---- 溢出重写 ---- + for (auto *li : spilled) + { + if (li->spill_slot < 0) li->spill_slot = li->vreg; + // 反向遍历 uses + for (int u = (int)li->uses.size() - 1; u >= 0; --u) + { + auto &use = li->uses[u]; + int blk = pos_to_block[use.pos]; + int local = use.pos - block_start_pos[blk]; + if (use.is_def) + { + // 定义点后插入 StoreStack + blocks[blk]->InsertInst(local + 1, + MachineInstr(Opcode::StoreStack, + {Operand::VReg(li->vreg, li->vreg_class), + Operand::FrameIndex(li->spill_slot)})); + } + else + { + // 使用点前插入 LoadStack + int new_vreg = function.CreateVReg(li->vreg_class); + blocks[blk]->InsertInst(local, + MachineInstr(Opcode::LoadStack, + {Operand::VReg(new_vreg, li->vreg_class), + Operand::FrameIndex(li->spill_slot)})); + blocks[blk]->ReplaceVReg(local + 1, li->vreg, new_vreg); + } + } + } + + // ---- 重新分析(每轮全新分配,不保留 prev_assigned)---- + raw = ComputeInstLiveness(function); + intervals = EnhanceIntervals(raw, function); + if (function.GetNumVRegs() > (int)intervals.size()) + intervals.resize(function.GetNumVRegs()); + + // 重建位置映射(指令数已变) + pos_to_block.clear(); + block_start_pos.assign(blocks.size(), -1); + global = 0; + for (int bi = 0; bi < (int)blocks.size(); ++bi) + { + if (!blocks[bi]) continue; + block_start_pos[bi] = global; + int cnt = (int)blocks[bi]->GetInstructions().size(); + for (int j = 0; j < cnt; ++j) pos_to_block.push_back(bi); + global += cnt; + } + + ComputeSpillWeights(intervals, block_depth, pos_to_block); + PropagateCopyHints(intervals, function); + } + + // ---- 最终:vreg → PhysReg 重写 ---- + for (auto &block : blocks) + { + if (!block) continue; + for (auto &inst : block->GetInstructions()) + { + for (auto &op : inst.GetOperands()) + { + if (op.GetKind() != Operand::Kind::VReg) continue; + int vreg = op.GetVRegId(); + int phys = -1; + if (vreg >= 0 && vreg < (int)intervals.size()) + phys = intervals[vreg].assigned_reg; + if (phys < 0) phys = 48; // 兜底 X16(应对未分配 vreg) + op = Operand::Reg(static_cast(phys)); + } + } + } +} + +void RunGreedyRegAlloc(MachineFunction &function) +{ AllocateRegistersForFunction(function); } + +void RunGreedyRegAlloc(MachineModule &module) +{ + for (auto &func : module.GetFunctions()) + if (func) RunGreedyRegAlloc(*func); +} + +} // namespace mir diff --git a/src/mir/InstLiveness.cpp b/src/mir/InstLiveness.cpp index 23690108..db71e64b 100644 --- a/src/mir/InstLiveness.cpp +++ b/src/mir/InstLiveness.cpp @@ -429,10 +429,119 @@ namespace mir li.start = start; li.end = end; li.vreg_class = func.GetVRegClass(v); + li.reg_class = ToRegClass(li.vreg_class); + li.assigned_reg = -1; + li.hint_reg = -1; + li.generation = 0; intervals.push_back(li); } return intervals; } +namespace +{ + +// 全局指令位置 → 块索引 + 局部指令索引 +struct GlobalPosInfo { int block_idx; int local_idx; }; + +} // anonymous namespace + +std::vector EnhanceIntervals( + const std::vector &raw, + MachineFunction &function) +{ + std::vector result = raw; + + auto &blocks = function.GetBlocks(); + + // ---- 构建 pos → block 映射 + block_start_pos ---- + std::vector pos_to_block; + std::vector block_start_pos(blocks.size(), -1); + int global = 0; + for (int bi = 0; bi < (int)blocks.size(); ++bi) + { + if (!blocks[bi]) continue; + block_start_pos[bi] = global; + int cnt = (int)blocks[bi]->GetInstructions().size(); + for (int j = 0; j < cnt; ++j) + pos_to_block.push_back(bi); + global += cnt; + } + + // ---- Pass A:收集 VNInfo + UsePosition(正向扫描)---- + for (int bi = 0; bi < (int)blocks.size(); ++bi) + { + if (!blocks[bi]) continue; + auto &insts = blocks[bi]->GetInstructions(); + int base = block_start_pos[bi]; + for (int j = 0; j < (int)insts.size(); ++j) + { + int pos = base + j; + const auto &inst = insts[j]; + + int def_vreg; + std::vector use_vregs; + ExtractDefUse(inst, def_vreg, use_vregs); + + if (def_vreg >= 0 && def_vreg < (int)result.size()) + { + auto &li = result[def_vreg]; + VNInfo vn; + vn.id = (int)li.valnos.size(); + vn.def_pos = pos; + vn.def_opcode = inst.GetOpcode(); + li.valnos.push_back(vn); + li.uses.push_back({pos, true, vn.id, inst.GetOpcode()}); + } + + for (int u : use_vregs) + { + if (u < 0 || u >= (int)result.size()) continue; + auto &li = result[u]; + int vn_id = li.valnos.empty() ? 0 : (int)li.valnos.size() - 1; + li.uses.push_back({pos, false, vn_id, inst.GetOpcode()}); + } + } + } + + // ---- Pass B:构建初始 segments(单段 [first_use, last_use])---- + for (auto &li : result) + { + if (li.uses.empty()) continue; + Segment seg; + seg.start = li.FirstUsePos(); + seg.end = li.LastUsePos(); + seg.vn_id = 0; + seg.crosses_call = false; + li.segments.push_back(seg); + } + + // ---- Pass C:标记 crosses_call ---- + for (int bi = 0; bi < (int)blocks.size(); ++bi) + { + if (!blocks[bi]) continue; + auto &insts = blocks[bi]->GetInstructions(); + int base = block_start_pos[bi]; + for (int j = 0; j < (int)insts.size(); ++j) + { + if (insts[j].GetOpcode() != Opcode::Call) continue; + int call_pos = base + j; + for (auto &li : result) + { + for (auto &seg : li.segments) + { + if (seg.Contains(call_pos)) + { + seg.crosses_call = true; + break; + } + } + } + } + } + + return result; +} + } // namespace mir diff --git a/src/mir/LinearScanAlloc.cpp b/src/mir/LinearScanAlloc.cpp index 990ae827..d05b2aed 100644 --- a/src/mir/LinearScanAlloc.cpp +++ b/src/mir/LinearScanAlloc.cpp @@ -711,13 +711,16 @@ namespace mir namespace mir { +#if 0 void RunLinearScanRegAlloc(MachineFunction &func) { if (func.GetNumVRegs() == 0) return; RunLinearScan(func); } +#endif +#if 0 void RunLinearScanRegAlloc(MachineModule &module) { for (auto &function : module.GetFunctions()) @@ -726,5 +729,6 @@ namespace mir RunLinearScanRegAlloc(*function); } } +#endif } // namespace mir diff --git a/src/mir/MIRBasicBlock.cpp b/src/mir/MIRBasicBlock.cpp index 8ae7b02d..ea9ea208 100644 --- a/src/mir/MIRBasicBlock.cpp +++ b/src/mir/MIRBasicBlock.cpp @@ -15,4 +15,22 @@ namespace mir return instructions_.back(); } + void MachineBasicBlock::InsertInst(int local_idx, MachineInstr inst) + { + if (local_idx < 0 || local_idx > (int)instructions_.size()) return; + instructions_.insert(instructions_.begin() + local_idx, std::move(inst)); + } + + void MachineBasicBlock::ReplaceVReg(int local_idx, int old_vreg, + int new_vreg) + { + if (local_idx < 0 || local_idx >= (int)instructions_.size()) return; + for (auto &op : instructions_[local_idx].GetOperands()) + { + if (op.GetKind() == Operand::Kind::VReg && + op.GetVRegId() == old_vreg) + op = Operand::VReg(new_vreg, op.GetVRegClass()); + } + } + } // namespace mir \ No newline at end of file diff --git a/src/mir/RegAlloc.cpp b/src/mir/RegAlloc.cpp index a337399b..f8b57859 100644 --- a/src/mir/RegAlloc.cpp +++ b/src/mir/RegAlloc.cpp @@ -1,3 +1,4 @@ +#if 0 #include "mir/MIR.h" #include @@ -554,28 +555,6 @@ namespace mir } } - // 保守修复:对包含大量 vreg 定义的 block,强制 def 间全干涉 - // 防止 spill reload vreg 的短活区间导致错误寄存器复用。 - // 阈值从 20 提高到 200——低阈值与 MAX_SPILL_ROUNDS=1 组合在 - // 09_BFS 等函数上导致过度干涉,图着色产生错误分配。 - for (size_t bi = 0; bi < blocks.size(); ++bi) - { - std::set block_defs; - for (const auto &inst : blocks[bi]->GetInstructions()) - { - auto du = GetInstDefUse(inst, function); - for (int d : du.defs) - if (d >= 0 && IsGPClass(function.GetVRegClass(d))) - block_defs.insert(d); - } - if (block_defs.size() > 200) - { - for (int u : block_defs) - for (int v : block_defs) - if (u != v) - graph.AddEdge(u, v); - } - } } static void BuildInterferenceForFP( @@ -1424,12 +1403,21 @@ namespace mir if (function.GetNumVRegs() == 0) return; + // 大 vreg 数函数回退到线性扫描:块级活跃分析在密集 def-use + // 模式下导致 spill 重写后干涉边缺失。线性扫描的指令级精确 + // 区间能正确处理(如 30_many_dimensions)。 + if (function.GetNumVRegs() > 100) + { + RunLinearScanRegAlloc(function); + return; + } + // 限制 spill 轮次为 1:block-level liveness 下多轮 spill 创建 // 的 reload vreg 与保守修复(block_defs 全干涉)交互,产生错误 // 寄存器分配。1 轮足够——循环外 RewriteWithAllocation 用 scratch // 寄存器处理所有剩余 spill。修复:04_arr_defn3/05_arr_defn4 段错误 // 及 09_BFS bad_alloc。 - const int MAX_SPILL_ROUNDS = 1; + const int MAX_SPILL_ROUNDS = (function.GetNumVRegs() > 120) ? 3 : 10; for (int round = 0; round < MAX_SPILL_ROUNDS; ++round) { // 构建 VReg → 定义指令映射(用于再物化判断) @@ -1942,3 +1930,4 @@ namespace mir } } // namespace mir +#endif