You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/mir/RegAlloc.cpp

1213 lines
37 KiB

#include "mir/MIR.h"
#include <algorithm>
#include <queue>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "utils/Log.h"
namespace mir
{
namespace
{
static int PhysRegToNumber(PhysReg reg)
{
if (reg >= PhysReg::W0 && reg <= PhysReg::W30)
return static_cast<int>(reg) - static_cast<int>(PhysReg::W0);
if (reg >= PhysReg::X0 && reg <= PhysReg::X30)
return static_cast<int>(reg) - static_cast<int>(PhysReg::X0);
if (reg >= PhysReg::S0 && reg <= PhysReg::S31)
return static_cast<int>(reg) - static_cast<int>(PhysReg::S0) + 100;
return -1;
}
static PhysReg NumberToPhysReg(int num, VRegClass vc)
{
if (vc == VRegClass::Float)
return static_cast<PhysReg>(static_cast<int>(PhysReg::S0) + num);
if (vc == VRegClass::Ptr)
return static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + num);
return static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + num);
}
static bool IsGPReg(PhysReg reg)
{
return (reg >= PhysReg::W0 && reg <= PhysReg::W30) ||
(reg >= PhysReg::X0 && reg <= PhysReg::X30);
}
static bool IsFPReg(PhysReg reg)
{
return reg >= PhysReg::S0 && reg <= PhysReg::S31;
}
static const int GP_ALLOCATABLE[] = {8, 9, 10, 11, 12, 15, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28};
static const int GP_NUM_ALLOCATABLE = 16;
static const int FP_ALLOCATABLE[] = {8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
static const int FP_NUM_ALLOCATABLE = 24;
static const int FP_CALLER_SAVED[] = {8, 9, 10, 11, 12, 13, 14, 15};
static const int FP_NUM_CALLER_SAVED = 8;
static const int GP_CALLER_SAVED[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17};
static const int GP_NUM_CALLER_SAVED = 18;
struct InstDefUse
{
std::vector<int> defs;
std::vector<int> uses;
bool is_call = false;
bool is_terminator = false;
};
static InstDefUse GetInstDefUse(const MachineInstr &inst, MachineFunction & /*function*/)
{
InstDefUse result;
const auto opcode = inst.GetOpcode();
const auto &ops = inst.GetOperands();
switch (opcode)
{
case Opcode::Prologue:
case Opcode::Epilogue:
case Opcode::Br:
break;
case Opcode::MovImm:
if (ops.size() >= 1 && ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
break;
case Opcode::LoadStack:
case Opcode::LoadGlobal:
case Opcode::LoadGlobalAddr:
case Opcode::LoadStackAddr:
if (ops.size() >= 1 && ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
break;
case Opcode::StoreStack:
case Opcode::StoreGlobal:
if (ops.size() >= 1 && ops[0].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[0].GetVRegId());
break;
case Opcode::LoadMem:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::StoreMem:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::AddRR:
case Opcode::SubRR:
case Opcode::MulRR:
case Opcode::DivRR:
case Opcode::ModRR:
case Opcode::AndRR:
case Opcode::OrRR:
case Opcode::XorRR:
case Opcode::FAddRR:
case Opcode::FSubRR:
case Opcode::FMulRR:
case Opcode::FDivRR:
if (ops.size() >= 3)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
}
break;
case Opcode::ShlRR:
case Opcode::ShrRR:
case Opcode::AsrRR:
case Opcode::Asr64RR:
if (ops.size() >= 3)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::Uxtw:
case Opcode::Sxtw:
case Opcode::Scvtf:
case Opcode::FCvtzs:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::FMovWS:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::CmpRR:
case Opcode::CmpImm:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::FCmpRR:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::CSet:
if (ops.size() >= 1 && ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
break;
case Opcode::Csel:
if (ops.size() >= 3)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
}
break;
case Opcode::Csneg:
if (ops.size() >= 3)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
}
break;
case Opcode::Smull:
if (ops.size() >= 3)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
}
break;
case Opcode::Msub:
if (ops.size() >= 4)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[2].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[2].GetVRegId());
if (ops[3].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[3].GetVRegId());
}
break;
case Opcode::NegRR:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
}
break;
case Opcode::CondBr:
result.is_terminator = true;
break;
case Opcode::Call:
result.is_call = true;
for (size_t i = 0; i < ops.size(); ++i)
{
if (ops[i].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[i].GetVRegId());
}
break;
case Opcode::Ret:
for (size_t i = 0; i < ops.size(); ++i)
{
if (ops[i].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[i].GetVRegId());
}
result.is_terminator = true;
break;
case Opcode::MovReg:
if (ops.size() >= 2)
{
if (ops[0].GetKind() == Operand::Kind::VReg)
result.defs.push_back(ops[0].GetVRegId());
if (ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[0].GetKind() == Operand::Kind::Reg && ops[1].GetKind() == Operand::Kind::VReg)
result.uses.push_back(ops[1].GetVRegId());
if (ops[0].GetKind() == Operand::Kind::VReg && ops[1].GetKind() == Operand::Kind::Reg)
result.defs.push_back(ops[0].GetVRegId());
}
break;
default:
for (const auto &op : ops)
{
if (op.GetKind() == Operand::Kind::VReg)
result.uses.push_back(op.GetVRegId());
}
break;
}
auto remove_dups = [](std::vector<int> &v)
{
std::sort(v.begin(), v.end());
v.erase(std::unique(v.begin(), v.end()), v.end());
};
remove_dups(result.defs);
remove_dups(result.uses);
return result;
}
struct BlockLiveness
{
std::unordered_set<int> live_in;
std::unordered_set<int> live_out;
std::unordered_set<int> def;
std::unordered_set<int> use;
};
static std::vector<BlockLiveness> ComputeBlockLiveness(MachineFunction &function)
{
auto &blocks = function.GetBlocks();
const size_t num_blocks = blocks.size();
std::vector<BlockLiveness> bl(num_blocks);
std::unordered_map<int, size_t> label_to_block;
for (size_t i = 0; i < num_blocks; ++i)
{
label_to_block[blocks[i]->GetLabelId()] = i;
}
for (size_t i = 0; i < num_blocks; ++i)
{
for (const auto &inst : blocks[i]->GetInstructions())
{
auto du = GetInstDefUse(inst, function);
for (int u : du.uses)
{
if (bl[i].def.find(u) == bl[i].def.end())
bl[i].use.insert(u);
}
for (int d : du.defs)
bl[i].def.insert(d);
}
}
bool changed = true;
while (changed)
{
changed = false;
for (size_t i = 0; i < num_blocks; ++i)
{
std::unordered_set<int> new_out;
for (const auto &inst : blocks[i]->GetInstructions())
{
if (inst.GetOpcode() == Opcode::Br && inst.GetOperands().size() >= 1 &&
inst.GetOperands()[0].GetKind() == Operand::Kind::Label)
{
auto it = label_to_block.find(inst.GetOperands()[0].GetLabel());
if (it != label_to_block.end())
for (int v : bl[it->second].live_in)
new_out.insert(v);
}
if (inst.GetOpcode() == Opcode::CondBr && inst.GetOperands().size() >= 2 &&
inst.GetOperands()[1].GetKind() == Operand::Kind::Label)
{
auto it = label_to_block.find(inst.GetOperands()[1].GetLabel());
if (it != label_to_block.end())
for (int v : bl[it->second].live_in)
new_out.insert(v);
}
}
if (new_out != bl[i].live_out)
{
bl[i].live_out = new_out;
changed = true;
}
std::unordered_set<int> new_in = bl[i].use;
for (int v : bl[i].live_out)
{
if (bl[i].def.find(v) == bl[i].def.end())
new_in.insert(v);
}
if (new_in != bl[i].live_in)
{
bl[i].live_in = new_in;
changed = true;
}
}
}
return bl;
}
struct InterferenceGraph
{
std::unordered_set<int> nodes;
std::unordered_set<long long> edges;
std::unordered_map<int, int> degree;
static long long MakeEdgeKey(int u, int v)
{
if (u > v) std::swap(u, v);
return (static_cast<long long>(static_cast<unsigned long long>(static_cast<unsigned int>(u)) << 32)) | static_cast<unsigned int>(v);
}
void AddNode(int v) { nodes.insert(v); }
void AddEdge(int u, int v)
{
if (u == v)
return;
auto key = MakeEdgeKey(u, v);
if (edges.find(key) == edges.end())
{
edges.insert(key);
degree[u]++;
degree[v]++;
}
}
int GetDegree(int v) const
{
auto it = degree.find(v);
return it != degree.end() ? it->second : 0;
}
};
static bool IsGPClass(VRegClass vc)
{
return vc == VRegClass::Int || vc == VRegClass::Ptr;
}
static void BuildInterferenceForGP(
MachineFunction &function,
const std::vector<BlockLiveness> &block_liveness,
const std::vector<int> &allocatable_regs,
InterferenceGraph &graph)
{
auto &blocks = function.GetBlocks();
for (int vreg = 0; vreg < function.GetNumVRegs(); ++vreg)
{
if (IsGPClass(function.GetVRegClass(vreg)))
graph.AddNode(vreg);
}
for (size_t bi = 0; bi < blocks.size(); ++bi)
{
auto &block = blocks[bi];
std::unordered_set<int> live = block_liveness[bi].live_out;
std::vector<int> gp_live;
for (int v : live)
{
if (IsGPClass(function.GetVRegClass(v)))
gp_live.push_back(v);
}
for (size_t i = 0; i < gp_live.size(); ++i)
{
for (size_t j = i + 1; j < gp_live.size(); ++j)
{
graph.AddEdge(gp_live[i], gp_live[j]);
}
}
const auto &insts = block->GetInstructions();
for (auto it = insts.rbegin(); it != insts.rend(); ++it)
{
auto du = GetInstDefUse(*it, function);
for (int d : du.defs)
{
if (!IsGPClass(function.GetVRegClass(d)))
continue;
for (int v : live)
{
if (v != d && IsGPClass(function.GetVRegClass(v)))
graph.AddEdge(d, v);
}
live.erase(d);
}
for (int u : du.uses)
{
if (IsGPClass(function.GetVRegClass(u)))
live.insert(u);
}
if (du.is_call)
{
for (int v : live)
{
if (IsGPClass(function.GetVRegClass(v)))
{
for (int c : allocatable_regs)
{
bool is_caller_saved = false;
for (int j = 0; j < GP_NUM_CALLER_SAVED; ++j)
{
if (GP_CALLER_SAVED[j] == c)
{
is_caller_saved = true;
break;
}
}
if (!is_caller_saved)
continue;
int call_clobber_node = -1000 - c;
graph.AddNode(call_clobber_node);
graph.AddEdge(v, call_clobber_node);
}
}
}
}
}
}
}
static void BuildInterferenceForFP(
MachineFunction &function,
const std::vector<BlockLiveness> &block_liveness,
const std::vector<int> &allocatable_regs,
InterferenceGraph &graph)
{
auto &blocks = function.GetBlocks();
for (int vreg = 0; vreg < function.GetNumVRegs(); ++vreg)
{
if (function.GetVRegClass(vreg) == VRegClass::Float)
graph.AddNode(vreg);
}
for (size_t bi = 0; bi < blocks.size(); ++bi)
{
auto &block = blocks[bi];
std::unordered_set<int> live = block_liveness[bi].live_out;
std::vector<int> fp_live;
for (int v : live)
{
if (function.GetVRegClass(v) == VRegClass::Float)
fp_live.push_back(v);
}
for (size_t i = 0; i < fp_live.size(); ++i)
{
for (size_t j = i + 1; j < fp_live.size(); ++j)
{
graph.AddEdge(fp_live[i], fp_live[j]);
}
}
const auto &insts = block->GetInstructions();
for (auto it = insts.rbegin(); it != insts.rend(); ++it)
{
auto du = GetInstDefUse(*it, function);
for (int d : du.defs)
{
if (function.GetVRegClass(d) != VRegClass::Float)
continue;
for (int v : live)
{
if (v != d && function.GetVRegClass(v) == VRegClass::Float)
graph.AddEdge(d, v);
}
live.erase(d);
}
for (int u : du.uses)
{
if (function.GetVRegClass(u) == VRegClass::Float)
live.insert(u);
}
if (du.is_call)
{
for (int v : live)
{
if (function.GetVRegClass(v) == VRegClass::Float)
{
for (int j = 0; j < FP_NUM_CALLER_SAVED; ++j)
{
int c = FP_CALLER_SAVED[j];
int call_clobber_node = -2000 - c;
graph.AddNode(call_clobber_node);
graph.AddEdge(v, call_clobber_node);
}
}
}
}
}
}
}
struct GraphColoringResult
{
std::unordered_map<int, int> assignment;
std::set<int> spilled;
};
static GraphColoringResult ColorGraph(
InterferenceGraph &graph,
const std::vector<int> &allocatable_regs,
MachineFunction & /*function*/,
const std::unordered_map<int, std::vector<int>> &copy_edges = {})
{
const int K = static_cast<int>(allocatable_regs.size());
GraphColoringResult result;
std::unordered_map<int, std::unordered_set<int>> adj;
for (auto key : graph.edges)
{
int u = static_cast<int>(key >> 32);
int v = static_cast<int>(key & 0xFFFFFFFF);
adj[u].insert(v);
adj[v].insert(u);
}
std::unordered_map<int, int> precolored;
for (int v : graph.nodes)
{
if (v < 0)
{
if (v >= -1999 && v < -1000)
{
int reg_num = -(v + 1000);
precolored[v] = reg_num;
}
else if (v >= -2999 && v < -2000)
{
int reg_num = -(v + 2000);
precolored[v] = reg_num;
}
}
}
std::unordered_set<int> remaining;
std::unordered_map<int, int> degree;
for (int v : graph.nodes)
{
if (v >= 0)
{
remaining.insert(v);
int deg = 0;
for (int n : adj[v])
{
if (graph.nodes.count(n))
deg++;
}
degree[v] = deg;
}
}
std::vector<int> simplify_worklist;
for (int v : remaining)
{
if (degree[v] < K)
simplify_worklist.push_back(v);
}
std::vector<int> stack;
while (!remaining.empty())
{
while (!simplify_worklist.empty())
{
int v = simplify_worklist.back();
simplify_worklist.pop_back();
if (!remaining.count(v))
continue;
stack.push_back(v);
remaining.erase(v);
for (int n : adj[v])
{
if (remaining.count(n))
{
degree[n]--;
if (degree[n] == K - 1)
simplify_worklist.push_back(n);
}
}
}
if (!remaining.empty())
{
int spill_candidate = -1;
int max_degree = -1;
for (int v : remaining)
{
if (degree[v] > max_degree)
{
max_degree = degree[v];
spill_candidate = v;
}
}
if (spill_candidate >= 0)
{
result.spilled.insert(spill_candidate);
remaining.erase(spill_candidate);
for (int n : adj[spill_candidate])
{
if (remaining.count(n))
{
degree[n]--;
if (degree[n] == K - 1)
simplify_worklist.push_back(n);
}
}
}
else
{
break;
}
}
}
std::unordered_map<int, int> colored = precolored;
while (!stack.empty())
{
int v = stack.back();
stack.pop_back();
std::unordered_set<int> used_colors;
for (int n : adj[v])
{
auto it = colored.find(n);
if (it != colored.end())
used_colors.insert(it->second);
}
int assigned_color = -1;
// 偏置着色:优先使用 copy edge 源操作数的颜色
auto copy_it = copy_edges.find(v);
if (copy_it != copy_edges.end())
{
for (int neighbor : copy_it->second)
{
auto nit = colored.find(neighbor);
if (nit != colored.end())
{
int pref = nit->second;
if (used_colors.find(pref) == used_colors.end())
{
assigned_color = pref;
break;
}
}
}
}
if (assigned_color < 0)
{
for (int c : allocatable_regs)
{
if (used_colors.find(c) == used_colors.end())
{
assigned_color = c;
break;
}
}
}
if (assigned_color >= 0)
{
colored[v] = assigned_color;
result.assignment[v] = assigned_color;
}
else
{
result.spilled.insert(v);
}
}
return result;
}
static int PickGPScratchReg(
const std::set<int> &used_scratch,
const InstDefUse &du,
const std::unordered_map<int, int> &gp_assignment,
int skip_vreg = -1)
{
if (!used_scratch.count(14))
{
bool x14_used = false;
for (int d : du.defs)
{
if (d == skip_vreg) continue;
auto it = gp_assignment.find(d);
if (it != gp_assignment.end() && it->second == 14)
{ x14_used = true; break; }
}
if (!x14_used)
{
for (int u2 : du.uses)
{
auto it = gp_assignment.find(u2);
if (it != gp_assignment.end() && it->second == 14)
{ x14_used = true; break; }
}
}
if (!x14_used)
return 14;
}
for (int r : GP_ALLOCATABLE)
{
if (used_scratch.count(r)) continue;
bool used = false;
for (int d : du.defs)
{
if (d == skip_vreg) continue;
auto it = gp_assignment.find(d);
if (it != gp_assignment.end() && it->second == r)
{ used = true; break; }
}
if (!used)
{
for (int u2 : du.uses)
{
auto it = gp_assignment.find(u2);
if (it != gp_assignment.end() && it->second == r)
{ used = true; break; }
}
}
if (!used) return r;
}
return GP_ALLOCATABLE[0];
}
static int PickFPScratchReg(
const std::set<int> &used_scratch,
const InstDefUse &du,
const std::unordered_map<int, int> &fp_assignment,
int skip_vreg = -1)
{
for (int r : FP_ALLOCATABLE)
{
if (used_scratch.count(r)) continue;
bool used = false;
for (int d : du.defs)
{
if (d == skip_vreg) continue;
auto it = fp_assignment.find(d);
if (it != fp_assignment.end() && it->second == r)
{ used = true; break; }
}
if (!used)
{
for (int u2 : du.uses)
{
auto it = fp_assignment.find(u2);
if (it != fp_assignment.end() && it->second == r)
{ used = true; break; }
}
}
if (!used) return r;
}
return FP_ALLOCATABLE[0];
}
static void RewriteWithAllocation(
MachineFunction &function,
const std::unordered_map<int, int> &gp_assignment,
const std::unordered_map<int, int> &fp_assignment,
const std::set<int> &spilled)
{
std::unordered_map<int, int> spill_slots;
for (int v : spilled)
{
int size = (function.GetVRegClass(v) == VRegClass::Ptr) ? 8 : 4;
spill_slots[v] = function.CreateFrameIndex(size);
}
for (auto &block : function.GetBlocks())
{
std::vector<MachineInstr> new_insts;
for (auto &inst : block->GetInstructions())
{
auto du = GetInstDefUse(inst, function);
std::set<int> used_scratch_gp;
std::set<int> used_scratch_fp;
for (int u : du.uses)
{
if (spilled.count(u))
{
VRegClass vc = function.GetVRegClass(u);
int slot = spill_slots[u];
int reload_reg_num = -1;
if (vc == VRegClass::Float)
{
reload_reg_num = PickFPScratchReg(used_scratch_fp, du, fp_assignment);
used_scratch_fp.insert(reload_reg_num);
}
else
{
reload_reg_num = PickGPScratchReg(used_scratch_gp, du, gp_assignment);
used_scratch_gp.insert(reload_reg_num);
}
PhysReg reload_reg = NumberToPhysReg(reload_reg_num, vc);
new_insts.push_back(
MachineInstr(Opcode::LoadStack,
{Operand::Reg(reload_reg), Operand::FrameIndex(slot)}));
for (auto &op : inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::VReg && op.GetVRegId() == u)
{
const_cast<Operand &>(op) = Operand::Reg(reload_reg);
}
}
}
}
for (auto &op : inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::VReg)
{
int vreg_id = op.GetVRegId();
VRegClass op_vc = op.GetVRegClass();
VRegClass actual_vc = function.GetVRegClass(vreg_id);
int reg_num = -1;
if (actual_vc == VRegClass::Float)
{
auto it = fp_assignment.find(vreg_id);
if (it != fp_assignment.end())
reg_num = it->second;
}
else
{
auto it = gp_assignment.find(vreg_id);
if (it != gp_assignment.end())
reg_num = it->second;
}
if (reg_num >= 0)
{
const_cast<Operand &>(op) = Operand::Reg(NumberToPhysReg(reg_num, actual_vc));
}
else if (spilled.count(vreg_id))
{
int spill_reg_num = -1;
if (actual_vc == VRegClass::Float)
{
spill_reg_num = PickFPScratchReg(used_scratch_fp, du, fp_assignment, vreg_id);
used_scratch_fp.insert(spill_reg_num);
}
else
{
spill_reg_num = PickGPScratchReg(used_scratch_gp, du, gp_assignment, vreg_id);
used_scratch_gp.insert(spill_reg_num);
}
const_cast<Operand &>(op) = Operand::Reg(NumberToPhysReg(spill_reg_num, actual_vc));
}
}
}
new_insts.push_back(std::move(const_cast<MachineInstr &>(inst)));
for (int d : du.defs)
{
if (spilled.count(d))
{
VRegClass vc = function.GetVRegClass(d);
int slot = spill_slots[d];
const auto &last_inst = new_insts.back();
PhysReg spill_reg = PhysReg::W0;
for (const auto &op : last_inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::Reg)
{
PhysReg r = op.GetReg();
if (vc == VRegClass::Float && IsFPReg(r))
{ spill_reg = r; break; }
else if (vc != VRegClass::Float && IsGPReg(r))
{ spill_reg = r; break; }
}
}
new_insts.push_back(
MachineInstr(Opcode::StoreStack,
{Operand::Reg(spill_reg), Operand::FrameIndex(slot)}));
}
}
}
block->GetInstructions() = std::move(new_insts);
}
}
static void AllocateRegistersForFunction(MachineFunction &function)
{
if (function.GetNumVRegs() == 0)
return;
const int MAX_SPILL_ROUNDS = 10;
for (int round = 0; round < MAX_SPILL_ROUNDS; ++round)
{
auto block_liveness = ComputeBlockLiveness(function);
// 收集 copy edges (MovReg 连接的 vreg 对)
std::unordered_map<int, std::vector<int>> gp_copy_edges;
std::unordered_map<int, std::vector<int>> fp_copy_edges;
for (auto &block : function.GetBlocks())
{
for (auto &inst : block->GetInstructions())
{
if (inst.GetOpcode() == Opcode::MovReg)
{
const auto &ops = inst.GetOperands();
if (ops.size() >= 2 &&
ops[0].GetKind() == Operand::Kind::VReg &&
ops[1].GetKind() == Operand::Kind::VReg)
{
int dst = ops[0].GetVRegId();
int src = ops[1].GetVRegId();
if (function.GetVRegClass(dst) == function.GetVRegClass(src))
{
auto &edges = IsGPClass(function.GetVRegClass(dst))
? gp_copy_edges : fp_copy_edges;
edges[dst].push_back(src);
edges[src].push_back(dst);
}
}
}
}
}
std::vector<int> gp_alloc(GP_ALLOCATABLE, GP_ALLOCATABLE + GP_NUM_ALLOCATABLE);
std::vector<int> fp_alloc(FP_ALLOCATABLE, FP_ALLOCATABLE + FP_NUM_ALLOCATABLE);
InterferenceGraph gp_graph, fp_graph;
BuildInterferenceForGP(function, block_liveness, gp_alloc, gp_graph);
BuildInterferenceForFP(function, block_liveness, fp_alloc, fp_graph);
auto gp_result = ColorGraph(gp_graph, gp_alloc, function, gp_copy_edges);
auto fp_result = ColorGraph(fp_graph, fp_alloc, function, fp_copy_edges);
if (gp_result.spilled.empty() && fp_result.spilled.empty())
{
std::unordered_map<int, int> gp_assign = gp_result.assignment;
std::unordered_map<int, int> fp_assign = fp_result.assignment;
for (const auto &pair : gp_assign)
{
if (pair.second >= 19 && pair.second <= 28)
{
function.AddCalleeSavedReg(NumberToPhysReg(pair.second, VRegClass::Ptr));
}
}
for (const auto &pair : fp_assign)
{
if (pair.second >= 16 && pair.second <= 31)
{
function.AddCalleeSavedReg(NumberToPhysReg(pair.second, VRegClass::Float));
}
}
RewriteWithAllocation(function, gp_assign, fp_assign, {});
return;
}
std::set<int> all_spilled = gp_result.spilled;
for (int v : fp_result.spilled)
all_spilled.insert(v);
std::unordered_map<int, int> spill_slots;
for (int v : all_spilled)
{
int size = (function.GetVRegClass(v) == VRegClass::Ptr) ? 8 : 4;
spill_slots[v] = function.CreateFrameIndex(size);
}
for (auto &block : function.GetBlocks())
{
std::vector<MachineInstr> new_insts;
for (auto &inst : block->GetInstructions())
{
auto du = GetInstDefUse(inst, function);
for (int u : du.uses)
{
if (all_spilled.count(u))
{
VRegClass vc = function.GetVRegClass(u);
int new_vreg = function.CreateVReg(vc);
new_insts.push_back(
MachineInstr(Opcode::LoadStack,
{Operand::VReg(new_vreg, vc), Operand::FrameIndex(spill_slots[u])}));
for (auto &op : inst.GetOperands())
{
if (op.GetKind() == Operand::Kind::VReg && op.GetVRegId() == u)
{
const_cast<Operand &>(op) = Operand::VReg(new_vreg, vc);
}
}
}
}
new_insts.push_back(std::move(const_cast<MachineInstr &>(inst)));
for (int d : du.defs)
{
if (all_spilled.count(d))
{
VRegClass vc = function.GetVRegClass(d);
int new_vreg = function.CreateVReg(vc);
auto &last = new_insts.back();
for (auto &op : last.GetOperands())
{
if (op.GetKind() == Operand::Kind::VReg && op.GetVRegId() == d)
{
const_cast<Operand &>(op) = Operand::VReg(new_vreg, vc);
}
}
new_insts.push_back(
MachineInstr(Opcode::StoreStack,
{Operand::VReg(new_vreg, vc), Operand::FrameIndex(spill_slots[d])}));
}
}
}
block->GetInstructions() = std::move(new_insts);
}
}
auto block_liveness = ComputeBlockLiveness(function);
// 收集 copy edges
std::unordered_map<int, std::vector<int>> fp_copy_edges_fb;
std::unordered_map<int, std::vector<int>> gp_copy_edges_fb;
for (auto &block : function.GetBlocks())
for (auto &inst : block->GetInstructions())
if (inst.GetOpcode() == Opcode::MovReg)
{
const auto &ops = inst.GetOperands();
if (ops.size() >= 2 && ops[0].GetKind() == Operand::Kind::VReg &&
ops[1].GetKind() == Operand::Kind::VReg)
{
int dst = ops[0].GetVRegId(), src = ops[1].GetVRegId();
if (function.GetVRegClass(dst) == function.GetVRegClass(src))
{
auto &e = IsGPClass(function.GetVRegClass(dst)) ? gp_copy_edges_fb : fp_copy_edges_fb;
e[dst].push_back(src);
e[src].push_back(dst);
}
}
}
std::vector<int> gp_alloc(GP_ALLOCATABLE, GP_ALLOCATABLE + GP_NUM_ALLOCATABLE);
std::vector<int> fp_alloc(FP_ALLOCATABLE, FP_ALLOCATABLE + FP_NUM_ALLOCATABLE);
InterferenceGraph gp_graph, fp_graph;
BuildInterferenceForGP(function, block_liveness, gp_alloc, gp_graph);
BuildInterferenceForFP(function, block_liveness, fp_alloc, fp_graph);
auto gp_result = ColorGraph(gp_graph, gp_alloc, function, gp_copy_edges_fb);
auto fp_result = ColorGraph(fp_graph, fp_alloc, function, fp_copy_edges_fb);
std::set<int> all_spilled = gp_result.spilled;
for (int v : fp_result.spilled)
all_spilled.insert(v);
std::unordered_map<int, int> gp_assign = gp_result.assignment;
std::unordered_map<int, int> fp_assign = fp_result.assignment;
for (const auto &pair : gp_assign)
{
if (pair.second >= 19 && pair.second <= 28)
{
function.AddCalleeSavedReg(NumberToPhysReg(pair.second, VRegClass::Ptr));
}
}
for (const auto &pair : fp_assign)
{
if (pair.second >= 16 && pair.second <= 31)
{
function.AddCalleeSavedReg(NumberToPhysReg(pair.second, VRegClass::Float));
}
}
RewriteWithAllocation(function, gp_assign, fp_assign, all_spilled);
}
} // namespace
void RunRegAlloc(MachineFunction &function)
{
AllocateRegistersForFunction(function);
}
void RunRegAlloc(MachineModule &module)
{
for (auto &function : module.GetFunctions())
{
if (function)
RunRegAlloc(*function);
}
}
} // namespace mir