You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/mir/RegAlloc.cpp

1219 lines
37 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "mir/MIR.h"
#include <algorithm>
#include <queue>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "utils/Log.h"
namespace mir
{
namespace
{
static int PhysRegToNumber(PhysReg reg)
{
if (reg >= PhysReg::W0 && reg <= PhysReg::W30)
return static_cast<int>(reg) - static_cast<int>(PhysReg::W0);
if (reg >= PhysReg::X0 && reg <= PhysReg::X30)
return static_cast<int>(reg) - static_cast<int>(PhysReg::X0);
if (reg >= PhysReg::S0 && reg <= PhysReg::S31)
return static_cast<int>(reg) - static_cast<int>(PhysReg::S0) + 100;
return -1;
}
static PhysReg NumberToPhysReg(int num, VRegClass vc)
{
if (vc == VRegClass::Float)
return static_cast<PhysReg>(static_cast<int>(PhysReg::S0) + num);
if (vc == VRegClass::Ptr)
return static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + num);
return static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + num);
}
static bool IsGPReg(PhysReg reg)
{
return (reg >= PhysReg::W0 && reg <= PhysReg::W30) ||
(reg >= PhysReg::X0 && reg <= PhysReg::X30);
}
static bool IsFPReg(PhysReg reg)
{
return reg >= PhysReg::S0 && reg <= PhysReg::S31;
}
static const int GP_ALLOCATABLE[] = {8, 9, 10, 11, 12, 15, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28};
static const int GP_NUM_ALLOCATABLE = 16;
static const int FP_ALLOCATABLE[] = {8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
static const int FP_NUM_ALLOCATABLE = 24;
static const int FP_CALLER_SAVED[] = {8, 9, 10, 11, 12, 13, 14, 15};
static const int FP_NUM_CALLER_SAVED = 8;
static const int GP_CALLER_SAVED[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17};
static const int GP_NUM_CALLER_SAVED = 18;
struct InstDefUse
{
std::vector<int> defs;
std::vector<int> uses;
bool is_call = false;
bool is_terminator = false;
};
static InstDefUse GetInstDefUse(const MachineInstr &inst, MachineFunction & /*function*/)
{
InstDefUse result;
const auto opcode = inst.GetOpcode();
const auto &ops = inst.GetOperands();
switch (opcode)
{
case Opcode::Prologue:
case Opcode::Epilogue:
case Opcode::Br:
break;
case Opcode::MovImm:
if (ops.size() >= 1 && ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
break;
case Opcode::LoadStack:
case Opcode::LoadGlobal:
case Opcode::LoadGlobalAddr:
case Opcode::LoadStackAddr:
if (ops.size() >= 1 && ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
break;
case Opcode::StoreStack:
case Opcode::StoreGlobal:
if (ops.size() >= 1 && ops[0].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[0].GetVRegId());
break;
case Opcode::LoadMem:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
// 缩放寻址第三操作数index register
if (ops.size() >= 3 && ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
}
break;
case Opcode::StoreMem:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
// 缩放寻址第三操作数index register
if (ops.size() >= 3 && ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
}
break;
case Opcode::AddRR:
case Opcode::SubRR:
case Opcode::MulRR:
case Opcode::DivRR:
case Opcode::ModRR:
case Opcode::AndRR:
case Opcode::OrRR:
case Opcode::XorRR:
case Opcode::FAddRR:
case Opcode::FSubRR:
case Opcode::FMulRR:
case Opcode::FDivRR:
if (ops.size() >= 3)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
}
break;
case Opcode::ShlRR:
case Opcode::ShrRR:
case Opcode::AsrRR:
case Opcode::Asr64RR:
if (ops.size() >= 3)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::Uxtw:
case Opcode::Sxtw:
case Opcode::Scvtf:
case Opcode::FCvtzs:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::FMovWS:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::CmpRR:
case Opcode::CmpImm:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::FCmpRR:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::CSet:
if (ops.size() >= 1 && ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
break;
case Opcode::Csel:
if (ops.size() >= 3)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
}
break;
case Opcode::Csneg:
if (ops.size() >= 3)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
}
break;
case Opcode::Smull:
if (ops.size() >= 3)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
}
break;
case Opcode::Msub:
if (ops.size() >= 4)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
if (ops[3].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[3].GetVRegId());
}
break;
case Opcode::NegRR:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::CondBr:
result.is_terminator = true;
break;
case Opcode::Call:
result.is_call = true;
for (size_t i = 0; i < ops.size(); ++i)
{
if (ops[i].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[i].GetVRegId());
}
break;
case Opcode::Ret:
for (size_t i = 0; i < ops.size(); ++i)
{
if (ops[i].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[i].GetVRegId());
}
result.is_terminator = true;
break;
case Opcode::MovReg:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[0].GetKind() == Operand::Kind::Reg && ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[0].GetKind() == Operand::Kind::VReg && ops[1].GetKind() == Operand::Kind::Reg)
result.defs.push_back(ops[0].GetVRegId());
}
break;
default:
for (const auto &op : ops)
{
if (op.GetKind() == Operand::Kind::VReg)
result.uses.push_back(op.GetVRegId());
}
break;
}
auto remove_dups = [](std::vector<int> &v)
{
std::sort(v.begin(), v.end());
v.erase(std::unique(v.begin(), v.end()), v.end());
};
remove_dups(result.defs);
remove_dups(result.uses);
return result;
}
struct BlockLiveness
{
std::unordered_set<int> live_in;
std::unordered_set<int> live_out;
std::unordered_set<int> def;
std::unordered_set<int> use;
};
static std::vector<BlockLiveness> ComputeBlockLiveness(MachineFunction &function)
{
auto &blocks = function.GetBlocks();
const size_t num_blocks = blocks.size();
std::vector<BlockLiveness> bl(num_blocks);
std::unordered_map<int, size_t> label_to_block;
for (size_t i = 0; i < num_blocks; ++i)
{
label_to_block[blocks[i]->GetLabelId()] = i;
}
for (size_t i = 0; i < num_blocks; ++i)
{
for (const auto &inst : blocks[i]->GetInstructions())
{
auto du = GetInstDefUse(inst, function);
for (int u : du.uses)
{
if (bl[i].def.find(u) == bl[i].def.end())
bl[i].use.insert(u);
}
for (int d : du.defs)
bl[i].def.insert(d);
}
}
bool changed = true;
while (changed)
{
changed = false;
for (size_t i = 0; i < num_blocks; ++i)
{
std::unordered_set<int> new_out;
for (const auto &inst : blocks[i]->GetInstructions())
{
if (inst.GetOpcode() == Opcode::Br && inst.GetOperands().size() >= 1 &&
inst.GetOperands()[0].GetKind() == Operand::Kind::Label)
{
auto it = label_to_block.find(inst.GetOperands()[0].GetLabel());
if (it != label_to_block.end())
for (int v : bl[it->second].live_in)
new_out.insert(v);
}
if (inst.GetOpcode() == Opcode::CondBr && inst.GetOperands().size() >= 2 &&
inst.GetOperands()[1].GetKind() == Operand::Kind::Label)
{
auto it = label_to_block.find(inst.GetOperands()[1].GetLabel());
if (it != label_to_block.end())
for (int v : bl[it->second].live_in)
new_out.insert(v);
}
}
if (new_out != bl[i].live_out)
{
bl[i].live_out = new_out;
changed = true;
}
std::unordered_set<int> new_in = bl[i].use;
for (int v : bl[i].live_out)
{
if (bl[i].def.find(v) == bl[i].def.end())
new_in.insert(v);
}
if (new_in != bl[i].live_in)
{
bl[i].live_in = new_in;
changed = true;
}
}
}
return bl;
}
struct InterferenceGraph
{
std::unordered_set<int> nodes;
std::unordered_set<long long> edges;
std::unordered_map<int, int> degree;
static long long MakeEdgeKey(int u, int v)
{
if (u > v) std::swap(u, v);
return (static_cast<long long>(static_cast<unsigned long long>(static_cast<unsigned int>(u)) << 32)) | static_cast<unsigned int>(v);
}
void AddNode(int v) { nodes.insert(v); }
void AddEdge(int u, int v)
{
if (u == v)
return;
auto key = MakeEdgeKey(u, v);
if (edges.find(key) == edges.end())
{
edges.insert(key);
degree[u]++;
degree[v]++;
}
}
int GetDegree(int v) const
{
auto it = degree.find(v);
return it != degree.end() ? it->second : 0;
}
};
static bool IsGPClass(VRegClass vc)
{
return vc == VRegClass::Int || vc == VRegClass::Ptr;
}
static void BuildInterferenceForGP(
MachineFunction &function,
const std::vector<BlockLiveness> &block_liveness,
const std::vector<int> &allocatable_regs,
InterferenceGraph &graph)
{
auto &blocks = function.GetBlocks();
for (int vreg = 0; vreg < function.GetNumVRegs(); ++vreg)
{
if (IsGPClass(function.GetVRegClass(vreg)))
graph.AddNode(vreg);
}
for (size_t bi = 0; bi < blocks.size(); ++bi)
{
auto &block = blocks[bi];
std::unordered_set<int> live = block_liveness[bi].live_out;
std::vector<int> gp_live;
for (int v : live)
{
if (IsGPClass(function.GetVRegClass(v)))
gp_live.push_back(v);
}
for (size_t i = 0; i < gp_live.size(); ++i)
{
for (size_t j = i + 1; j < gp_live.size(); ++j)
{
graph.AddEdge(gp_live[i], gp_live[j]);
}
}
const auto &insts = block->GetInstructions();
for (auto it = insts.rbegin(); it != insts.rend(); ++it)
{
auto du = GetInstDefUse(*it, function);
for (int d : du.defs)
{
if (!IsGPClass(function.GetVRegClass(d)))
continue;
for (int v : live)
{
if (v != d && IsGPClass(function.GetVRegClass(v)))
graph.AddEdge(d, v);
}
live.erase(d);
}
for (int u : du.uses)
{
if (IsGPClass(function.GetVRegClass(u)))
live.insert(u);
}
if (du.is_call)
{
for (int v : live)
{
if (IsGPClass(function.GetVRegClass(v)))
{
for (int c : allocatable_regs)
{
bool is_caller_saved = false;
for (int j = 0; j < GP_NUM_CALLER_SAVED; ++j)
{
if (GP_CALLER_SAVED[j] == c)
{
is_caller_saved = true;
break;
}
}
if (!is_caller_saved)
continue;
int call_clobber_node = -1000 - c;
graph.AddNode(call_clobber_node);
graph.AddEdge(v, call_clobber_node);
}
}
}
}
}
}
}
static void BuildInterferenceForFP(
MachineFunction &function,
const std::vector<BlockLiveness> &block_liveness,
const std::vector<int> &allocatable_regs,
InterferenceGraph &graph)
{
auto &blocks = function.GetBlocks();
for (int vreg = 0; vreg < function.GetNumVRegs(); ++vreg)
{
if (function.GetVRegClass(vreg) == VRegClass::Float)
graph.AddNode(vreg);
}
for (size_t bi = 0; bi < blocks.size(); ++bi)
{
auto &block = blocks[bi];
std::unordered_set<int> live = block_liveness[bi].live_out;
std::vector<int> fp_live;
for (int v : live)
{
if (function.GetVRegClass(v) == VRegClass::Float)
fp_live.push_back(v);
}
for (size_t i = 0; i < fp_live.size(); ++i)
{
for (size_t j = i + 1; j < fp_live.size(); ++j)
{
graph.AddEdge(fp_live[i], fp_live[j]);
}
}
const auto &insts = block->GetInstructions();
for (auto it = insts.rbegin(); it != insts.rend(); ++it)
{
auto du = GetInstDefUse(*it, function);
for (int d : du.defs)
{
if (function.GetVRegClass(d) != VRegClass::Float)
continue;
for (int v : live)
{
if (v != d && function.GetVRegClass(v) == VRegClass::Float)
graph.AddEdge(d, v);
}
live.erase(d);
}
for (int u : du.uses)
{
if (function.GetVRegClass(u) == VRegClass::Float)
live.insert(u);
}
if (du.is_call)
{
for (int v : live)
{
if (function.GetVRegClass(v) == VRegClass::Float)
{
for (int j = 0; j < FP_NUM_CALLER_SAVED; ++j)
{
int c = FP_CALLER_SAVED[j];
int call_clobber_node = -2000 - c;
graph.AddNode(call_clobber_node);
graph.AddEdge(v, call_clobber_node);
}
}
}
}
}
}
}
struct GraphColoringResult
{
std::unordered_map<int, int> assignment;
std::set<int> spilled;
};
static GraphColoringResult ColorGraph(
InterferenceGraph &graph,
const std::vector<int> &allocatable_regs,
MachineFunction & /*function*/,
const std::unordered_map<int, std::vector<int>> &copy_edges = {})
{
const int K = static_cast<int>(allocatable_regs.size());
GraphColoringResult result;
std::unordered_map<int, std::unordered_set<int>> adj;
for (auto key : graph.edges)
{
int u = static_cast<int>(key >> 32);
int v = static_cast<int>(key & 0xFFFFFFFF);
adj[u].insert(v);
adj[v].insert(u);
}
std::unordered_map<int, int> precolored;
for (int v : graph.nodes)
{
if (v < 0)
{
if (v >= -1999 && v < -1000)
{
int reg_num = -(v + 1000);
precolored[v] = reg_num;
}
else if (v >= -2999 && v < -2000)
{
int reg_num = -(v + 2000);
precolored[v] = reg_num;
}
}
}
std::unordered_set<int> remaining;
std::unordered_map<int, int> degree;
for (int v : graph.nodes)
{
if (v >= 0)
{
remaining.insert(v);
int deg = 0;
for (int n : adj[v])
{
if (graph.nodes.count(n))
deg++;
}
degree[v] = deg;
}
}
std::vector<int> simplify_worklist;
for (int v : remaining)
{
if (degree[v] < K)
simplify_worklist.push_back(v);
}
std::vector<int> stack;
while (!remaining.empty())
{
while (!simplify_worklist.empty())
{
int v = simplify_worklist.back();
simplify_worklist.pop_back();
if (!remaining.count(v))
continue;
stack.push_back(v);
remaining.erase(v);
for (int n : adj[v])
{
if (remaining.count(n))
{
degree[n]--;
if (degree[n] == K - 1)
simplify_worklist.push_back(n);
}
}
}
if (!remaining.empty())
{
int spill_candidate = -1;
int max_degree = -1;
for (int v : remaining)
{
if (degree[v] > max_degree)
{
max_degree = degree[v];
spill_candidate = v;
}
}
if (spill_candidate >= 0)
{
result.spilled.insert(spill_candidate);
remaining.erase(spill_candidate);
for (int n : adj[spill_candidate])
{
if (remaining.count(n))
{
degree[n]--;
if (degree[n] == K - 1)
simplify_worklist.push_back(n);
}
}
}
else
{
break;
}
}
}
std::unordered_map<int, int> colored = precolored;
while (!stack.empty())
{
int v = stack.back();
stack.pop_back();
std::unordered_set<int> used_colors;
for (int n : adj[v])
{
auto it = colored.find(n);
if (it != colored.end())
used_colors.insert(it->second);
}
int assigned_color = -1;
// 偏置着色:优先使用 copy edge 源操作数的颜色
auto copy_it = copy_edges.find(v);
if (copy_it != copy_edges.end())
{
for (int neighbor : copy_it->second)
{
auto nit = colored.find(neighbor);
if (nit != colored.end())
{
int pref = nit->second;
if (used_colors.find(pref) == used_colors.end())
{
assigned_color = pref;
break;
}
}
}
}
if (assigned_color < 0)
{
for (int c : allocatable_regs)
{
if (used_colors.find(c) == used_colors.end())
{
assigned_color = c;
break;
}
}
}
if (assigned_color >= 0)
{
colored[v] = assigned_color;
result.assignment[v] = assigned_color;
}
else
{
result.spilled.insert(v);
}
}
return result;
}
static int PickGPScratchReg(
const std::set<int> &used_scratch,
const InstDefUse &du,
const std::unordered_map<int, int> &gp_assignment,
int skip_vreg = -1)
{
if (!used_scratch.count(14))
{
bool x14_used = false;
for (int d : du.defs)
{
if (d == skip_vreg) continue;
auto it = gp_assignment.find(d);
if (it != gp_assignment.end() && it->second == 14)
{ x14_used = true; break; }
}
if (!x14_used)
{
for (int u2 : du.uses)
{
auto it = gp_assignment.find(u2);
if (it != gp_assignment.end() && it->second == 14)
{ x14_used = true; break; }
}
}
if (!x14_used)
return 14;
}
for (int r : GP_ALLOCATABLE)
{
if (used_scratch.count(r)) continue;
bool used = false;
for (int d : du.defs)
{
if (d == skip_vreg) continue;
auto it = gp_assignment.find(d);
if (it != gp_assignment.end() && it->second == r)
{ used = true; break; }
}
if (!used)
{
for (int u2 : du.uses)
{
auto it = gp_assignment.find(u2);
if (it != gp_assignment.end() && it->second == r)
{ used = true; break; }
}
}
if (!used) return r;
}
return GP_ALLOCATABLE[0];
}
static int PickFPScratchReg(
const std::set<int> &used_scratch,
const InstDefUse &du,
const std::unordered_map<int, int> &fp_assignment,
int skip_vreg = -1)
{
for (int r : FP_ALLOCATABLE)
{
if (used_scratch.count(r)) continue;
bool used = false;
for (int d : du.defs)
{
if (d == skip_vreg) continue;
auto it = fp_assignment.find(d);
if (it != fp_assignment.end() && it->second == r)
{ used = true; break; }
}
if (!used)
{
for (int u2 : du.uses)
{
auto it = fp_assignment.find(u2);
if (it != fp_assignment.end() && it->second == r)
{ used = true; break; }
}
}
if (!used) return r;
}
return FP_ALLOCATABLE[0];
}
static void RewriteWithAllocation(
MachineFunction &function,
const std::unordered_map<int, int> &gp_assignment,
const std::unordered_map<int, int> &fp_assignment,
const std::set<int> &spilled)
{
std::unordered_map<int, int> spill_slots;
for (int v : spilled)
{
int size = (function.GetVRegClass(v) == VRegClass::Ptr) ? 8 : 4;
spill_slots[v] = function.CreateFrameIndex(size);
}
for (auto &block : function.GetBlocks())
{
std::vector<MachineInstr> new_insts;
for (auto &inst : block->GetInstructions())
{
auto du = GetInstDefUse(inst, function);
std::set<int> used_scratch_gp;
std::set<int> used_scratch_fp;
for (int u : du.uses)
{
if (spilled.count(u))
{
VRegClass vc = function.GetVRegClass(u);
int slot = spill_slots[u];
int reload_reg_num = -1;
if (vc == VRegClass::Float)
{
reload_reg_num = PickFPScratchReg(used_scratch_fp, du, fp_assignment);
used_scratch_fp.insert(reload_reg_num);
}
else
{
reload_reg_num = PickGPScratchReg(used_scratch_gp, du, gp_assignment);
used_scratch_gp.insert(reload_reg_num);
}
PhysReg reload_reg = NumberToPhysReg(reload_reg_num, vc);
new_insts.push_back(
MachineInstr(Opcode::LoadStack,
{Operand::Reg(reload_reg), Operand::FrameIndex(slot)}));
for (auto &op : inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::VReg && op.GetVRegId() == u)
{
const_cast<Operand &>(op) = Operand::Reg(reload_reg);
}
}
}
}
for (auto &op : inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::VReg)
{
int vreg_id = op.GetVRegId();
VRegClass op_vc = op.GetVRegClass();
VRegClass actual_vc = function.GetVRegClass(vreg_id);
int reg_num = -1;
if (actual_vc == VRegClass::Float)
{
auto it = fp_assignment.find(vreg_id);
if (it != fp_assignment.end())
reg_num = it->second;
}
else
{
auto it = gp_assignment.find(vreg_id);
if (it != gp_assignment.end())
reg_num = it->second;
}
if (reg_num >= 0)
{
const_cast<Operand &>(op) = Operand::Reg(NumberToPhysReg(reg_num, actual_vc));
}
else if (spilled.count(vreg_id))
{
int spill_reg_num = -1;
if (actual_vc == VRegClass::Float)
{
spill_reg_num = PickFPScratchReg(used_scratch_fp, du, fp_assignment, vreg_id);
used_scratch_fp.insert(spill_reg_num);
}
else
{
spill_reg_num = PickGPScratchReg(used_scratch_gp, du, gp_assignment, vreg_id);
used_scratch_gp.insert(spill_reg_num);
}
const_cast<Operand &>(op) = Operand::Reg(NumberToPhysReg(spill_reg_num, actual_vc));
}
}
}
new_insts.push_back(std::move(const_cast<MachineInstr &>(inst)));
for (int d : du.defs)
{
if (spilled.count(d))
{
VRegClass vc = function.GetVRegClass(d);
int slot = spill_slots[d];
const auto &last_inst = new_insts.back();
PhysReg spill_reg = PhysReg::W0;
for (const auto &op : last_inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::Reg)
{
PhysReg r = op.GetReg();
if (vc == VRegClass::Float && IsFPReg(r))
{ spill_reg = r; break; }
else if (vc != VRegClass::Float && IsGPReg(r))
{ spill_reg = r; break; }
}
}
new_insts.push_back(
MachineInstr(Opcode::StoreStack,
{Operand::Reg(spill_reg), Operand::FrameIndex(slot)}));
}
}
}
block->GetInstructions() = std::move(new_insts);
}
}
static void AllocateRegistersForFunction(MachineFunction &function)
{
if (function.GetNumVRegs() == 0)
return;
const int MAX_SPILL_ROUNDS = 10;
for (int round = 0; round < MAX_SPILL_ROUNDS; ++round)
{
auto block_liveness = ComputeBlockLiveness(function);
// 收集 copy edges (MovReg 连接的 vreg 对)
std::unordered_map<int, std::vector<int>> gp_copy_edges;
std::unordered_map<int, std::vector<int>> fp_copy_edges;
for (auto &block : function.GetBlocks())
{
for (auto &inst : block->GetInstructions())
{
if (inst.GetOpcode() == Opcode::MovReg)
{
const auto &ops = inst.GetOperands();
if (ops.size() >= 2 &&
ops[0].GetKind() == Operand::Kind::VReg &&
ops[1].GetKind() == Operand::Kind::VReg)
{
int dst = ops[0].GetVRegId();
int src = ops[1].GetVRegId();
if (function.GetVRegClass(dst) == function.GetVRegClass(src))
{
auto &edges = IsGPClass(function.GetVRegClass(dst))
? gp_copy_edges : fp_copy_edges;
edges[dst].push_back(src);
edges[src].push_back(dst);
}
}
}
}
}
std::vector<int> gp_alloc(GP_ALLOCATABLE, GP_ALLOCATABLE + GP_NUM_ALLOCATABLE);
std::vector<int> fp_alloc(FP_ALLOCATABLE, FP_ALLOCATABLE + FP_NUM_ALLOCATABLE);
InterferenceGraph gp_graph, fp_graph;
BuildInterferenceForGP(function, block_liveness, gp_alloc, gp_graph);
BuildInterferenceForFP(function, block_liveness, fp_alloc, fp_graph);
auto gp_result = ColorGraph(gp_graph, gp_alloc, function, gp_copy_edges);
auto fp_result = ColorGraph(fp_graph, fp_alloc, function, fp_copy_edges);
if (gp_result.spilled.empty() && fp_result.spilled.empty())
{
std::unordered_map<int, int> gp_assign = gp_result.assignment;
std::unordered_map<int, int> fp_assign = fp_result.assignment;
for (const auto &pair : gp_assign)
{
if (pair.second >= 19 && pair.second <= 28)
{
function.AddCalleeSavedReg(NumberToPhysReg(pair.second, VRegClass::Ptr));
}
}
for (const auto &pair : fp_assign)
{
if (pair.second >= 16 && pair.second <= 31)
{
function.AddCalleeSavedReg(NumberToPhysReg(pair.second, VRegClass::Float));
}
}
RewriteWithAllocation(function, gp_assign, fp_assign, {});
return;
}
std::set<int> all_spilled = gp_result.spilled;
for (int v : fp_result.spilled)
all_spilled.insert(v);
std::unordered_map<int, int> spill_slots;
for (int v : all_spilled)
{
int size = (function.GetVRegClass(v) == VRegClass::Ptr) ? 8 : 4;
spill_slots[v] = function.CreateFrameIndex(size);
}
for (auto &block : function.GetBlocks())
{
std::vector<MachineInstr> new_insts;
for (auto &inst : block->GetInstructions())
{
auto du = GetInstDefUse(inst, function);
for (int u : du.uses)
{
if (all_spilled.count(u))
{
VRegClass vc = function.GetVRegClass(u);
int new_vreg = function.CreateVReg(vc);
new_insts.push_back(
MachineInstr(Opcode::LoadStack,
{Operand::VReg(new_vreg, vc), Operand::FrameIndex(spill_slots[u])}));
for (auto &op : inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::VReg && op.GetVRegId() == u)
{
const_cast<Operand &>(op) = Operand::VReg(new_vreg, vc);
}
}
}
}
new_insts.push_back(std::move(const_cast<MachineInstr &>(inst)));
for (int d : du.defs)
{
if (all_spilled.count(d))
{
VRegClass vc = function.GetVRegClass(d);
int new_vreg = function.CreateVReg(vc);
auto &last = new_insts.back();
for (auto &op : last.GetOperands())
{
if (op.GetKind() == Operand::Kind::VReg && op.GetVRegId() == d)
{
const_cast<Operand &>(op) = Operand::VReg(new_vreg, vc);
}
}
new_insts.push_back(
MachineInstr(Opcode::StoreStack,
{Operand::VReg(new_vreg, vc), Operand::FrameIndex(spill_slots[d])}));
}
}
}
block->GetInstructions() = std::move(new_insts);
}
}
auto block_liveness = ComputeBlockLiveness(function);
// 收集 copy edges
std::unordered_map<int, std::vector<int>> fp_copy_edges_fb;
std::unordered_map<int, std::vector<int>> gp_copy_edges_fb;
for (auto &block : function.GetBlocks())
for (auto &inst : block->GetInstructions())
if (inst.GetOpcode() == Opcode::MovReg)
{
const auto &ops = inst.GetOperands();
if (ops.size() >= 2 && ops[0].GetKind() == Operand::Kind::VReg &&
ops[1].GetKind() == Operand::Kind::VReg)
{
int dst = ops[0].GetVRegId(), src = ops[1].GetVRegId();
if (function.GetVRegClass(dst) == function.GetVRegClass(src))
{
auto &e = IsGPClass(function.GetVRegClass(dst)) ? gp_copy_edges_fb : fp_copy_edges_fb;
e[dst].push_back(src);
e[src].push_back(dst);
}
}
}
std::vector<int> gp_alloc(GP_ALLOCATABLE, GP_ALLOCATABLE + GP_NUM_ALLOCATABLE);
std::vector<int> fp_alloc(FP_ALLOCATABLE, FP_ALLOCATABLE + FP_NUM_ALLOCATABLE);
InterferenceGraph gp_graph, fp_graph;
BuildInterferenceForGP(function, block_liveness, gp_alloc, gp_graph);
BuildInterferenceForFP(function, block_liveness, fp_alloc, fp_graph);
auto gp_result = ColorGraph(gp_graph, gp_alloc, function, gp_copy_edges_fb);
auto fp_result = ColorGraph(fp_graph, fp_alloc, function, fp_copy_edges_fb);
std::set<int> all_spilled = gp_result.spilled;
for (int v : fp_result.spilled)
all_spilled.insert(v);
std::unordered_map<int, int> gp_assign = gp_result.assignment;
std::unordered_map<int, int> fp_assign = fp_result.assignment;
for (const auto &pair : gp_assign)
{
if (pair.second >= 19 && pair.second <= 28)
{
function.AddCalleeSavedReg(NumberToPhysReg(pair.second, VRegClass::Ptr));
}
}
for (const auto &pair : fp_assign)
{
if (pair.second >= 16 && pair.second <= 31)
{
function.AddCalleeSavedReg(NumberToPhysReg(pair.second, VRegClass::Float));
}
}
RewriteWithAllocation(function, gp_assign, fp_assign, all_spilled);
}
} // namespace
void RunRegAlloc(MachineFunction &function)
{
AllocateRegistersForFunction(function);
}
void RunRegAlloc(MachineModule &module)
{
for (auto &function : module.GetFunctions())
{
if (function)
RunRegAlloc(*function);
}
}
} // namespace mir