You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/mir/RegAlloc.cpp

1308 lines
41 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "mir/MIR.h"
#include <algorithm>
#include <limits>
#include <queue>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "utils/Log.h"
namespace mir
{
namespace
{
static int PhysRegToNumber(PhysReg reg)
{
if (reg >= PhysReg::W0 && reg <= PhysReg::W30)
return static_cast<int>(reg) - static_cast<int>(PhysReg::W0);
if (reg >= PhysReg::X0 && reg <= PhysReg::X30)
return static_cast<int>(reg) - static_cast<int>(PhysReg::X0);
if (reg >= PhysReg::S0 && reg <= PhysReg::S31)
return static_cast<int>(reg) - static_cast<int>(PhysReg::S0) + 100;
return -1;
}
static PhysReg NumberToPhysReg(int num, VRegClass vc)
{
if (vc == VRegClass::Float)
return static_cast<PhysReg>(static_cast<int>(PhysReg::S0) + num);
if (vc == VRegClass::Ptr)
return static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + num);
return static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + num);
}
static bool IsGPReg(PhysReg reg)
{
return (reg >= PhysReg::W0 && reg <= PhysReg::W30) ||
(reg >= PhysReg::X0 && reg <= PhysReg::X30);
}
static bool IsFPReg(PhysReg reg)
{
return reg >= PhysReg::S0 && reg <= PhysReg::S31;
}
static const int GP_ALLOCATABLE[] = {8, 9, 10, 11, 12, 15, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28};
static const int GP_NUM_ALLOCATABLE = 16;
static const int FP_ALLOCATABLE[] = {8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
static const int FP_NUM_ALLOCATABLE = 24;
static const int FP_CALLER_SAVED[] = {8, 9, 10, 11, 12, 13, 14, 15};
static const int FP_NUM_CALLER_SAVED = 8;
static const int GP_CALLER_SAVED[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17};
static const int GP_NUM_CALLER_SAVED = 18;
struct InstDefUse
{
std::vector<int> defs;
std::vector<int> uses;
bool is_call = false;
bool is_terminator = false;
};
static InstDefUse GetInstDefUse(const MachineInstr &inst, MachineFunction & /*function*/)
{
InstDefUse result;
const auto opcode = inst.GetOpcode();
const auto &ops = inst.GetOperands();
switch (opcode)
{
case Opcode::Prologue:
case Opcode::Epilogue:
case Opcode::Br:
break;
case Opcode::MovImm:
if (ops.size() >= 1 && ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
break;
case Opcode::LoadStack:
case Opcode::LoadGlobal:
case Opcode::LoadGlobalAddr:
case Opcode::LoadStackAddr:
if (ops.size() >= 1 && ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
break;
case Opcode::StoreStack:
case Opcode::StoreGlobal:
if (ops.size() >= 1 && ops[0].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[0].GetVRegId());
break;
case Opcode::LoadMem:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::StoreMem:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::AddRR:
case Opcode::SubRR:
case Opcode::MulRR:
case Opcode::DivRR:
case Opcode::ModRR:
case Opcode::AndRR:
case Opcode::OrRR:
case Opcode::XorRR:
case Opcode::FAddRR:
case Opcode::FSubRR:
case Opcode::FMulRR:
case Opcode::FDivRR:
if (ops.size() >= 3)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
}
break;
case Opcode::ShlRR:
case Opcode::ShrRR:
case Opcode::AsrRR:
case Opcode::Asr64RR:
if (ops.size() >= 3)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::Uxtw:
case Opcode::Sxtw:
case Opcode::Scvtf:
case Opcode::FCvtzs:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::FMovWS:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::CmpRR:
case Opcode::CmpImm:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::FCmpRR:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::CSet:
if (ops.size() >= 1 && ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
break;
case Opcode::Csel:
if (ops.size() >= 3)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
}
break;
case Opcode::Smull:
if (ops.size() >= 3)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
}
break;
case Opcode::Msub:
if (ops.size() >= 4)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
if (ops[3].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[3].GetVRegId());
}
break;
case Opcode::NegRR:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::CondBr:
result.is_terminator = true;
break;
case Opcode::Call:
result.is_call = true;
for (size_t i = 0; i < ops.size(); ++i)
{
if (ops[i].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[i].GetVRegId());
}
break;
case Opcode::Ret:
for (size_t i = 0; i < ops.size(); ++i)
{
if (ops[i].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[i].GetVRegId());
}
result.is_terminator = true;
break;
case Opcode::MovReg:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[0].GetKind() == Operand::Kind::Reg && ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[0].GetKind() == Operand::Kind::VReg && ops[1].GetKind() == Operand::Kind::Reg)
result.defs.push_back(ops[0].GetVRegId());
}
break;
default:
for (const auto &op : ops)
{
if (op.GetKind() == Operand::Kind::VReg)
result.uses.push_back(op.GetVRegId());
}
break;
}
auto remove_dups = [](std::vector<int> &v)
{
std::sort(v.begin(), v.end());
v.erase(std::unique(v.begin(), v.end()), v.end());
};
remove_dups(result.defs);
remove_dups(result.uses);
return result;
}
struct BlockLiveness
{
std::unordered_set<int> live_in;
std::unordered_set<int> live_out;
std::unordered_set<int> def;
std::unordered_set<int> use;
};
struct LivenessResult
{
std::vector<BlockLiveness> block_liveness;
std::unordered_map<int, int> interval_length;
std::unordered_map<int, int> ref_count;
};
static LivenessResult ComputeBlockLiveness(MachineFunction &function)
{
auto &blocks = function.GetBlocks();
const size_t num_blocks = blocks.size();
std::vector<BlockLiveness> bl(num_blocks);
std::unordered_map<int, size_t> label_to_block;
for (size_t i = 0; i < num_blocks; ++i)
{
label_to_block[blocks[i]->GetLabelId()] = i;
}
std::unordered_map<int, int> ref_count;
for (size_t i = 0; i < num_blocks; ++i)
{
for (const auto &inst : blocks[i]->GetInstructions())
{
auto du = GetInstDefUse(inst, function);
for (int u : du.uses)
{
if (bl[i].def.find(u) == bl[i].def.end())
bl[i].use.insert(u);
ref_count[u]++;
}
for (int d : du.defs)
{
bl[i].def.insert(d);
ref_count[d]++;
}
}
}
bool changed = true;
while (changed)
{
changed = false;
for (size_t i = 0; i < num_blocks; ++i)
{
std::unordered_set<int> new_out;
for (const auto &inst : blocks[i]->GetInstructions())
{
if (inst.GetOpcode() == Opcode::Br && inst.GetOperands().size() >= 1 &&
inst.GetOperands()[0].GetKind() == Operand::Kind::Label)
{
auto it = label_to_block.find(inst.GetOperands()[0].GetLabel());
if (it != label_to_block.end())
for (int v : bl[it->second].live_in)
new_out.insert(v);
}
if (inst.GetOpcode() == Opcode::CondBr && inst.GetOperands().size() >= 2 &&
inst.GetOperands()[1].GetKind() == Operand::Kind::Label)
{
auto it = label_to_block.find(inst.GetOperands()[1].GetLabel());
if (it != label_to_block.end())
for (int v : bl[it->second].live_in)
new_out.insert(v);
}
}
if (new_out != bl[i].live_out)
{
bl[i].live_out = new_out;
changed = true;
}
std::unordered_set<int> new_in = bl[i].use;
for (int v : bl[i].live_out)
{
if (bl[i].def.find(v) == bl[i].def.end())
new_in.insert(v);
}
if (new_in != bl[i].live_in)
{
bl[i].live_in = new_in;
changed = true;
}
}
}
// 在最终稳定的 liveness 上统计 interval_length
// 反向扫描每个块,统计每个 vreg 在多少条指令处活跃
std::unordered_map<int, int> interval_length;
for (size_t i = 0; i < num_blocks; ++i)
{
std::unordered_set<int> live = bl[i].live_out;
for (int v : live)
interval_length[v]++;
const auto &insts = blocks[i]->GetInstructions();
for (auto it = insts.rbegin(); it != insts.rend(); ++it)
{
auto du = GetInstDefUse(*it, function);
for (int d : du.defs)
live.erase(d);
for (int u : du.uses)
live.insert(u);
for (int v : live)
interval_length[v]++;
}
}
return {std::move(bl), std::move(interval_length), std::move(ref_count)};
}
struct InterferenceGraph
{
std::unordered_set<int> nodes;
std::unordered_set<long long> edges;
std::unordered_map<int, int> degree;
static long long MakeEdgeKey(int u, int v)
{
if (u > v) std::swap(u, v);
return (static_cast<long long>(static_cast<unsigned long long>(static_cast<unsigned int>(u)) << 32)) | static_cast<unsigned int>(v);
}
void AddNode(int v) { nodes.insert(v); }
void AddEdge(int u, int v)
{
if (u == v)
return;
auto key = MakeEdgeKey(u, v);
if (edges.find(key) == edges.end())
{
edges.insert(key);
degree[u]++;
degree[v]++;
}
}
int GetDegree(int v) const
{
auto it = degree.find(v);
return it != degree.end() ? it->second : 0;
}
};
static bool IsGPClass(VRegClass vc)
{
return vc == VRegClass::Int || vc == VRegClass::Ptr;
}
static void BuildInterferenceForGP(
MachineFunction &function,
const std::vector<BlockLiveness> &block_liveness,
const std::vector<int> &allocatable_regs,
InterferenceGraph &graph)
{
auto &blocks = function.GetBlocks();
for (int vreg = 0; vreg < function.GetNumVRegs(); ++vreg)
{
if (IsGPClass(function.GetVRegClass(vreg)))
graph.AddNode(vreg);
}
for (size_t bi = 0; bi < blocks.size(); ++bi)
{
auto &block = blocks[bi];
std::unordered_set<int> live = block_liveness[bi].live_out;
const auto &insts = block->GetInstructions();
for (auto it = insts.rbegin(); it != insts.rend(); ++it)
{
auto du = GetInstDefUse(*it, function);
// MovReg: 暂时从 live 中移除 use 操作数,使 def/use 之间不产生干涉边
bool is_movreg = (it->GetOpcode() == Opcode::MovReg);
std::vector<int> saved_uses;
if (is_movreg && du.defs.size() == 1 && du.uses.size() == 1)
{
int use_vreg = du.uses[0];
if (live.count(use_vreg) && IsGPClass(function.GetVRegClass(use_vreg)))
{
live.erase(use_vreg);
saved_uses.push_back(use_vreg);
}
}
for (int d : du.defs)
{
if (!IsGPClass(function.GetVRegClass(d)))
continue;
for (int v : live)
{
if (v != d && IsGPClass(function.GetVRegClass(v)))
graph.AddEdge(d, v);
}
live.erase(d);
}
// 恢复 MovReg 的 use
for (int u : saved_uses)
live.insert(u);
for (int u : du.uses)
{
if (IsGPClass(function.GetVRegClass(u)))
live.insert(u);
}
if (du.is_call)
{
for (int v : live)
{
if (IsGPClass(function.GetVRegClass(v)))
{
for (int c : allocatable_regs)
{
bool is_caller_saved = false;
for (int j = 0; j < GP_NUM_CALLER_SAVED; ++j)
{
if (GP_CALLER_SAVED[j] == c)
{
is_caller_saved = true;
break;
}
}
if (!is_caller_saved)
continue;
int call_clobber_node = -1000 - c;
graph.AddNode(call_clobber_node);
graph.AddEdge(v, call_clobber_node);
}
}
}
}
}
}
}
static void BuildInterferenceForFP(
MachineFunction &function,
const std::vector<BlockLiveness> &block_liveness,
const std::vector<int> &allocatable_regs,
InterferenceGraph &graph)
{
auto &blocks = function.GetBlocks();
for (int vreg = 0; vreg < function.GetNumVRegs(); ++vreg)
{
if (function.GetVRegClass(vreg) == VRegClass::Float)
graph.AddNode(vreg);
}
for (size_t bi = 0; bi < blocks.size(); ++bi)
{
auto &block = blocks[bi];
std::unordered_set<int> live = block_liveness[bi].live_out;
const auto &insts = block->GetInstructions();
for (auto it = insts.rbegin(); it != insts.rend(); ++it)
{
auto du = GetInstDefUse(*it, function);
// MovReg: 暂时从 live 中移除 use 操作数,使 def/use 之间不产生干涉边
bool is_movreg = (it->GetOpcode() == Opcode::MovReg);
std::vector<int> saved_uses;
if (is_movreg && du.defs.size() == 1 && du.uses.size() == 1)
{
int use_vreg = du.uses[0];
if (live.count(use_vreg) && function.GetVRegClass(use_vreg) == VRegClass::Float)
{
live.erase(use_vreg);
saved_uses.push_back(use_vreg);
}
}
for (int d : du.defs)
{
if (function.GetVRegClass(d) != VRegClass::Float)
continue;
for (int v : live)
{
if (v != d && function.GetVRegClass(v) == VRegClass::Float)
graph.AddEdge(d, v);
}
live.erase(d);
}
// 恢复 MovReg 的 use
for (int u : saved_uses)
live.insert(u);
for (int u : du.uses)
{
if (function.GetVRegClass(u) == VRegClass::Float)
live.insert(u);
}
if (du.is_call)
{
for (int v : live)
{
if (function.GetVRegClass(v) == VRegClass::Float)
{
for (int j = 0; j < FP_NUM_CALLER_SAVED; ++j)
{
int c = FP_CALLER_SAVED[j];
int call_clobber_node = -2000 - c;
graph.AddNode(call_clobber_node);
graph.AddEdge(v, call_clobber_node);
}
}
}
}
}
}
}
struct GraphColoringResult
{
std::unordered_map<int, int> assignment;
std::set<int> spilled;
};
static GraphColoringResult ColorGraph(
InterferenceGraph &graph,
const std::vector<int> &allocatable_regs,
MachineFunction & /*function*/,
int caller_saved_threshold,
const std::unordered_map<int, int> &interval_length,
const std::unordered_map<int, int> &ref_count,
const std::set<int> &rematerializable_vregs)
{
const int K = static_cast<int>(allocatable_regs.size());
GraphColoringResult result;
std::unordered_map<int, std::unordered_set<int>> adj;
for (auto key : graph.edges)
{
int u = static_cast<int>(key >> 32);
int v = static_cast<int>(key & 0xFFFFFFFF);
adj[u].insert(v);
adj[v].insert(u);
}
std::unordered_map<int, int> precolored;
for (int v : graph.nodes)
{
if (v < 0)
{
if (v >= -1999 && v < -1000)
{
int reg_num = -(v + 1000);
precolored[v] = reg_num;
}
else if (v >= -2999 && v < -2000)
{
int reg_num = -(v + 2000);
precolored[v] = reg_num;
}
}
}
std::unordered_set<int> remaining;
std::unordered_map<int, int> degree;
for (int v : graph.nodes)
{
if (v >= 0)
{
remaining.insert(v);
int deg = 0;
for (int n : adj[v])
{
if (graph.nodes.count(n))
deg++;
}
degree[v] = deg;
}
}
std::vector<int> simplify_worklist;
for (int v : remaining)
{
if (degree[v] < K)
simplify_worklist.push_back(v);
}
std::vector<int> stack;
while (!remaining.empty())
{
while (!simplify_worklist.empty())
{
int v = simplify_worklist.back();
simplify_worklist.pop_back();
if (!remaining.count(v))
continue;
stack.push_back(v);
remaining.erase(v);
for (int n : adj[v])
{
if (remaining.count(n))
{
degree[n]--;
if (degree[n] == K - 1)
simplify_worklist.push_back(n);
}
}
}
if (!remaining.empty())
{
// spill cost: len(活跃指令数)*5 + ref(def+use总次数)*15 - degree(干涉度数)*25
// cost 越小越优先 spill —— 短区间、少引用、高冲突的变量更适合溢出
// 权重基于经验调节degree 项主导len/ref 项作为 tiebreaker
auto GetSpillCost = [&](int v) -> int {
int len = 0;
auto lit = interval_length.find(v);
if (lit != interval_length.end()) len = lit->second;
int ref = 0;
auto rit = ref_count.find(v);
if (rit != ref_count.end()) ref = rit->second;
int d = degree[v];
// 可再物化变量MovImm 常量)大幅降权,优先 spill
int cost = len * 5 + ref * 15 - d * 25;
if (rematerializable_vregs.count(v))
cost -= 100000;
return cost;
};
int spill_candidate = -1;
int min_cost = std::numeric_limits<int>::max();
for (int v : remaining)
{
int cost = GetSpillCost(v);
if (cost < min_cost)
{
min_cost = cost;
spill_candidate = v;
}
}
if (spill_candidate >= 0)
{
result.spilled.insert(spill_candidate);
remaining.erase(spill_candidate);
for (int n : adj[spill_candidate])
{
if (remaining.count(n))
{
degree[n]--;
if (degree[n] == K - 1)
simplify_worklist.push_back(n);
}
}
}
else
{
break;
}
}
}
std::unordered_map<int, int> colored = precolored;
while (!stack.empty())
{
int v = stack.back();
stack.pop_back();
std::unordered_set<int> used_colors;
for (int n : adj[v])
{
auto it = colored.find(n);
if (it != colored.end())
used_colors.insert(it->second);
}
int assigned_color = -1;
// 第一遍:优先选 caller-saved 颜色 (c < caller_saved_threshold)
for (int c : allocatable_regs)
{
if (c >= caller_saved_threshold) break;
if (used_colors.find(c) == used_colors.end())
{
assigned_color = c;
break;
}
}
// 第二遍:若 caller-saved 无可用,选 callee-saved
if (assigned_color < 0)
{
for (int c : allocatable_regs)
{
if (c < caller_saved_threshold) continue;
if (used_colors.find(c) == used_colors.end())
{
assigned_color = c;
break;
}
}
}
if (assigned_color >= 0)
{
colored[v] = assigned_color;
result.assignment[v] = assigned_color;
}
else
{
result.spilled.insert(v);
}
}
return result;
}
static int PickGPScratchReg(
const std::set<int> &used_scratch,
const InstDefUse &du,
const std::unordered_map<int, int> &gp_assignment,
int skip_vreg = -1)
{
if (!used_scratch.count(14))
{
bool x14_used = false;
for (int d : du.defs)
{
if (d == skip_vreg) continue;
auto it = gp_assignment.find(d);
if (it != gp_assignment.end() && it->second == 14)
{ x14_used = true; break; }
}
if (!x14_used)
{
for (int u2 : du.uses)
{
auto it = gp_assignment.find(u2);
if (it != gp_assignment.end() && it->second == 14)
{ x14_used = true; break; }
}
}
if (!x14_used)
return 14;
}
for (int r : GP_ALLOCATABLE)
{
if (used_scratch.count(r)) continue;
bool used = false;
for (int d : du.defs)
{
if (d == skip_vreg) continue;
auto it = gp_assignment.find(d);
if (it != gp_assignment.end() && it->second == r)
{ used = true; break; }
}
if (!used)
{
for (int u2 : du.uses)
{
auto it = gp_assignment.find(u2);
if (it != gp_assignment.end() && it->second == r)
{ used = true; break; }
}
}
if (!used) return r;
}
return GP_ALLOCATABLE[0];
}
static int PickFPScratchReg(
const std::set<int> &used_scratch,
const InstDefUse &du,
const std::unordered_map<int, int> &fp_assignment,
int skip_vreg = -1)
{
for (int r : FP_ALLOCATABLE)
{
if (used_scratch.count(r)) continue;
bool used = false;
for (int d : du.defs)
{
if (d == skip_vreg) continue;
auto it = fp_assignment.find(d);
if (it != fp_assignment.end() && it->second == r)
{ used = true; break; }
}
if (!used)
{
for (int u2 : du.uses)
{
auto it = fp_assignment.find(u2);
if (it != fp_assignment.end() && it->second == r)
{ used = true; break; }
}
}
if (!used) return r;
}
return FP_ALLOCATABLE[0];
}
static void RewriteWithAllocation(
MachineFunction &function,
const std::unordered_map<int, int> &gp_assignment,
const std::unordered_map<int, int> &fp_assignment,
const std::set<int> &spilled,
const std::unordered_map<int, MachineInstr *> &vreg_def_inst = {})
{
std::unordered_map<int, int> spill_slots;
for (int v : spilled)
{
int size = (function.GetVRegClass(v) == VRegClass::Ptr) ? 8 : 4;
spill_slots[v] = function.CreateFrameIndex(size);
}
for (auto &block : function.GetBlocks())
{
std::vector<MachineInstr> new_insts;
for (auto &inst : block->GetInstructions())
{
auto du = GetInstDefUse(inst, function);
std::set<int> used_scratch_gp;
std::set<int> used_scratch_fp;
for (int u : du.uses)
{
if (spilled.count(u))
{
VRegClass vc = function.GetVRegClass(u);
int slot = spill_slots[u];
int reload_reg_num = -1;
// 检查是否可再物化仅GP寄存器支持AArch64不支持mov sN,#imm
auto def_it = vreg_def_inst.find(u);
bool can_remat = (vc != VRegClass::Float) &&
(def_it != vreg_def_inst.end() && def_it->second->IsRematerializable());
if (vc == VRegClass::Float)
{
reload_reg_num = PickFPScratchReg(used_scratch_fp, du, fp_assignment);
used_scratch_fp.insert(reload_reg_num);
}
else
{
reload_reg_num = PickGPScratchReg(used_scratch_gp, du, gp_assignment);
used_scratch_gp.insert(reload_reg_num);
}
PhysReg reload_reg = NumberToPhysReg(reload_reg_num, vc);
if (can_remat)
{
// 再物化:用 scratch 寄存器直接生成 MovImm
new_insts.push_back(
MachineInstr(Opcode::MovImm,
{Operand::Reg(reload_reg), Operand::Imm(def_it->second->GetRematImm())}));
}
else
{
// 常规:从栈加载
new_insts.push_back(
MachineInstr(Opcode::LoadStack,
{Operand::Reg(reload_reg), Operand::FrameIndex(slot)}));
}
for (auto &op : inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::VReg && op.GetVRegId() == u)
{
const_cast<Operand &>(op) = Operand::Reg(reload_reg);
}
}
}
}
for (auto &op : inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::VReg)
{
int vreg_id = op.GetVRegId();
VRegClass op_vc = op.GetVRegClass();
VRegClass actual_vc = function.GetVRegClass(vreg_id);
int reg_num = -1;
if (actual_vc == VRegClass::Float)
{
auto it = fp_assignment.find(vreg_id);
if (it != fp_assignment.end())
reg_num = it->second;
}
else
{
auto it = gp_assignment.find(vreg_id);
if (it != gp_assignment.end())
reg_num = it->second;
}
if (reg_num >= 0)
{
const_cast<Operand &>(op) = Operand::Reg(NumberToPhysReg(reg_num, actual_vc));
}
else if (spilled.count(vreg_id))
{
int spill_reg_num = -1;
if (actual_vc == VRegClass::Float)
{
spill_reg_num = PickFPScratchReg(used_scratch_fp, du, fp_assignment, vreg_id);
used_scratch_fp.insert(spill_reg_num);
}
else
{
spill_reg_num = PickGPScratchReg(used_scratch_gp, du, gp_assignment, vreg_id);
used_scratch_gp.insert(spill_reg_num);
}
const_cast<Operand &>(op) = Operand::Reg(NumberToPhysReg(spill_reg_num, actual_vc));
}
}
}
new_insts.push_back(std::move(const_cast<MachineInstr &>(inst)));
for (int d : du.defs)
{
if (spilled.count(d))
{
VRegClass vc = function.GetVRegClass(d);
int slot = spill_slots[d];
// 可再物化变量:不需要 StoreStackuse 点会重新生成 MovImm仅GP
auto def_it = vreg_def_inst.find(d);
bool can_remat = (vc != VRegClass::Float) &&
(def_it != vreg_def_inst.end() && def_it->second->IsRematerializable());
if (!can_remat)
{
const auto &last_inst = new_insts.back();
PhysReg spill_reg = PhysReg::W0;
for (const auto &op : last_inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::Reg)
{
PhysReg r = op.GetReg();
if (vc == VRegClass::Float && IsFPReg(r))
{ spill_reg = r; break; }
else if (vc != VRegClass::Float && IsGPReg(r))
{ spill_reg = r; break; }
}
}
new_insts.push_back(
MachineInstr(Opcode::StoreStack,
{Operand::Reg(spill_reg), Operand::FrameIndex(slot)}));
}
}
}
}
block->GetInstructions() = std::move(new_insts);
}
}
static void AllocateRegistersForFunction(MachineFunction &function)
{
if (function.GetNumVRegs() == 0)
return;
const int MAX_SPILL_ROUNDS = 10;
for (int round = 0; round < MAX_SPILL_ROUNDS; ++round)
{
// 构建 VReg → 定义指令映射(用于再物化判断)
std::unordered_map<int, MachineInstr *> vreg_def_inst;
for (auto &block : function.GetBlocks())
{
for (auto &inst : block->GetInstructions())
{
auto du = GetInstDefUse(inst, function);
for (int d : du.defs)
{
vreg_def_inst[d] = &inst;
}
}
}
auto liveness = ComputeBlockLiveness(function);
// 构建可再物化 vreg 集合MovImm 常量)
std::set<int> rematerializable_vregs;
for (const auto &pair : vreg_def_inst)
{
if (pair.second->IsRematerializable())
rematerializable_vregs.insert(pair.first);
}
std::vector<int> gp_alloc(GP_ALLOCATABLE, GP_ALLOCATABLE + GP_NUM_ALLOCATABLE);
std::vector<int> fp_alloc(FP_ALLOCATABLE, FP_ALLOCATABLE + FP_NUM_ALLOCATABLE);
InterferenceGraph gp_graph, fp_graph;
BuildInterferenceForGP(function, liveness.block_liveness, gp_alloc, gp_graph);
BuildInterferenceForFP(function, liveness.block_liveness, fp_alloc, fp_graph);
auto gp_result = ColorGraph(gp_graph, gp_alloc, function, 19,
liveness.interval_length, liveness.ref_count,
rematerializable_vregs);
auto fp_result = ColorGraph(fp_graph, fp_alloc, function, 16,
liveness.interval_length, liveness.ref_count,
rematerializable_vregs);
if (gp_result.spilled.empty() && fp_result.spilled.empty())
{
std::unordered_map<int, int> gp_assign = gp_result.assignment;
std::unordered_map<int, int> fp_assign = fp_result.assignment;
for (const auto &pair : gp_assign)
{
if (pair.second >= 19 && pair.second <= 28)
{
function.AddCalleeSavedReg(NumberToPhysReg(pair.second, VRegClass::Ptr));
}
}
for (const auto &pair : fp_assign)
{
if (pair.second >= 16 && pair.second <= 31)
{
function.AddCalleeSavedReg(NumberToPhysReg(pair.second, VRegClass::Float));
}
}
RewriteWithAllocation(function, gp_assign, fp_assign, {});
return;
}
std::set<int> all_spilled = gp_result.spilled;
for (int v : fp_result.spilled)
all_spilled.insert(v);
std::unordered_map<int, int> spill_slots;
for (int v : all_spilled)
{
int size = (function.GetVRegClass(v) == VRegClass::Ptr) ? 8 : 4;
spill_slots[v] = function.CreateFrameIndex(size);
}
for (auto &block : function.GetBlocks())
{
std::vector<MachineInstr> new_insts;
for (auto &inst : block->GetInstructions())
{
auto du = GetInstDefUse(inst, function);
for (int u : du.uses)
{
if (all_spilled.count(u))
{
VRegClass vc = function.GetVRegClass(u);
int new_vreg = function.CreateVReg(vc);
// 检查是否可再物化仅GP寄存器支持AArch64不支持mov sN,#imm
auto def_it = vreg_def_inst.find(u);
if (vc != VRegClass::Float &&
def_it != vreg_def_inst.end() && def_it->second->IsRematerializable())
{
// 再物化:直接生成 MovImm
new_insts.push_back(
MachineInstr(Opcode::MovImm,
{Operand::VReg(new_vreg, vc), Operand::Imm(def_it->second->GetRematImm())}));
}
else
{
// 常规:从栈加载
new_insts.push_back(
MachineInstr(Opcode::LoadStack,
{Operand::VReg(new_vreg, vc), Operand::FrameIndex(spill_slots[u])}));
}
for (auto &op : inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::VReg && op.GetVRegId() == u)
{
const_cast<Operand &>(op) = Operand::VReg(new_vreg, vc);
}
}
}
}
new_insts.push_back(std::move(const_cast<MachineInstr &>(inst)));
for (int d : du.defs)
{
if (all_spilled.count(d))
{
VRegClass vc = function.GetVRegClass(d);
int new_vreg = function.CreateVReg(vc);
auto &last = new_insts.back();
for (auto &op : last.GetOperands())
{
if (op.GetKind() == Operand::Kind::VReg && op.GetVRegId() == d)
{
const_cast<Operand &>(op) = Operand::VReg(new_vreg, vc);
}
}
// 可再物化变量:不需要 StoreStackuse 点会重新生成 MovImm仅GP
auto def_it = vreg_def_inst.find(d);
bool can_remat = (vc != VRegClass::Float) &&
(def_it != vreg_def_inst.end() && def_it->second->IsRematerializable());
if (!can_remat)
{
new_insts.push_back(
MachineInstr(Opcode::StoreStack,
{Operand::VReg(new_vreg, vc), Operand::FrameIndex(spill_slots[d])}));
}
}
}
}
block->GetInstructions() = std::move(new_insts);
}
}
// 循环外:构建 VReg → 定义指令映射(用于再物化判断)
std::unordered_map<int, MachineInstr *> vreg_def_inst;
for (auto &block : function.GetBlocks())
{
for (auto &inst : block->GetInstructions())
{
auto du = GetInstDefUse(inst, function);
for (int d : du.defs)
{
vreg_def_inst[d] = &inst;
}
}
}
auto liveness = ComputeBlockLiveness(function);
// 构建可再物化 vreg 集合
std::set<int> rematerializable_vregs;
for (const auto &pair : vreg_def_inst)
{
if (pair.second->IsRematerializable())
rematerializable_vregs.insert(pair.first);
}
std::vector<int> gp_alloc(GP_ALLOCATABLE, GP_ALLOCATABLE + GP_NUM_ALLOCATABLE);
std::vector<int> fp_alloc(FP_ALLOCATABLE, FP_ALLOCATABLE + FP_NUM_ALLOCATABLE);
InterferenceGraph gp_graph, fp_graph;
BuildInterferenceForGP(function, liveness.block_liveness, gp_alloc, gp_graph);
BuildInterferenceForFP(function, liveness.block_liveness, fp_alloc, fp_graph);
auto gp_result = ColorGraph(gp_graph, gp_alloc, function, 19,
liveness.interval_length, liveness.ref_count,
rematerializable_vregs);
auto fp_result = ColorGraph(fp_graph, fp_alloc, function, 16,
liveness.interval_length, liveness.ref_count,
rematerializable_vregs);
std::set<int> all_spilled = gp_result.spilled;
for (int v : fp_result.spilled)
all_spilled.insert(v);
std::unordered_map<int, int> gp_assign = gp_result.assignment;
std::unordered_map<int, int> fp_assign = fp_result.assignment;
for (const auto &pair : gp_assign)
{
if (pair.second >= 19 && pair.second <= 28)
{
function.AddCalleeSavedReg(NumberToPhysReg(pair.second, VRegClass::Ptr));
}
}
for (const auto &pair : fp_assign)
{
if (pair.second >= 16 && pair.second <= 31)
{
function.AddCalleeSavedReg(NumberToPhysReg(pair.second, VRegClass::Float));
}
}
RewriteWithAllocation(function, gp_assign, fp_assign, all_spilled, vreg_def_inst);
}
} // namespace
void RunRegAlloc(MachineFunction &function)
{
AllocateRegistersForFunction(function);
}
void RunRegAlloc(MachineModule &module)
{
for (auto &function : module.GetFunctions())
{
if (function)
RunRegAlloc(*function);
}
}
} // namespace mir