feat(mir): add virtual register support

dev
Codex 2 days ago
parent 6d0d3e42fa
commit 13ed95131a

@ -38,6 +38,12 @@ enum class PhysReg {
W13,
W14,
W15,
W19,
W20,
W21,
W22,
W23,
W24,
X0,
X1,
X2,
@ -54,6 +60,12 @@ enum class PhysReg {
X13,
X14,
X15,
X19,
X20,
X21,
X22,
X23,
X24,
S0,
S1,
S2,
@ -80,6 +92,8 @@ PhysReg WRegFromIndex(int index);
PhysReg XRegFromIndex(int index);
PhysReg SRegFromIndex(int index);
enum class RegClass { GPR32, GPR64, FPR32 };
enum class CondCode { EQ, NE, LT, LE, GT, GE };
const char* CondCodeName(CondCode cc);
@ -123,9 +137,10 @@ enum class Opcode {
class Operand {
public:
enum class Kind { Reg, Imm, FrameIndex, GlobalSymbol, Block };
enum class Kind { Reg, VReg, Imm, FrameIndex, GlobalSymbol, Block };
static Operand Reg(PhysReg reg);
static Operand VReg(int id);
static Operand Imm(int value);
static Operand FrameIndex(int index);
static Operand GlobalSymbol(std::string symbol);
@ -133,6 +148,7 @@ class Operand {
Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; }
int GetVReg() const { return imm_; }
int GetImm() const { return imm_; }
int GetFrameIndex() const { return imm_; }
const std::string& GetSymbol() const { return symbol_; }
@ -152,6 +168,7 @@ class MachineInstr {
MachineInstr(Opcode opcode, std::vector<Operand> operands = {});
Opcode GetOpcode() const { return opcode_; }
std::vector<Operand>& GetOperands() { return operands_; }
const std::vector<Operand>& GetOperands() const { return operands_; }
private:
@ -170,6 +187,14 @@ struct FrameSlot {
FrameSlotKind kind = FrameSlotKind::Temp;
};
struct VirtualRegInfo {
int id = -1;
RegClass reg_class = RegClass::GPR32;
int home_slot = -1;
bool spilled = false;
PhysReg assigned_reg = PhysReg::W0;
};
class MachineBasicBlock {
public:
explicit MachineBasicBlock(std::string name);
@ -237,12 +262,23 @@ class MachineFunction {
void SetStackArgSize(int size) { stack_arg_size_ = size; }
bool IsLeaf() const { return is_leaf_; }
void SetLeaf(bool is_leaf) { is_leaf_ = is_leaf; }
int CreateVirtualReg(RegClass reg_class, int home_slot = -1);
VirtualRegInfo& GetVirtualReg(int id);
const VirtualRegInfo& GetVirtualReg(int id) const;
std::vector<VirtualRegInfo>& GetVirtualRegs() { return virtual_regs_; }
const std::vector<VirtualRegInfo>& GetVirtualRegs() const { return virtual_regs_; }
void AddUsedCalleeSavedReg(PhysReg reg);
const std::vector<PhysReg>& GetUsedCalleeSavedRegs() const {
return used_callee_saved_regs_;
}
private:
std::string name_;
bool is_declaration_ = false;
std::vector<std::unique_ptr<MachineBasicBlock>> blocks_;
std::vector<FrameSlot> frame_slots_;
std::vector<VirtualRegInfo> virtual_regs_;
std::vector<PhysReg> used_callee_saved_regs_;
int frame_size_ = 0;
int stack_arg_size_ = 0;
bool is_leaf_ = true;

@ -44,6 +44,12 @@ int MachineFunction::CreateFrameIndex(int size, int align, FrameSlotKind kind) {
return index;
}
int MachineFunction::CreateVirtualReg(RegClass reg_class, int home_slot) {
int id = static_cast<int>(virtual_regs_.size());
virtual_regs_.push_back(VirtualRegInfo{id, reg_class, home_slot, false, PhysReg::W0});
return id;
}
FrameSlot& MachineFunction::GetFrameSlot(int index) {
if (index < 0 || index >= static_cast<int>(frame_slots_.size())) {
throw std::runtime_error(FormatError("mir", "非法 FrameIndex"));
@ -58,6 +64,29 @@ const FrameSlot& MachineFunction::GetFrameSlot(int index) const {
return frame_slots_[index];
}
VirtualRegInfo& MachineFunction::GetVirtualReg(int id) {
if (id < 0 || id >= static_cast<int>(virtual_regs_.size())) {
throw std::runtime_error(FormatError("mir", "非法虚拟寄存器"));
}
return virtual_regs_[id];
}
const VirtualRegInfo& MachineFunction::GetVirtualReg(int id) const {
if (id < 0 || id >= static_cast<int>(virtual_regs_.size())) {
throw std::runtime_error(FormatError("mir", "非法虚拟寄存器"));
}
return virtual_regs_[id];
}
void MachineFunction::AddUsedCalleeSavedReg(PhysReg reg) {
for (PhysReg existing : used_callee_saved_regs_) {
if (existing == reg) {
return;
}
}
used_callee_saved_regs_.push_back(reg);
}
MachineGlobal& MachineModule::AddGlobal(MachineGlobal global) {
globals_.push_back(std::move(global));
return globals_.back();

@ -9,6 +9,8 @@ Operand::Operand(Kind kind, PhysReg reg, int imm, std::string symbol)
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); }
Operand Operand::VReg(int id) { return Operand(Kind::VReg, PhysReg::W0, id); }
Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::W0, value);
}

@ -40,6 +40,18 @@ const char* PhysRegName(PhysReg reg) {
return "w14";
case PhysReg::W15:
return "w15";
case PhysReg::W19:
return "w19";
case PhysReg::W20:
return "w20";
case PhysReg::W21:
return "w21";
case PhysReg::W22:
return "w22";
case PhysReg::W23:
return "w23";
case PhysReg::W24:
return "w24";
case PhysReg::X0:
return "x0";
case PhysReg::X1:
@ -72,6 +84,18 @@ const char* PhysRegName(PhysReg reg) {
return "x14";
case PhysReg::X15:
return "x15";
case PhysReg::X19:
return "x19";
case PhysReg::X20:
return "x20";
case PhysReg::X21:
return "x21";
case PhysReg::X22:
return "x22";
case PhysReg::X23:
return "x23";
case PhysReg::X24:
return "x24";
case PhysReg::S0:
return "s0";
case PhysReg::S1:
@ -122,6 +146,12 @@ bool IsIntReg(PhysReg reg) {
case PhysReg::W13:
case PhysReg::W14:
case PhysReg::W15:
case PhysReg::W19:
case PhysReg::W20:
case PhysReg::W21:
case PhysReg::W22:
case PhysReg::W23:
case PhysReg::W24:
case PhysReg::X0:
case PhysReg::X1:
case PhysReg::X2:
@ -138,6 +168,12 @@ bool IsIntReg(PhysReg reg) {
case PhysReg::X13:
case PhysReg::X14:
case PhysReg::X15:
case PhysReg::X19:
case PhysReg::X20:
case PhysReg::X21:
case PhysReg::X22:
case PhysReg::X23:
case PhysReg::X24:
case PhysReg::X29:
case PhysReg::X30:
case PhysReg::SP:
@ -183,6 +219,12 @@ bool Is64BitReg(PhysReg reg) {
case PhysReg::X13:
case PhysReg::X14:
case PhysReg::X15:
case PhysReg::X19:
case PhysReg::X20:
case PhysReg::X21:
case PhysReg::X22:
case PhysReg::X23:
case PhysReg::X24:
case PhysReg::X29:
case PhysReg::X30:
case PhysReg::SP:
@ -226,6 +268,18 @@ PhysReg WRegFromIndex(int index) {
return PhysReg::W14;
case 15:
return PhysReg::W15;
case 19:
return PhysReg::W19;
case 20:
return PhysReg::W20;
case 21:
return PhysReg::W21;
case 22:
return PhysReg::W22;
case 23:
return PhysReg::W23;
case 24:
return PhysReg::W24;
}
throw std::runtime_error(FormatError("mir", "不支持的 W 寄存器编号"));
}
@ -264,6 +318,18 @@ PhysReg XRegFromIndex(int index) {
return PhysReg::X14;
case 15:
return PhysReg::X15;
case 19:
return PhysReg::X19;
case 20:
return PhysReg::X20;
case 21:
return PhysReg::X21;
case 22:
return PhysReg::X22;
case 23:
return PhysReg::X23;
case 24:
return PhysReg::X24;
}
throw std::runtime_error(FormatError("mir", "不支持的 X 寄存器编号"));
}

Loading…
Cancel
Save