feat(ra)初步实现功能

feature/ra
mxr 6 days ago
parent 70234dde70
commit 0d170d1af8

@ -1,9 +1,12 @@
#pragma once
#include <cstdint>
#include <initializer_list>
#include <iosfwd>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
namespace ir {
@ -148,9 +151,10 @@ enum class Opcode {
// ========== 操作数类 ==========
class Operand {
public:
enum class Kind { Reg, Imm, FrameIndex, Cond, Label };
enum class Kind { Reg, VReg, Imm, FrameIndex, Cond, Label };
static Operand Reg(PhysReg reg);
static Operand VReg(int id);
static Operand Imm(int value);
static Operand FrameIndex(int index);
static Operand Cond(CondCode cc);
@ -158,11 +162,15 @@ class Operand {
Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; }
int GetVReg() const { return imm_; }
int GetImm() const { return imm_; }
int GetFrameIndex() const { return imm_; }
CondCode GetCondCode() const { return cc_; }
const std::string& GetLabel() const { return label_; }
bool IsVReg() const { return kind_ == Kind::VReg; }
bool IsPhysReg() const { return kind_ == Kind::Reg; }
private:
Operand(Kind kind, PhysReg reg, int imm, CondCode cc, const std::string& label);
@ -180,10 +188,28 @@ class MachineInstr {
Opcode GetOpcode() const { return opcode_; }
const std::vector<Operand>& GetOperands() const { return operands_; }
std::vector<Operand>& GetOperands() { return operands_; }
// def/use 信息(用于活跃性分析)
const std::vector<int>& GetDefs() const { return defs_; }
const std::vector<int>& GetUses() const { return uses_; }
std::vector<int>& GetDefs() { return defs_; }
std::vector<int>& GetUses() { return uses_; }
void AddDef(int vreg) { defs_.push_back(vreg); }
void AddUse(int vreg) { uses_.push_back(vreg); }
// 指令分类
bool IsCall() const { return opcode_ == Opcode::Call; }
bool IsTerminator() const {
return opcode_ == Opcode::B || opcode_ == Opcode::BCond || opcode_ == Opcode::Ret;
}
bool IsMove() const { return opcode_ == Opcode::MovReg; }
private:
Opcode opcode_;
std::vector<Operand> operands_;
std::vector<int> defs_;
std::vector<int> uses_;
};
// ========== 栈槽结构 ==========
@ -211,10 +237,15 @@ class MachineBasicBlock {
const std::vector<MachineBasicBlock*>& GetSuccessors() const { return successors_; }
void AddSuccessor(MachineBasicBlock* succ) { successors_.push_back(succ); }
std::vector<MachineBasicBlock*>& GetPredecessors() { return predecessors_; }
const std::vector<MachineBasicBlock*>& GetPredecessors() const { return predecessors_; }
void AddPredecessor(MachineBasicBlock* pred) { predecessors_.push_back(pred); }
private:
std::string name_;
std::vector<MachineInstr> instructions_;
std::vector<MachineBasicBlock*> successors_;
std::vector<MachineBasicBlock*> predecessors_;
};
// ========== MIR 函数 ==========
@ -223,39 +254,69 @@ class MachineFunction {
explicit MachineFunction(std::string name);
const std::string& GetName() const { return name_; }
// 基本块管理
MachineBasicBlock& GetEntry() { return entry_; }
const MachineBasicBlock& GetEntry() const { return entry_; }
std::vector<std::unique_ptr<MachineBasicBlock>>& GetBasicBlocks() {
return basic_blocks_;
std::vector<std::unique_ptr<MachineBasicBlock>>& GetBasicBlocks() {
return basic_blocks_;
}
const std::vector<std::unique_ptr<MachineBasicBlock>>& GetBasicBlocks() const {
return basic_blocks_;
}
void AddBasicBlock(std::unique_ptr<MachineBasicBlock> bb) {
basic_blocks_.push_back(std::move(bb));
}
MachineBasicBlock* GetBlockByName(const std::string& name) {
for (auto& bb : basic_blocks_) {
if (bb->GetName() == name) return bb.get();
}
return nullptr;
}
// 栈槽管理
int CreateFrameIndex(int size = 4);
FrameSlot& GetFrameSlot(int index);
const FrameSlot& GetFrameSlot(int index) const;
std::vector<FrameSlot>& GetFrameSlots() { return frame_slots_; }
const std::vector<FrameSlot>& GetFrameSlots() const { return frame_slots_; }
// 栈帧大小
int GetFrameSize() const { return frame_size_; }
void SetFrameSize(int size) { frame_size_ = size; }
// callee-saved 寄存器管理
void MarkCalleeSaved(PhysReg reg) { used_callee_saved_regs_.insert(reg); }
const std::set<PhysReg>& GetCalleeSavedRegs() const { return used_callee_saved_regs_; }
bool IsCalleeSavedUsed(PhysReg reg) const {
return used_callee_saved_regs_.count(reg) > 0;
}
// spill 槽管理
int CreateSpillSlot(int size = 4);
bool IsSpillSlot(int index) const;
// vreg 类型管理(由 Lowering 填充RA 使用)
enum class VRegType : uint8_t { kInt32 = 0, kInt64 = 1, kFloat32 = 2 };
void SetVRegType(int vreg, VRegType type) { vreg_types_[vreg] = type; }
VRegType GetVRegType(int vreg) const {
auto it = vreg_types_.find(vreg);
return it != vreg_types_.end() ? it->second : VRegType::kInt32;
}
bool HasVRegType(int vreg) const { return vreg_types_.count(vreg) > 0; }
private:
std::string name_;
MachineBasicBlock entry_;
std::vector<std::unique_ptr<MachineBasicBlock>> basic_blocks_;
std::vector<FrameSlot> frame_slots_;
std::set<int> spill_slot_indices_;
int frame_size_ = 0;
std::set<PhysReg> used_callee_saved_regs_;
std::unordered_map<int, VRegType> vreg_types_;
};
// ========== MIR 模块 ==========
@ -324,12 +385,9 @@ class MachineModule {
};
// ========== 后端流程函数 ==========
/* std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineFunction& function);
void RunFrameLowering(MachineFunction& function);
void PrintAsm(const MachineFunction& function, std::ostream& os); */
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineModule& module);
void RunMIRPasses(MachineModule& module);
void RunFrameLowering(MachineModule& module);
void PrintAsm(const MachineModule& module, std::ostream& os);

@ -46,16 +46,13 @@ int main(int argc, char** argv) {
}
if (opts.emit_asm) {
//auto machine_func = mir::LowerToMIR(*module);
auto machine_module = mir::LowerToMIR(*module);
//mir::RunRegAlloc(*machine_func);
mir::RunRegAlloc(*machine_module);
//mir::RunFrameLowering(*machine_func);
mir::RunMIRPasses(*machine_module);
mir::RunFrameLowering(*machine_module);
if (need_blank_line) {
std::cout << "\n";
}
//mir::PrintAsm(*machine_func, std::cout);
mir::PrintAsm(*machine_module, std::cout);
}
#else

@ -64,6 +64,10 @@ void PrintOperand(std::ostream& os, const Operand& op) {
case Operand::Kind::Reg:
os << PhysRegName(op.GetReg());
break;
case Operand::Kind::VReg:
throw std::runtime_error(
FormatError("asm", "寄存器分配未完成: 存在虚拟寄存器 #" +
std::to_string(op.GetVReg())));
case Operand::Kind::Imm:
os << "#" << op.GetImm();
break;
@ -88,6 +92,57 @@ static bool IsLegalAddSubImm(int64_t imm) {
return false;
}
// ---- 寄存器宽度规范化 ----
static bool IsWReg(PhysReg reg) {
return reg >= PhysReg::W0 && reg <= PhysReg::W30;
}
static bool IsXReg(PhysReg reg) {
return reg >= PhysReg::X0 && reg <= PhysReg::X30;
}
static bool IsSReg(PhysReg reg) {
return reg >= PhysReg::S0 && reg <= PhysReg::S31;
}
// Xn → Wn, Wn → Wn, Sn → Sn
static PhysReg ToW(PhysReg reg) {
if (IsXReg(reg))
return static_cast<PhysReg>(
static_cast<int>(reg) - static_cast<int>(PhysReg::X0) + static_cast<int>(PhysReg::W0));
return reg;
}
// Wn → Xn, Xn → Xn, Sn → Sn
static PhysReg ToX(PhysReg reg) {
if (IsWReg(reg))
return static_cast<PhysReg>(
static_cast<int>(reg) - static_cast<int>(PhysReg::W0) + static_cast<int>(PhysReg::X0));
return reg;
}
// 检查一组操作数是否全是同一宽度W/X/S
static bool AllSameRegWidth(const std::vector<Operand>& ops) {
int kind = -1;
for (const auto& op : ops) {
if (op.GetKind() != Operand::Kind::Reg) continue;
PhysReg r = op.GetReg();
if (IsWReg(r)) { if (kind == -1) kind = 0; else if (kind != 0) return false; }
else if (IsXReg(r)) { if (kind == -1) kind = 1; else if (kind != 1) return false; }
else if (IsSReg(r)) { if (kind == -1) kind = 2; else if (kind != 2) return false; }
}
return true;
}
// 根据目的地宽度规范化所有寄存器操作数
static void NormalizeRegOps(std::vector<Operand>& ops, PhysReg dst) {
PhysReg base = dst;
bool wantW = IsWReg(base);
bool wantX = IsXReg(base);
for (auto& op : ops) {
if (op.GetKind() != Operand::Kind::Reg) continue;
if (wantW) op = Operand::Reg(ToW(op.GetReg()));
else if (wantX) op = Operand::Reg(ToX(op.GetReg()));
}
}
// 在匿名命名空间添加辅助函数
static void PrintLoadImm64(std::ostream& os, PhysReg reg, uint64_t imm) {
// 输出 movz + movk 序列
@ -148,37 +203,48 @@ void PrintInstruction(std::ostream& os, const MachineInstr& instr,
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::MovReg:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
case Opcode::MovReg:{
PhysReg dst = ops.at(0).GetReg();
PhysReg src = ops.at(1).GetReg();
if (IsSReg(dst) || IsSReg(src)) {
// 涉及 S 寄存器的 move使用 fmov
if (!IsSReg(dst)) dst = ToW(dst); // 确保是 W 寄存器
if (!IsSReg(src)) src = ToW(src);
os << " fmov " << PhysRegName(dst) << ", " << PhysRegName(src) << "\n";
} else {
// GPR move规范化宽度
if (IsWReg(dst) && IsXReg(src)) {
src = ToW(src);
} else if (IsXReg(dst) && IsWReg(src)) {
src = ToX(src);
}
os << " mov " << PhysRegName(dst) << ", " << PhysRegName(src) << "\n";
}
break;
}
case Opcode::StoreStack: {
// 检查第二个操作数的类型
if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::FrameIndex) {
// 存储到栈槽
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
} else if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
// 间接存储:存储到寄存器指向的地址
// STR W9, [X8]
// 间接存储:基址必须是 X 寄存器
PhysReg base = ToX(ops.at(1).GetReg());
os << " str " << PhysRegName(ops.at(0).GetReg()) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
<< PhysRegName(base) << "]\n";
} else {
throw std::runtime_error("StoreStack: 无效的操作数类型");
}
break;
}
case Opcode::LoadStack: {
// 检查第二个操作数的类型
if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::FrameIndex) {
// 从栈槽加载
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
} else if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
// 间接加载:从寄存器指向的地址加载
// LDR W9, [X8]
// 间接加载:基址必须是 X 寄存器
PhysReg base = ToX(ops.at(1).GetReg());
os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
<< PhysRegName(base) << "]\n";
} else {
throw std::runtime_error("LoadStack: 无效的操作数类型");
}
@ -204,115 +270,181 @@ void PrintInstruction(std::ostream& os, const MachineInstr& instr,
}
os << "\n";
break;
case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::AddRI:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #"
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::SubRR:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SubRI:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #"
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::MulRR:
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SDivRR:
os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::UDivRR:
os << " udiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FAddRR:
os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FSubRR:
os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FMulRR:
os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FDivRR:
os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::CmpRR:
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::CmpRI:
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
case Opcode::AddRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " add " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::AddRI: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " add " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", #"
<< nops[2].GetImm() << "\n";
break;
case Opcode::FCmpRR:
os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
}
case Opcode::SubRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " sub " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
case Opcode::SIToFP:
os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
}
case Opcode::SubRI: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " sub " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", #"
<< nops[2].GetImm() << "\n";
break;
case Opcode::FPToSI:
os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
}
case Opcode::MulRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " mul " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::SDivRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " sdiv " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::UDivRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " udiv " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::FAddRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " fadd " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::FSubRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " fsub " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::FMulRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " fmul " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::FDivRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " fdiv " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::CmpRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " cmp " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << "\n";
break;
}
case Opcode::CmpRI: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " cmp " << PhysRegName(nops[0].GetReg()) << ", #"
<< nops[1].GetImm() << "\n";
break;
}
case Opcode::FCmpRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " fcmp " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << "\n";
break;
}
case Opcode::SIToFP: {
PhysReg dst = ops.at(0).GetReg();
PhysReg src = ops.at(1).GetReg();
if (!IsWReg(src)) src = ToW(src);
os << " scvtf " << PhysRegName(dst) << ", " << PhysRegName(src) << "\n";
break;
}
case Opcode::FPToSI: {
PhysReg dst = ops.at(0).GetReg();
PhysReg src = ops.at(1).GetReg();
if (!IsWReg(dst)) dst = ToW(dst);
os << " fcvtzs " << PhysRegName(dst) << ", " << PhysRegName(src) << "\n";
break;
}
case Opcode::ZExt:
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #1\n";
break;
case Opcode::AndRR:
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::OrRR:
os << " orr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::EorRR:
os << " eor " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::LslRR:
os << " lsl " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::LsrRR:
os << " lsr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::AsrRR:
os << " asr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
case Opcode::AndRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " and " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::OrRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " orr " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::EorRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " eor " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::LslRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " lsl " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::LsrRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " lsr " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::AsrRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " asr " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::B:
os << " b ";
PrintOperand(os, ops.at(0));
@ -345,8 +477,8 @@ void PrintInstruction(std::ostream& os, const MachineInstr& instr,
break;
case Opcode::LoadStackAddr: {
const FrameSlot& slot = GetFrameSlot(function, ops.at(1));
int64_t offset = slot.offset; // 负值,如 -8
PhysReg dst = ops.at(0).GetReg();
int64_t offset = slot.offset;
PhysReg dst = ToX(ops.at(0).GetReg()); // 地址必须是 X 寄存器
auto tryEmitSimple = [&]() -> bool {
if (offset >= 0 && offset <= 4095) {
@ -384,10 +516,15 @@ void PrintInstruction(std::ostream& os, const MachineInstr& instr,
<< ops.at(2).GetLabel() << "\n";
break;
}
case Opcode::Sxtw:
os << " sxtw " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
case Opcode::Sxtw: {
PhysReg dst = ops.at(0).GetReg();
PhysReg src = ops.at(1).GetReg();
// sxtw 要求 X 目标W 源
if (!IsXReg(dst)) dst = ToX(dst);
if (!IsWReg(src)) src = ToW(src);
os << " sxtw " << PhysRegName(dst) << ", " << PhysRegName(src) << "\n";
break;
}
default:
os << " // unknown instruction\n";
break;

@ -1,19 +1,11 @@
#include "mir/MIR.h"
#include <algorithm>
#include <stdexcept>
#include <vector>
#include "utils/Log.h"
//#define DEBUG_Frame
#ifdef DEBUG_Frame
#include <iostream>
#define DEBUG_MSG(msg) std::cerr << "[Frame Debug] " << msg << std::endl
#else
#define DEBUG_MSG(msg)
#endif
namespace mir {
namespace {
@ -21,10 +13,47 @@ int AlignTo(int value, int align) {
return ((value + align - 1) / align) * align;
}
// 收集排序后的 callee-saved 寄存器列表
// 返回:{ (physReg, frameIndex) } 对
struct CSRegSlot {
PhysReg phys_reg;
int frame_index;
int size; // 8 for x registers, 4 for s registers
};
std::vector<CSRegSlot> CollectCalleeSavedSlots(MachineFunction& function) {
std::vector<CSRegSlot> slots;
const auto& regs = function.GetCalleeSavedRegs();
// 整数 callee-saved (X19-X28 格式,每个 8 字节)
for (int i = 19; i <= 28; ++i) {
PhysReg xreg = static_cast<PhysReg>(static_cast<int>(PhysReg::X19) + (i - 19));
PhysReg wreg = static_cast<PhysReg>(static_cast<int>(PhysReg::W19) + (i - 19));
if (regs.count(wreg) || regs.count(xreg)) {
int slot = function.CreateFrameIndex(8);
slots.push_back({xreg, slot, 8});
}
}
// 浮点 callee-saved (S8-S15每个 4 字节)
for (int i = 8; i <= 15; ++i) {
PhysReg sreg = static_cast<PhysReg>(static_cast<int>(PhysReg::S8) + (i - 8));
if (regs.count(sreg)) {
int slot = function.CreateFrameIndex(8);
slots.push_back({sreg, slot, 8});
}
}
return slots;
}
} // namespace
void RunFrameLowering(MachineFunction& function) {
DEBUG_MSG("function RunFrameLowering");
// 收集 callee-saved 寄存器并分配栈槽
auto csSlots = CollectCalleeSavedSlots(function);
// 计算栈槽偏移
int cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
@ -32,26 +61,34 @@ void RunFrameLowering(MachineFunction& function) {
}
function.SetFrameSize(AlignTo(cursor, 16));
// 基本块
// 插入 Prologue / Epilogue
const auto& blocks = function.GetBasicBlocks();
bool firstBlock = true;
for (const auto& bb : blocks) {
DEBUG_MSG("block");
auto& insts = bb->GetInstructions();
std::vector<MachineInstr> lowered;
// 输出基本块标签(非第一个基本块)
if (firstBlock) {
DEBUG_MSG("empalace Prologue");
lowered.emplace_back(Opcode::Prologue);
lowered.emplace_back(Opcode::Prologue);
// 在 Prologue 后保存 callee-saved 寄存器
for (const auto& cs : csSlots) {
lowered.emplace_back(Opcode::StoreStack,
std::vector<Operand>{Operand::Reg(cs.phys_reg),
Operand::FrameIndex(cs.frame_index)});
}
}
firstBlock = false;
// 输出基本块中的指令
for (const auto& inst : insts) {
DEBUG_MSG("inst");
if (inst.GetOpcode() == Opcode::Ret) {
DEBUG_MSG("empalace Epilogue");
// 在 Epilogue 前恢复 callee-saved 寄存器
for (const auto& cs : csSlots) {
lowered.emplace_back(Opcode::LoadStack,
std::vector<Operand>{Operand::Reg(cs.phys_reg),
Operand::FrameIndex(cs.frame_index)});
}
lowered.emplace_back(Opcode::Epilogue);
}
lowered.push_back(inst);
@ -60,13 +97,10 @@ void RunFrameLowering(MachineFunction& function) {
}
}
// 模块版本的栈帧布局
void RunFrameLowering(MachineModule& module) {
// 对模块中的每个函数执行栈帧布局
DEBUG_MSG("module RunFrameLowering");
for (auto& func : module.GetFunctions()) {
RunFrameLowering(*func);
}
}
} // namespace mir
} // namespace mir

File diff suppressed because it is too large Load Diff

@ -30,4 +30,14 @@ const FrameSlot& MachineFunction::GetFrameSlot(int index) const {
return frame_slots_[index];
}
int MachineFunction::CreateSpillSlot(int size) {
int index = CreateFrameIndex(size);
spill_slot_indices_.insert(index);
return index;
}
bool MachineFunction::IsSpillSlot(int index) const {
return spill_slot_indices_.count(index) > 0;
}
} // namespace mir

@ -9,6 +9,8 @@ Operand::Operand(Kind kind, PhysReg reg, int imm, CondCode cc, const std::string
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0, CondCode::EQ, ""); }
Operand Operand::VReg(int id) { return Operand(Kind::VReg, PhysReg::W0, id, CondCode::EQ, ""); }
Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::W0, value, CondCode::EQ, "");
}

@ -1,81 +1,488 @@
#include "mir/MIR.h"
#include <algorithm>
#include <map>
#include <set>
#include <stdexcept>
#include <unordered_map>
#include <vector>
#include "utils/Log.h"
namespace mir {
namespace {
bool IsAllowedReg(PhysReg reg) {
// ========== VReg 类型 ==========
enum class VRegClass { kInt32, kInt64, kFloat32 };
// ========== 活跃区间 ==========
struct LiveInterval {
int vreg;
int start;
int end;
VRegClass reg_class;
LiveInterval(int v, int s, int e, VRegClass rc)
: vreg(v), start(s), end(e), reg_class(rc) {}
};
// ========== 可分配物理寄存器池 ==========
// 整数 32位callee-saved 优先W19-W28然后 caller-savedW8-W13
// W14, W15 保留为 spill scratch
const PhysReg kGPR32Pool[] = {
PhysReg::W19, PhysReg::W20, PhysReg::W21, PhysReg::W22,
PhysReg::W23, PhysReg::W24, PhysReg::W25, PhysReg::W26,
PhysReg::W27, PhysReg::W28,
PhysReg::W8, PhysReg::W9, PhysReg::W10, PhysReg::W11,
PhysReg::W12, PhysReg::W13,
};
constexpr int kNumGPR32 = sizeof(kGPR32Pool) / sizeof(kGPR32Pool[0]);
// 整数 64位callee-saved 优先X19-X28然后 caller-savedX8-X13
// X14, X15 保留为 spill scratch
const PhysReg kGPR64Pool[] = {
PhysReg::X19, PhysReg::X20, PhysReg::X21, PhysReg::X22,
PhysReg::X23, PhysReg::X24, PhysReg::X25, PhysReg::X26,
PhysReg::X27, PhysReg::X28,
PhysReg::X8, PhysReg::X9, PhysReg::X10, PhysReg::X11,
PhysReg::X12, PhysReg::X13,
};
constexpr int kNumGPR64 = sizeof(kGPR64Pool) / sizeof(kGPR64Pool[0]);
// 浮点 32位: S8-S13 (callee-saved). S14-S15 保留为 spill scratch
// S16+ 是 caller-saved不能放入通用池外部调用会隐式 clobber
const PhysReg kFPR32Pool[] = {
PhysReg::S8, PhysReg::S9, PhysReg::S10, PhysReg::S11,
PhysReg::S12, PhysReg::S13,
};
constexpr int kNumFPR32 = sizeof(kFPR32Pool) / sizeof(kFPR32Pool[0]);
// Spill scratch registers (每个类型 2 个,避免多 spilled-vreg 冲突)
const PhysReg kSpillScratchInt32[] = { PhysReg::W15, PhysReg::W14 };
const PhysReg kSpillScratchInt64[] = { PhysReg::X15, PhysReg::X14 };
const PhysReg kSpillScratchFloat[] = { PhysReg::S15, PhysReg::S14 };
PhysReg GetSpillScratch(VRegClass rc, int idx) {
switch (rc) {
case VRegClass::kInt32: return kSpillScratchInt32[idx % 2];
case VRegClass::kInt64: return kSpillScratchInt64[idx % 2];
case VRegClass::kFloat32: return kSpillScratchFloat[idx % 2];
}
return kSpillScratchInt32[0];
}
// 判断是否为 callee-saved 寄存器
bool IsCalleeSaved(PhysReg reg) {
switch (reg) {
case PhysReg::W0:
case PhysReg::W8:
case PhysReg::W9:
case PhysReg::X29: //FP = X29 帧指针
case PhysReg::X30: //LR = X30 链接寄存器
case PhysReg::SP:
case PhysReg::W19: case PhysReg::W20: case PhysReg::W21: case PhysReg::W22:
case PhysReg::W23: case PhysReg::W24: case PhysReg::W25: case PhysReg::W26:
case PhysReg::W27: case PhysReg::W28:
case PhysReg::X19: case PhysReg::X20: case PhysReg::X21: case PhysReg::X22:
case PhysReg::X23: case PhysReg::X24: case PhysReg::X25: case PhysReg::X26:
case PhysReg::X27: case PhysReg::X28:
case PhysReg::S8: case PhysReg::S9: case PhysReg::S10: case PhysReg::S11:
case PhysReg::S12: case PhysReg::S13: case PhysReg::S14: case PhysReg::S15:
case PhysReg::S16: case PhysReg::S17: case PhysReg::S18: case PhysReg::S19:
case PhysReg::S20: case PhysReg::S21: case PhysReg::S22: case PhysReg::S23:
case PhysReg::S24: case PhysReg::S25: case PhysReg::S26: case PhysReg::S27:
case PhysReg::S28: case PhysReg::S29: case PhysReg::S30: case PhysReg::S31:
return true;
default: return false;
}
return false;
}
} // namespace
// 获取寄存器编号(用于 Wn/Xn 互斥检查)
int GetRegIndex(PhysReg reg) {
if (reg >= PhysReg::W0 && reg <= PhysReg::W30)
return static_cast<int>(reg) - static_cast<int>(PhysReg::W0);
if (reg >= PhysReg::X0 && reg <= PhysReg::X30)
return static_cast<int>(reg) - static_cast<int>(PhysReg::X0);
if (reg >= PhysReg::S0 && reg <= PhysReg::S31)
return 32 + static_cast<int>(reg) - static_cast<int>(PhysReg::S0);
return -1;
}
//void RunRegAlloc(MachineFunction& function) {
// for (const auto& inst : function.GetEntry().GetInstructions()) {
// for (const auto& operand : inst.GetOperands()) {
// if (operand.GetKind() == Operand::Kind::Reg &&
// !IsAllowedReg(operand.GetReg())) {
// throw std::runtime_error(FormatError("mir", "寄存器分配失败"));
// }
// }
// }
//}
// 单函数版本的寄存器分配(原有逻辑)
void RunRegAlloc(MachineFunction& function) {
// 当前仅执行最小一致性检查,不实现真实寄存器分配
// Lab3 阶段保持栈槽模型,不需要真实寄存器分配
// 检查每个基本块中的指令
// ========== 推断 vreg 类型(优先使用 Lowering 存储的类型) ==========
VRegClass InferVRegClass(int vreg, MachineFunction& function) {
if (function.HasVRegType(vreg)) {
switch (function.GetVRegType(vreg)) {
case MachineFunction::VRegType::kFloat32: return VRegClass::kFloat32;
case MachineFunction::VRegType::kInt64: return VRegClass::kInt64;
case MachineFunction::VRegType::kInt32: return VRegClass::kInt32;
}
}
return VRegClass::kInt32; // 默认(不应到达,因为 Lowering 覆盖所有 vreg
}
// ========== 指令编号 ==========
void NumberInstructions(MachineFunction& function,
std::unordered_map<MachineInstr*, int>& instrToIdx,
std::vector<MachineInstr*>& idxToInstr,
std::map<int, MachineBasicBlock*>& blockBoundary) {
int idx = 0;
for (auto& bb : function.GetBasicBlocks()) {
for (auto& instr : bb->GetInstructions()) {
// 检查指令的操作数是否有效
for (const auto& operand : instr.GetOperands()) {
switch (operand.GetKind()) {
case Operand::Kind::Reg:
// 寄存器操作数:检查是否在允许的范围内
// 当前使用固定寄存器 w0, w8, w9, s0, s1 等
break;
case Operand::Kind::FrameIndex:
// 栈槽索引:检查是否有效
if (operand.GetFrameIndex() < 0 ||
operand.GetFrameIndex() >= static_cast<int>(function.GetFrameSlots().size())) {
throw std::runtime_error(
FormatError("regalloc", "无效的栈槽索引: " +
std::to_string(operand.GetFrameIndex())));
blockBoundary[idx] = bb.get();
for (auto& inst : bb->GetInstructions()) {
instrToIdx[&inst] = idx;
idxToInstr.push_back(&inst);
++idx;
}
}
}
// ========== 构建活跃区间 ==========
std::vector<LiveInterval> ComputeLiveIntervals(MachineFunction& function) {
const auto& blocks = function.GetBasicBlocks();
if (blocks.empty()) return {};
// 编号
std::unordered_map<MachineInstr*, int> instrToIdx;
std::vector<MachineInstr*> idxToInstr;
std::map<int, MachineBasicBlock*> blockBoundary;
NumberInstructions(function, instrToIdx, idxToInstr, blockBoundary);
int total = static_cast<int>(idxToInstr.size());
// 收集所有 vreg
std::set<int> allVRegs;
for (auto* inst : idxToInstr) {
for (int d : inst->GetDefs()) allVRegs.insert(d);
for (int u : inst->GetUses()) allVRegs.insert(u);
}
// 每个 vreg 的活跃位置集合
std::unordered_map<int, std::set<int>> vregPositions;
// 基本块的 use/def 集合
struct BlockInfo {
std::set<int> use;
std::set<int> def;
int startIdx;
int endIdx;
};
std::unordered_map<MachineBasicBlock*, BlockInfo> blockInfo;
for (const auto& bb : blocks) {
auto& info = blockInfo[bb.get()];
// 找到块的首尾指令序号
auto& insts = bb->GetInstructions();
if (!insts.empty()) {
info.startIdx = instrToIdx[&insts.front()];
info.endIdx = instrToIdx[&insts.back()] + 1;
} else {
info.startIdx = 0;
info.endIdx = 0;
}
for (auto& inst : insts) {
int pos = instrToIdx[&inst];
for (int def : inst.GetDefs()) {
info.def.insert(def);
vregPositions[def].insert(pos);
}
for (int use : inst.GetUses()) {
if (info.def.count(use) == 0) {
info.use.insert(use);
}
vregPositions[use].insert(pos);
}
}
}
// 数据流分析: liveIn/liveOut
std::unordered_map<MachineBasicBlock*, std::set<int>> liveIn, liveOut;
bool changed = true;
while (changed) {
changed = false;
for (auto it = blocks.rbegin(); it != blocks.rend(); ++it) {
MachineBasicBlock* bb = it->get();
auto& info = blockInfo[bb];
// liveOut = union of successors' liveIn
std::set<int> newLiveOut;
for (auto* succ : bb->GetSuccessors()) {
for (int v : liveIn[succ]) newLiveOut.insert(v);
}
if (newLiveOut != liveOut[bb]) {
liveOut[bb] = newLiveOut;
changed = true;
}
// liveIn = use (liveOut - def)
std::set<int> newLiveIn = info.use;
for (int v : liveOut[bb]) {
if (info.def.count(v) == 0) newLiveIn.insert(v);
}
if (newLiveIn != liveIn[bb]) {
liveIn[bb] = newLiveIn;
changed = true;
}
}
}
// 生成 LiveInterval
// 注意:不将 liveIn 扩展到整个基本块的每个位置,因为线性扫描只需要
// [start, end] 区间。扩展会导致过长的活跃区间,造成不必要的 spill。
std::vector<LiveInterval> intervals;
for (int vreg : allVRegs) {
auto it = vregPositions.find(vreg);
if (it == vregPositions.end() || it->second.empty()) continue;
int start = *it->second.begin();
int end = *it->second.rbegin();
VRegClass rc = InferVRegClass(vreg, function);
intervals.emplace_back(vreg, start, end, rc);
}
std::sort(intervals.begin(), intervals.end(),
[](const LiveInterval& a, const LiveInterval& b) {
return a.start < b.start;
});
return intervals;
}
// ========== 寄存器池选择 ==========
const PhysReg* GetRegPool(VRegClass rc, int& count) {
switch (rc) {
case VRegClass::kInt32: count = kNumGPR32; return kGPR32Pool;
case VRegClass::kInt64: count = kNumGPR64; return kGPR64Pool;
case VRegClass::kFloat32: count = kNumFPR32; return kFPR32Pool;
}
count = 0;
return nullptr;
}
// ========== 活跃区间比较(按 end 排序,用于 active 集合) ==========
struct ByEnd {
bool operator()(const LiveInterval* a, const LiveInterval* b) const {
if (a->end != b->end) return a->end < b->end;
return a->vreg < b->vreg; // 打破平局:相同 end 按 vreg 区分
}
};
// ========== 线性扫描寄存器分配(单函数) ==========
void RunRegAllocFunc(MachineFunction& function) {
auto intervals = ComputeLiveIntervals(function);
if (intervals.empty()) return;
// vreg → 分配的物理寄存器
std::unordered_map<int, PhysReg> vregToPhys;
// 活跃区间集合(按 end 排序)
std::set<const LiveInterval*, ByEnd> active;
// Spill 槽vreg → FrameIndex
std::unordered_map<int, int> spillSlots;
// 寄存器占用跟踪
std::set<int> occupiedRegIndices;
// 每个寄存器池的空闲/占用状态
auto allocReg = [&](const LiveInterval& interval) -> PhysReg {
int poolSize = 0;
const PhysReg* pool = GetRegPool(interval.reg_class, poolSize);
for (int i = 0; i < poolSize; ++i) {
int idx = GetRegIndex(pool[i]);
if (occupiedRegIndices.count(idx) == 0) {
occupiedRegIndices.insert(idx);
if (IsCalleeSaved(pool[i])) {
function.MarkCalleeSaved(pool[i]);
}
return pool[i];
}
}
// 无法分配
return PhysReg::W0; // will trigger spill logic
};
auto freeReg = [&](PhysReg reg) {
occupiedRegIndices.erase(GetRegIndex(reg));
};
auto isFreeReg = [&](VRegClass rc) -> bool {
int poolSize = 0;
const PhysReg* pool = GetRegPool(rc, poolSize);
for (int i = 0; i < poolSize; ++i) {
if (occupiedRegIndices.count(GetRegIndex(pool[i])) == 0)
return true;
}
return false;
};
// 线性扫描
for (auto& interval : intervals) {
// 1. Expire old intervals
std::vector<const LiveInterval*> toRemove;
for (auto* act : active) {
if (act->end < interval.start) {
toRemove.push_back(act);
auto it = vregToPhys.find(act->vreg);
if (it != vregToPhys.end()) {
freeReg(it->second);
}
}
}
for (auto* act : toRemove) {
active.erase(act);
}
// 2. 尝试分配
if (isFreeReg(interval.reg_class)) {
PhysReg reg = allocReg(interval);
vregToPhys[interval.vreg] = reg;
active.insert(&interval);
} else {
// 3. Spill: 选择 active 中最晚结束的区间
if (active.empty()) {
// 所有寄存器都被占用Wn/Xn 别名冲突等边缘情况)
// 直接 spill 当前 interval
int slotSize = (interval.reg_class == VRegClass::kInt64) ? 8 : 4;
int slot = function.CreateSpillSlot(slotSize);
spillSlots[interval.vreg] = slot;
continue;
}
const LiveInterval* spillCand = *active.rbegin(); // 最晚结束
if (spillCand->end > interval.end) {
// Spill spillCand
PhysReg reg = vregToPhys[spillCand->vreg];
vregToPhys.erase(spillCand->vreg);
freeReg(reg);
active.erase(spillCand);
// 为其分配 spill slot
int slotSize = (spillCand->reg_class == VRegClass::kInt64) ? 8 : 4;
int slot = function.CreateSpillSlot(slotSize);
spillSlots[spillCand->vreg] = slot;
// 分配当前 interval
PhysReg newReg = allocReg(interval);
vregToPhys[interval.vreg] = newReg;
active.insert(&interval);
} else {
// Spill 当前 interval
int slotSize = (interval.reg_class == VRegClass::kInt64) ? 8 : 4;
int slot = function.CreateSpillSlot(slotSize);
spillSlots[interval.vreg] = slot;
// 不分配物理寄存器
}
}
}
// ========== 重写指令VReg → PhysReg + spill/reload ==========
for (auto& bb : function.GetBasicBlocks()) {
std::vector<MachineInstr> newInsts;
auto& insts = bb->GetInstructions();
for (auto& inst : insts) {
auto& ops = inst.GetOperands();
std::vector<int>& defs = inst.GetDefs();
std::vector<int>& uses = inst.GetUses();
// 辅助:获取 spilled vreg 的类型
auto getSpillRC = [&](int vreg) -> VRegClass {
for (auto& iv : intervals) {
if (iv.vreg == vreg) return iv.reg_class;
}
return VRegClass::kInt32;
};
// 收集此指令中需要 reload 的 spilled vreg去重
std::vector<int> spilledUses;
{
std::set<int> seen;
for (int vreg : uses) {
if (spillSlots.count(vreg) && seen.insert(vreg).second) {
spilledUses.push_back(vreg);
}
}
}
// === 插入 use 前的 reload每个 spilled vreg 用不同 scratch ===
for (size_t si = 0; si < spilledUses.size(); ++si) {
int vreg = spilledUses[si];
int slot = spillSlots[vreg];
PhysReg loadReg;
auto it = vregToPhys.find(vreg);
if (it != vregToPhys.end()) {
loadReg = it->second;
} else {
loadReg = GetSpillScratch(getSpillRC(vreg), static_cast<int>(si));
}
newInsts.emplace_back(Opcode::LoadStack,
std::vector<Operand>{Operand::Reg(loadReg), Operand::FrameIndex(slot)});
}
// === 替换 VReg 操作数为 PhysReg ===
// 跟踪每条指令中 spilled vreg 的 scratch 索引
int spillUseIdx = 0;
for (auto& op : ops) {
if (op.GetKind() == Operand::Kind::VReg) {
int vreg = op.GetVReg();
auto it = vregToPhys.find(vreg);
if (it != vregToPhys.end()) {
op = Operand::Reg(it->second);
} else {
// spilled 或未分配:用 spill scratch
int idx = 0;
if (spillSlots.count(vreg)) {
for (size_t si = 0; si < spilledUses.size(); ++si) {
if (spilledUses[si] == vreg) { idx = static_cast<int>(si); break; }
}
} else {
// 防御vreg 未在 vregToPhys 或 spillSlots 中,创建临时 spill slot
VRegClass rc = getSpillRC(vreg);
int slotSize = (rc == VRegClass::kInt64) ? 8 : 4;
int slot = function.CreateSpillSlot(slotSize);
spillSlots[vreg] = slot;
}
break;
case Operand::Kind::Imm:
case Operand::Kind::Cond:
case Operand::Kind::Label:
// 立即数、条件码、标签不需要检查
break;
op = Operand::Reg(GetSpillScratch(getSpillRC(vreg), idx));
spillUseIdx++;
}
}
}
newInsts.push_back(inst);
// === 插入 def 后的 store用于 spilled vreg ===
// 收集此指令中 spilled def vreg去重
std::vector<int> spilledDefs;
{
std::set<int> seen;
for (int vreg : defs) {
if (spillSlots.count(vreg) && seen.insert(vreg).second) {
spilledDefs.push_back(vreg);
}
}
}
for (size_t si = 0; si < spilledDefs.size(); ++si) {
int vreg = spilledDefs[si];
int slot = spillSlots[vreg];
PhysReg storeReg;
auto it = vregToPhys.find(vreg);
if (it != vregToPhys.end()) {
storeReg = it->second;
} else {
storeReg = GetSpillScratch(getSpillRC(vreg), static_cast<int>(si));
}
newInsts.emplace_back(Opcode::StoreStack,
std::vector<Operand>{Operand::Reg(storeReg), Operand::FrameIndex(slot)});
}
}
insts = std::move(newInsts);
}
// 清除所有指令的 def/useRA 完成后不再需要)
for (auto& bb : function.GetBasicBlocks()) {
for (auto& inst : bb->GetInstructions()) {
inst.GetDefs().clear();
inst.GetUses().clear();
}
}
// 注意Lab3 阶段不实现真实寄存器分配
// 所有值仍然使用栈槽模型,寄存器仅作为临时计算使用
}
// 模块版本的寄存器分配
} // namespace
void RunRegAlloc(MachineModule& module) {
// 对模块中的每个函数执行寄存器分配
for (auto& func : module.GetFunctions()) {
RunRegAlloc(*func);
RunRegAllocFunc(*func);
}
}

@ -83,11 +83,35 @@ const char* PhysRegName(PhysReg reg) {
case PhysReg::S5: return "s5";
case PhysReg::S6: return "s6";
case PhysReg::S7: return "s7";
case PhysReg::S8: return "s8";
case PhysReg::S9: return "s9";
case PhysReg::S10: return "s10";
case PhysReg::S11: return "s11";
case PhysReg::S12: return "s12";
case PhysReg::S13: return "s13";
case PhysReg::S14: return "s14";
case PhysReg::S15: return "s15";
case PhysReg::S16: return "s16";
case PhysReg::S17: return "s17";
case PhysReg::S18: return "s18";
case PhysReg::S19: return "s19";
case PhysReg::S20: return "s20";
case PhysReg::S21: return "s21";
case PhysReg::S22: return "s22";
case PhysReg::S23: return "s23";
case PhysReg::S24: return "s24";
case PhysReg::S25: return "s25";
case PhysReg::S26: return "s26";
case PhysReg::S27: return "s27";
case PhysReg::S28: return "s28";
case PhysReg::S29: return "s29";
case PhysReg::S30: return "s30";
case PhysReg::S31: return "s31";
// 特殊寄存器
case PhysReg::SP: return "sp";
case PhysReg::ZR: return "xzr";
default: return "unknown";
}
throw std::runtime_error(FormatError("mir", "未知物理寄存器"));

@ -1,4 +1,16 @@
// MIR Pass 管理:
// - 组织后端 pass 的运行顺序PreRA/PostRA/PEI 等阶段)
// - 统一运行 pass 与调试输出(按需要扩展)
// - 组织后端 pass 的运行顺序
// - 统一运行 pass 与调试输出
#include "mir/MIR.h"
namespace mir {
void RunPeephole(MachineModule& module);
void RunMIRPasses(MachineModule& module) {
// PeepholeRA 后局部优化
RunPeephole(module);
}
} // namespace mir

@ -1,4 +1,197 @@
// 窥孔优化Peephole
// - 删除冗余 move、合并常见指令模式
// - 提升最终汇编质量(按实现范围裁剪)
// - 提升最终汇编质量
#include "mir/MIR.h"
#include <algorithm>
#include <cstdint>
#include <set>
#include <string>
#include <vector>
#include "utils/Log.h"
namespace mir {
namespace {
// 检查指令是否有副作用(非纯计算)
bool HasSideEffects(const MachineInstr& inst) {
switch (inst.GetOpcode()) {
case Opcode::Call:
case Opcode::Ret:
case Opcode::B:
case Opcode::BCond:
case Opcode::StoreStack:
case Opcode::StoreStackPair:
case Opcode::Prologue:
case Opcode::Epilogue:
return true;
default:
return false;
}
}
// 检查是否是纯 move 指令
bool IsPureMove(const MachineInstr& inst) {
return inst.GetOpcode() == Opcode::MovReg;
}
// 检查指令是否使用了某个物理寄存器
bool InstUsesReg(const MachineInstr& inst, PhysReg reg) {
for (const auto& op : inst.GetOperands()) {
if (op.GetKind() == Operand::Kind::Reg && op.GetReg() == reg)
return true;
}
return false;
}
// 检查指令是否定义了某个物理寄存器
bool InstDefsReg(const MachineInstr& inst, PhysReg reg) {
// 大多数指令的 dest 是第一个操作数
if (inst.GetOperands().empty()) return false;
const auto& dst = inst.GetOperands()[0];
if (dst.GetKind() == Operand::Kind::Reg && dst.GetReg() == reg)
return true;
// StoreStackPair / LoadStackPair 有特殊格式
return false;
}
// 检查是否恒等操作
bool IsIdentityOp(const MachineInstr& inst) {
if (inst.GetOperands().size() < 3) return false;
const auto& op2 = inst.GetOperands()[2];
if (op2.GetKind() != Operand::Kind::Imm) return false;
if (op2.GetImm() != 0) return false;
switch (inst.GetOpcode()) {
case Opcode::AddRI:
case Opcode::SubRI:
return true;
default:
return false;
}
}
// 检查两个指令操作相同的栈偏移
bool IsSameStackOffset(const MachineInstr& a, const MachineInstr& b) {
if (a.GetOperands().size() < 2 || b.GetOperands().size() < 2) return false;
const auto& aOff = a.GetOperands()[1];
const auto& bOff = b.GetOperands()[1];
if (aOff.GetKind() == Operand::Kind::FrameIndex &&
bOff.GetKind() == Operand::Kind::FrameIndex) {
return aOff.GetFrameIndex() == bOff.GetFrameIndex();
}
return false;
}
// 单基本块窥孔优化(一次扫描)
int PeepholeBlock(MachineBasicBlock& bb) {
auto& insts = bb.GetInstructions();
int changes = 0;
bool changed = true;
// 迭代直到收敛
while (changed) {
changed = false;
std::vector<MachineInstr> newInsts;
size_t n = insts.size();
for (size_t i = 0; i < n; ++i) {
MachineInstr& curr = insts[i];
// 跳过已标记删除的指令(通过空操作码)
if (curr.GetOpcode() == Opcode::Nop && curr.GetOperands().empty()) {
// 跳过(已经是 nop 但被标记删除)
if (curr.GetOperands().empty()) continue;
}
// --- 规则1: 恒等操作消除 add/sub ..., #0 → mov ---
if (IsIdentityOp(curr)) {
const auto& dst = curr.GetOperands()[0];
const auto& src = curr.GetOperands()[1];
if (dst.GetKind() == Operand::Kind::Reg && src.GetKind() == Operand::Kind::Reg) {
MachineInstr mov(Opcode::MovReg,
std::vector<Operand>{dst, Operand::Reg(src.GetReg())});
newInsts.push_back(mov);
changed = true;
++changes;
continue;
}
}
// --- 规则2: mov wA, wA → 删除(自赋值) ---
if (IsPureMove(curr) && curr.GetOperands().size() >= 2) {
const auto& dst = curr.GetOperands()[0];
const auto& src = curr.GetOperands()[1];
if (dst.GetKind() == Operand::Kind::Reg && src.GetKind() == Operand::Kind::Reg &&
dst.GetReg() == src.GetReg()) {
changed = true;
++changes;
continue; // 删除
}
}
// --- 规则3: 冗余 mov → 删除第一条 ---
// mov wA, wB; mov wA, wC → 删除第一条(如果中间无其他使用 wA
if (IsPureMove(curr) && i + 1 < n) {
const auto& dst0 = curr.GetOperands()[0];
MachineInstr& next = insts[i + 1];
if (IsPureMove(next) && next.GetOperands().size() >= 2) {
const auto& dst1 = next.GetOperands()[0];
if (dst0.GetKind() == Operand::Kind::Reg &&
dst1.GetKind() == Operand::Kind::Reg &&
dst0.GetReg() == dst1.GetReg()) {
// 第一条 mov 的 dest 在第一条之后、第二条之前没有被使用
// (两条相邻,中间无其他指令)
changed = true;
++changes;
continue; // 删除第一条
}
}
}
// --- 规则4: Load after Store 消除 ---
// stur wA, [x29, #n]; ldur wB, [x29, #n] → mov wB, wA
if (curr.GetOpcode() == Opcode::StoreStack && i + 1 < n) {
MachineInstr& next = insts[i + 1];
if (next.GetOpcode() == Opcode::LoadStack &&
IsSameStackOffset(curr, next)) {
const auto& storeVal = curr.GetOperands()[0];
const auto& loadDst = next.GetOperands()[0];
if (storeVal.GetKind() == Operand::Kind::Reg &&
loadDst.GetKind() == Operand::Kind::Reg) {
MachineInstr mov(Opcode::MovReg,
std::vector<Operand>{loadDst, Operand::Reg(storeVal.GetReg())});
newInsts.push_back(curr); // 保留 store
newInsts.push_back(mov); // mov 替换 load
++i; // 跳过 next
changed = true;
++changes;
continue;
}
}
}
newInsts.push_back(curr);
}
insts = std::move(newInsts);
}
return changes;
}
} // namespace
// ========== RunPeephole模块版本 ==========
void RunPeephole(MachineModule& module) {
int totalChanges = 0;
for (auto& func : module.GetFunctions()) {
for (auto& bb : func->GetBasicBlocks()) {
totalChanges += PeepholeBlock(*bb);
}
}
}
} // namespace mir

Loading…
Cancel
Save