From 0d170d1af8dec934181cd90ec5e3f314b2196405 Mon Sep 17 00:00:00 2001 From: mxr <> Date: Fri, 22 May 2026 16:27:10 +0800 Subject: [PATCH] =?UTF-8?q?feat(ra)=E5=88=9D=E6=AD=A5=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/mir/MIR.h | 78 +- src/main.cpp | 5 +- src/mir/AsmPrinter.cpp | 373 ++++-- src/mir/FrameLowering.cpp | 80 +- src/mir/Lowering.cpp | 1947 +++++++++++++++----------------- src/mir/MIRFunction.cpp | 10 + src/mir/MIRInstr.cpp | 2 + src/mir/RegAlloc.cpp | 513 ++++++++- src/mir/Register.cpp | 28 +- src/mir/passes/PassManager.cpp | 16 +- src/mir/passes/Peephole.cpp | 195 +++- 11 files changed, 1968 insertions(+), 1279 deletions(-) diff --git a/include/mir/MIR.h b/include/mir/MIR.h index 67ef476..cc8ab6d 100644 --- a/include/mir/MIR.h +++ b/include/mir/MIR.h @@ -1,9 +1,12 @@ #pragma once +#include #include #include #include +#include #include +#include #include namespace ir { @@ -148,9 +151,10 @@ enum class Opcode { // ========== 操作数类 ========== class Operand { public: - enum class Kind { Reg, Imm, FrameIndex, Cond, Label }; + enum class Kind { Reg, VReg, Imm, FrameIndex, Cond, Label }; static Operand Reg(PhysReg reg); + static Operand VReg(int id); static Operand Imm(int value); static Operand FrameIndex(int index); static Operand Cond(CondCode cc); @@ -158,11 +162,15 @@ class Operand { Kind GetKind() const { return kind_; } PhysReg GetReg() const { return reg_; } + int GetVReg() const { return imm_; } int GetImm() const { return imm_; } int GetFrameIndex() const { return imm_; } CondCode GetCondCode() const { return cc_; } const std::string& GetLabel() const { return label_; } + bool IsVReg() const { return kind_ == Kind::VReg; } + bool IsPhysReg() const { return kind_ == Kind::Reg; } + private: Operand(Kind kind, PhysReg reg, int imm, CondCode cc, const std::string& label); @@ -180,10 +188,28 @@ class MachineInstr { Opcode GetOpcode() const { return opcode_; } const std::vector& GetOperands() const { return operands_; } + std::vector& GetOperands() { return operands_; } + + // def/use 信息(用于活跃性分析) + const std::vector& GetDefs() const { return defs_; } + const std::vector& GetUses() const { return uses_; } + std::vector& GetDefs() { return defs_; } + std::vector& GetUses() { return uses_; } + void AddDef(int vreg) { defs_.push_back(vreg); } + void AddUse(int vreg) { uses_.push_back(vreg); } + + // 指令分类 + bool IsCall() const { return opcode_ == Opcode::Call; } + bool IsTerminator() const { + return opcode_ == Opcode::B || opcode_ == Opcode::BCond || opcode_ == Opcode::Ret; + } + bool IsMove() const { return opcode_ == Opcode::MovReg; } private: Opcode opcode_; std::vector operands_; + std::vector defs_; + std::vector uses_; }; // ========== 栈槽结构 ========== @@ -211,10 +237,15 @@ class MachineBasicBlock { const std::vector& GetSuccessors() const { return successors_; } void AddSuccessor(MachineBasicBlock* succ) { successors_.push_back(succ); } + std::vector& GetPredecessors() { return predecessors_; } + const std::vector& GetPredecessors() const { return predecessors_; } + void AddPredecessor(MachineBasicBlock* pred) { predecessors_.push_back(pred); } + private: std::string name_; std::vector instructions_; std::vector successors_; + std::vector predecessors_; }; // ========== MIR 函数 ========== @@ -223,39 +254,69 @@ class MachineFunction { explicit MachineFunction(std::string name); const std::string& GetName() const { return name_; } - + // 基本块管理 MachineBasicBlock& GetEntry() { return entry_; } const MachineBasicBlock& GetEntry() const { return entry_; } - std::vector>& GetBasicBlocks() { - return basic_blocks_; + std::vector>& GetBasicBlocks() { + return basic_blocks_; } const std::vector>& GetBasicBlocks() const { return basic_blocks_; } - + void AddBasicBlock(std::unique_ptr bb) { basic_blocks_.push_back(std::move(bb)); } + MachineBasicBlock* GetBlockByName(const std::string& name) { + for (auto& bb : basic_blocks_) { + if (bb->GetName() == name) return bb.get(); + } + return nullptr; + } + // 栈槽管理 int CreateFrameIndex(int size = 4); FrameSlot& GetFrameSlot(int index); const FrameSlot& GetFrameSlot(int index) const; std::vector& GetFrameSlots() { return frame_slots_; } const std::vector& GetFrameSlots() const { return frame_slots_; } - + // 栈帧大小 int GetFrameSize() const { return frame_size_; } void SetFrameSize(int size) { frame_size_ = size; } + // callee-saved 寄存器管理 + void MarkCalleeSaved(PhysReg reg) { used_callee_saved_regs_.insert(reg); } + const std::set& GetCalleeSavedRegs() const { return used_callee_saved_regs_; } + bool IsCalleeSavedUsed(PhysReg reg) const { + return used_callee_saved_regs_.count(reg) > 0; + } + + // spill 槽管理 + int CreateSpillSlot(int size = 4); + bool IsSpillSlot(int index) const; + + // vreg 类型管理(由 Lowering 填充,RA 使用) + enum class VRegType : uint8_t { kInt32 = 0, kInt64 = 1, kFloat32 = 2 }; + void SetVRegType(int vreg, VRegType type) { vreg_types_[vreg] = type; } + VRegType GetVRegType(int vreg) const { + auto it = vreg_types_.find(vreg); + return it != vreg_types_.end() ? it->second : VRegType::kInt32; + } + bool HasVRegType(int vreg) const { return vreg_types_.count(vreg) > 0; } + private: std::string name_; MachineBasicBlock entry_; std::vector> basic_blocks_; std::vector frame_slots_; + std::set spill_slot_indices_; int frame_size_ = 0; + std::set used_callee_saved_regs_; + std::unordered_map vreg_types_; }; // ========== MIR 模块 ========== @@ -324,12 +385,9 @@ class MachineModule { }; // ========== 后端流程函数 ========== -/* std::unique_ptr LowerToMIR(const ir::Module& module); -void RunRegAlloc(MachineFunction& function); -void RunFrameLowering(MachineFunction& function); -void PrintAsm(const MachineFunction& function, std::ostream& os); */ std::unique_ptr LowerToMIR(const ir::Module& module); void RunRegAlloc(MachineModule& module); +void RunMIRPasses(MachineModule& module); void RunFrameLowering(MachineModule& module); void PrintAsm(const MachineModule& module, std::ostream& os); diff --git a/src/main.cpp b/src/main.cpp index c54f087..730ce5d 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -46,16 +46,13 @@ int main(int argc, char** argv) { } if (opts.emit_asm) { - //auto machine_func = mir::LowerToMIR(*module); auto machine_module = mir::LowerToMIR(*module); - //mir::RunRegAlloc(*machine_func); mir::RunRegAlloc(*machine_module); - //mir::RunFrameLowering(*machine_func); + mir::RunMIRPasses(*machine_module); mir::RunFrameLowering(*machine_module); if (need_blank_line) { std::cout << "\n"; } - //mir::PrintAsm(*machine_func, std::cout); mir::PrintAsm(*machine_module, std::cout); } #else diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index b5d04f7..9d60b7a 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -64,6 +64,10 @@ void PrintOperand(std::ostream& os, const Operand& op) { case Operand::Kind::Reg: os << PhysRegName(op.GetReg()); break; + case Operand::Kind::VReg: + throw std::runtime_error( + FormatError("asm", "寄存器分配未完成: 存在虚拟寄存器 #" + + std::to_string(op.GetVReg()))); case Operand::Kind::Imm: os << "#" << op.GetImm(); break; @@ -88,6 +92,57 @@ static bool IsLegalAddSubImm(int64_t imm) { return false; } +// ---- 寄存器宽度规范化 ---- +static bool IsWReg(PhysReg reg) { + return reg >= PhysReg::W0 && reg <= PhysReg::W30; +} +static bool IsXReg(PhysReg reg) { + return reg >= PhysReg::X0 && reg <= PhysReg::X30; +} +static bool IsSReg(PhysReg reg) { + return reg >= PhysReg::S0 && reg <= PhysReg::S31; +} + +// Xn → Wn, Wn → Wn, Sn → Sn +static PhysReg ToW(PhysReg reg) { + if (IsXReg(reg)) + return static_cast( + static_cast(reg) - static_cast(PhysReg::X0) + static_cast(PhysReg::W0)); + return reg; +} +// Wn → Xn, Xn → Xn, Sn → Sn +static PhysReg ToX(PhysReg reg) { + if (IsWReg(reg)) + return static_cast( + static_cast(reg) - static_cast(PhysReg::W0) + static_cast(PhysReg::X0)); + return reg; +} + +// 检查一组操作数是否全是同一宽度(W/X/S) +static bool AllSameRegWidth(const std::vector& ops) { + int kind = -1; + for (const auto& op : ops) { + if (op.GetKind() != Operand::Kind::Reg) continue; + PhysReg r = op.GetReg(); + if (IsWReg(r)) { if (kind == -1) kind = 0; else if (kind != 0) return false; } + else if (IsXReg(r)) { if (kind == -1) kind = 1; else if (kind != 1) return false; } + else if (IsSReg(r)) { if (kind == -1) kind = 2; else if (kind != 2) return false; } + } + return true; +} + +// 根据目的地宽度规范化所有寄存器操作数 +static void NormalizeRegOps(std::vector& ops, PhysReg dst) { + PhysReg base = dst; + bool wantW = IsWReg(base); + bool wantX = IsXReg(base); + for (auto& op : ops) { + if (op.GetKind() != Operand::Kind::Reg) continue; + if (wantW) op = Operand::Reg(ToW(op.GetReg())); + else if (wantX) op = Operand::Reg(ToX(op.GetReg())); + } +} + // 在匿名命名空间添加辅助函数 static void PrintLoadImm64(std::ostream& os, PhysReg reg, uint64_t imm) { // 输出 movz + movk 序列 @@ -148,37 +203,48 @@ void PrintInstruction(std::ostream& os, const MachineInstr& instr, os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #" << ops.at(1).GetImm() << "\n"; break; - case Opcode::MovReg: - os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << "\n"; + case Opcode::MovReg:{ + PhysReg dst = ops.at(0).GetReg(); + PhysReg src = ops.at(1).GetReg(); + if (IsSReg(dst) || IsSReg(src)) { + // 涉及 S 寄存器的 move:使用 fmov + if (!IsSReg(dst)) dst = ToW(dst); // 确保是 W 寄存器 + if (!IsSReg(src)) src = ToW(src); + os << " fmov " << PhysRegName(dst) << ", " << PhysRegName(src) << "\n"; + } else { + // GPR move:规范化宽度 + if (IsWReg(dst) && IsXReg(src)) { + src = ToW(src); + } else if (IsXReg(dst) && IsWReg(src)) { + src = ToX(src); + } + os << " mov " << PhysRegName(dst) << ", " << PhysRegName(src) << "\n"; + } break; + } case Opcode::StoreStack: { - // 检查第二个操作数的类型 if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::FrameIndex) { - // 存储到栈槽 const auto& slot = GetFrameSlot(function, ops.at(1)); PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset); } else if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::Reg) { - // 间接存储:存储到寄存器指向的地址 - // STR W9, [X8] + // 间接存储:基址必须是 X 寄存器 + PhysReg base = ToX(ops.at(1).GetReg()); os << " str " << PhysRegName(ops.at(0).GetReg()) << ", [" - << PhysRegName(ops.at(1).GetReg()) << "]\n"; + << PhysRegName(base) << "]\n"; } else { throw std::runtime_error("StoreStack: 无效的操作数类型"); } break; } case Opcode::LoadStack: { - // 检查第二个操作数的类型 if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::FrameIndex) { - // 从栈槽加载 const auto& slot = GetFrameSlot(function, ops.at(1)); PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset); } else if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::Reg) { - // 间接加载:从寄存器指向的地址加载 - // LDR W9, [X8] + // 间接加载:基址必须是 X 寄存器 + PhysReg base = ToX(ops.at(1).GetReg()); os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", [" - << PhysRegName(ops.at(1).GetReg()) << "]\n"; + << PhysRegName(base) << "]\n"; } else { throw std::runtime_error("LoadStack: 无效的操作数类型"); } @@ -204,115 +270,181 @@ void PrintInstruction(std::ostream& os, const MachineInstr& instr, } os << "\n"; break; - case Opcode::AddRR: - os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::AddRI: - os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", #" - << ops.at(2).GetImm() << "\n"; - break; - case Opcode::SubRR: - os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::SubRI: - os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", #" - << ops.at(2).GetImm() << "\n"; - break; - case Opcode::MulRR: - os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::SDivRR: - os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::UDivRR: - os << " udiv " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::FAddRR: - os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::FSubRR: - os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::FMulRR: - os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::FDivRR: - os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::CmpRR: - os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << "\n"; - break; - case Opcode::CmpRI: - os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", #" - << ops.at(1).GetImm() << "\n"; + case Opcode::AddRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " add " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", " + << PhysRegName(nops[2].GetReg()) << "\n"; + break; + } + case Opcode::AddRI: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " add " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", #" + << nops[2].GetImm() << "\n"; break; - case Opcode::FCmpRR: - os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << "\n"; + } + case Opcode::SubRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " sub " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", " + << PhysRegName(nops[2].GetReg()) << "\n"; break; - case Opcode::SIToFP: - os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << "\n"; + } + case Opcode::SubRI: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " sub " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", #" + << nops[2].GetImm() << "\n"; break; - case Opcode::FPToSI: - os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << "\n"; + } + case Opcode::MulRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " mul " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", " + << PhysRegName(nops[2].GetReg()) << "\n"; break; + } + case Opcode::SDivRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " sdiv " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", " + << PhysRegName(nops[2].GetReg()) << "\n"; + break; + } + case Opcode::UDivRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " udiv " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", " + << PhysRegName(nops[2].GetReg()) << "\n"; + break; + } + case Opcode::FAddRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " fadd " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", " + << PhysRegName(nops[2].GetReg()) << "\n"; + break; + } + case Opcode::FSubRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " fsub " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", " + << PhysRegName(nops[2].GetReg()) << "\n"; + break; + } + case Opcode::FMulRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " fmul " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", " + << PhysRegName(nops[2].GetReg()) << "\n"; + break; + } + case Opcode::FDivRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " fdiv " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", " + << PhysRegName(nops[2].GetReg()) << "\n"; + break; + } + case Opcode::CmpRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " cmp " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << "\n"; + break; + } + case Opcode::CmpRI: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " cmp " << PhysRegName(nops[0].GetReg()) << ", #" + << nops[1].GetImm() << "\n"; + break; + } + case Opcode::FCmpRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " fcmp " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << "\n"; + break; + } + case Opcode::SIToFP: { + PhysReg dst = ops.at(0).GetReg(); + PhysReg src = ops.at(1).GetReg(); + if (!IsWReg(src)) src = ToW(src); + os << " scvtf " << PhysRegName(dst) << ", " << PhysRegName(src) << "\n"; + break; + } + case Opcode::FPToSI: { + PhysReg dst = ops.at(0).GetReg(); + PhysReg src = ops.at(1).GetReg(); + if (!IsWReg(dst)) dst = ToW(dst); + os << " fcvtzs " << PhysRegName(dst) << ", " << PhysRegName(src) << "\n"; + break; + } case Opcode::ZExt: os << " and " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", #1\n"; break; - case Opcode::AndRR: - os << " and " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::OrRR: - os << " orr " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::EorRR: - os << " eor " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::LslRR: - os << " lsl " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::LsrRR: - os << " lsr " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::AsrRR: - os << " asr " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; + case Opcode::AndRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " and " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", " + << PhysRegName(nops[2].GetReg()) << "\n"; + break; + } + case Opcode::OrRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " orr " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", " + << PhysRegName(nops[2].GetReg()) << "\n"; + break; + } + case Opcode::EorRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " eor " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", " + << PhysRegName(nops[2].GetReg()) << "\n"; + break; + } + case Opcode::LslRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " lsl " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", " + << PhysRegName(nops[2].GetReg()) << "\n"; + break; + } + case Opcode::LsrRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " lsr " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", " + << PhysRegName(nops[2].GetReg()) << "\n"; break; + } + case Opcode::AsrRR: { + std::vector nops = ops; + NormalizeRegOps(nops, nops[0].GetReg()); + os << " asr " << PhysRegName(nops[0].GetReg()) << ", " + << PhysRegName(nops[1].GetReg()) << ", " + << PhysRegName(nops[2].GetReg()) << "\n"; + break; + } case Opcode::B: os << " b "; PrintOperand(os, ops.at(0)); @@ -345,8 +477,8 @@ void PrintInstruction(std::ostream& os, const MachineInstr& instr, break; case Opcode::LoadStackAddr: { const FrameSlot& slot = GetFrameSlot(function, ops.at(1)); - int64_t offset = slot.offset; // 负值,如 -8 - PhysReg dst = ops.at(0).GetReg(); + int64_t offset = slot.offset; + PhysReg dst = ToX(ops.at(0).GetReg()); // 地址必须是 X 寄存器 auto tryEmitSimple = [&]() -> bool { if (offset >= 0 && offset <= 4095) { @@ -384,10 +516,15 @@ void PrintInstruction(std::ostream& os, const MachineInstr& instr, << ops.at(2).GetLabel() << "\n"; break; } - case Opcode::Sxtw: - os << " sxtw " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << "\n"; + case Opcode::Sxtw: { + PhysReg dst = ops.at(0).GetReg(); + PhysReg src = ops.at(1).GetReg(); + // sxtw 要求 X 目标,W 源 + if (!IsXReg(dst)) dst = ToX(dst); + if (!IsWReg(src)) src = ToW(src); + os << " sxtw " << PhysRegName(dst) << ", " << PhysRegName(src) << "\n"; break; + } default: os << " // unknown instruction\n"; break; diff --git a/src/mir/FrameLowering.cpp b/src/mir/FrameLowering.cpp index 86cd502..40666a4 100644 --- a/src/mir/FrameLowering.cpp +++ b/src/mir/FrameLowering.cpp @@ -1,19 +1,11 @@ #include "mir/MIR.h" +#include #include #include #include "utils/Log.h" -//#define DEBUG_Frame - -#ifdef DEBUG_Frame -#include -#define DEBUG_MSG(msg) std::cerr << "[Frame Debug] " << msg << std::endl -#else -#define DEBUG_MSG(msg) -#endif - namespace mir { namespace { @@ -21,10 +13,47 @@ int AlignTo(int value, int align) { return ((value + align - 1) / align) * align; } +// 收集排序后的 callee-saved 寄存器列表 +// 返回:{ (physReg, frameIndex) } 对 +struct CSRegSlot { + PhysReg phys_reg; + int frame_index; + int size; // 8 for x registers, 4 for s registers +}; + +std::vector CollectCalleeSavedSlots(MachineFunction& function) { + std::vector slots; + const auto& regs = function.GetCalleeSavedRegs(); + + // 整数 callee-saved (X19-X28 格式,每个 8 字节) + for (int i = 19; i <= 28; ++i) { + PhysReg xreg = static_cast(static_cast(PhysReg::X19) + (i - 19)); + PhysReg wreg = static_cast(static_cast(PhysReg::W19) + (i - 19)); + if (regs.count(wreg) || regs.count(xreg)) { + int slot = function.CreateFrameIndex(8); + slots.push_back({xreg, slot, 8}); + } + } + + // 浮点 callee-saved (S8-S15,每个 4 字节) + for (int i = 8; i <= 15; ++i) { + PhysReg sreg = static_cast(static_cast(PhysReg::S8) + (i - 8)); + if (regs.count(sreg)) { + int slot = function.CreateFrameIndex(8); + slots.push_back({sreg, slot, 8}); + } + } + + return slots; +} + } // namespace void RunFrameLowering(MachineFunction& function) { - DEBUG_MSG("function RunFrameLowering"); + // 收集 callee-saved 寄存器并分配栈槽 + auto csSlots = CollectCalleeSavedSlots(function); + + // 计算栈槽偏移 int cursor = 0; for (const auto& slot : function.GetFrameSlots()) { cursor += slot.size; @@ -32,26 +61,34 @@ void RunFrameLowering(MachineFunction& function) { } function.SetFrameSize(AlignTo(cursor, 16)); - // 基本块 + // 插入 Prologue / Epilogue const auto& blocks = function.GetBasicBlocks(); bool firstBlock = true; - + for (const auto& bb : blocks) { - DEBUG_MSG("block"); auto& insts = bb->GetInstructions(); std::vector lowered; - // 输出基本块标签(非第一个基本块) + if (firstBlock) { - DEBUG_MSG("empalace Prologue"); - lowered.emplace_back(Opcode::Prologue); + lowered.emplace_back(Opcode::Prologue); + + // 在 Prologue 后保存 callee-saved 寄存器 + for (const auto& cs : csSlots) { + lowered.emplace_back(Opcode::StoreStack, + std::vector{Operand::Reg(cs.phys_reg), + Operand::FrameIndex(cs.frame_index)}); + } } firstBlock = false; - // 输出基本块中的指令 for (const auto& inst : insts) { - DEBUG_MSG("inst"); if (inst.GetOpcode() == Opcode::Ret) { - DEBUG_MSG("empalace Epilogue"); + // 在 Epilogue 前恢复 callee-saved 寄存器 + for (const auto& cs : csSlots) { + lowered.emplace_back(Opcode::LoadStack, + std::vector{Operand::Reg(cs.phys_reg), + Operand::FrameIndex(cs.frame_index)}); + } lowered.emplace_back(Opcode::Epilogue); } lowered.push_back(inst); @@ -60,13 +97,10 @@ void RunFrameLowering(MachineFunction& function) { } } -// 模块版本的栈帧布局 void RunFrameLowering(MachineModule& module) { - // 对模块中的每个函数执行栈帧布局 - DEBUG_MSG("module RunFrameLowering"); for (auto& func : module.GetFunctions()) { RunFrameLowering(*func); } } -} // namespace mir \ No newline at end of file +} // namespace mir diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index 0817fbb..d63bc9c 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -1,8 +1,9 @@ #include "mir/MIR.h" +#include #include #include -#include +#include #include "ir/IR.h" #include "utils/Log.h" @@ -19,8 +20,51 @@ namespace mir { namespace { -using ValueSlotMap = std::unordered_map; +// ========== VReg 类型 ========== +enum class VRegType { kInt32, kInt64, kFloat32 }; + +// ========== VReg 上下文:管理虚拟寄存器分配和 IR Value 映射 ========== +struct VRegContext { + int next_vreg = 1; + std::unordered_map value_to_vreg; + std::unordered_map vreg_types; + + int NewVReg(VRegType type) { + int id = next_vreg++; + vreg_types[id] = type; + return id; + } + + VRegType GetType(int vreg) const { + auto it = vreg_types.find(vreg); + return it != vreg_types.end() ? it->second : VRegType::kInt32; + } + + bool IsFloat(int vreg) const { + auto it = vreg_types.find(vreg); + return it != vreg_types.end() && it->second == VRegType::kFloat32; + } + + bool IsPointer(int vreg) const { + auto it = vreg_types.find(vreg); + return it != vreg_types.end() && it->second == VRegType::kInt64; + } + void SetVReg(const ir::Value* value, int vreg) { + value_to_vreg[value] = vreg; + } + + int GetVReg(const ir::Value* value) const { + auto it = value_to_vreg.find(value); + return it != value_to_vreg.end() ? it->second : -1; + } + + bool HasVReg(const ir::Value* value) const { + return value_to_vreg.count(value) > 0; + } +}; + +// ========== 辅助:判断类型 ========== inline bool IsInt32Type(const ir::Type* type) { return type && type->IsInt32() && type->Size() == 4; } @@ -31,14 +75,160 @@ static uint32_t FloatToBits(float f) { return bits; } -// 获取类型大小(字节) int GetTypeSize(const ir::Type* type) { if (!type) return 4; size_t size = type->Size(); return size > 0 ? static_cast(size) : 4; } -// 将 IR 整数比较谓词转换为 ARMv8 条件码 +VRegType GetVRegTypeForIRType(const ir::Type* type) { + if (!type) return VRegType::kInt32; + if (type->IsFloat()) return VRegType::kFloat32; + if (type->IsPtrInt32() || type->IsPtrFloat() || type->IsPtrInt1() || type->IsArray()) + return VRegType::kInt64; + return VRegType::kInt32; +} + +// ========== 辅助:记录 MachineInstr 的 def/use ========== +void RecordDefUse(MachineInstr& instr, int def, std::initializer_list uses) { + if (def > 0) instr.AddDef(def); + for (int u : uses) { + if (u > 0) instr.AddUse(u); + } +} + +void RecordDefUseVec(MachineInstr& instr, const std::vector& defs, + const std::vector& uses) { + for (int d : defs) instr.AddDef(d); + for (int u : uses) instr.AddUse(u); +} + +// ========== 辅助:发射立即数到虚拟寄存器 ========== +void EmitMovImm(int vreg, uint32_t imm, MachineBasicBlock& block) { + auto isLegalMovImm = [](uint32_t v) -> bool { + return (v & 0xFFFF) == v || (v & 0xFFFF0000) == v; + }; + + if (isLegalMovImm(imm)) { + auto& instr = block.Append(Opcode::MovImm, + {Operand::VReg(vreg), Operand::Imm(static_cast(imm))}); + instr.AddDef(vreg); + } else { + uint16_t low = imm & 0xFFFF; + uint16_t high = (imm >> 16) & 0xFFFF; + auto& i1 = block.Append(Opcode::MovImm, + {Operand::VReg(vreg), Operand::Imm(low)}); + i1.AddDef(vreg); + if (high != 0) { + auto& i2 = block.Append(Opcode::Movk, + {Operand::VReg(vreg), Operand::Imm(high), Operand::Imm(16)}); + i2.AddDef(vreg); + i2.AddUse(vreg); + } + } +} + +void EmitMovImm64(int vreg, uint64_t imm, MachineBasicBlock& block) { + uint16_t part0 = imm & 0xFFFF; + uint16_t part1 = (imm >> 16) & 0xFFFF; + uint16_t part2 = (imm >> 32) & 0xFFFF; + uint16_t part3 = (imm >> 48) & 0xFFFF; + + auto& i1 = block.Append(Opcode::MovImm, + {Operand::VReg(vreg), Operand::Imm(part0)}); + i1.AddDef(vreg); + if (part1 != 0) { + auto& i2 = block.Append(Opcode::Movk, + {Operand::VReg(vreg), Operand::Imm(part1), Operand::Imm(16)}); + i2.AddDef(vreg); + i2.AddUse(vreg); + } + if (part2 != 0) { + auto& i3 = block.Append(Opcode::Movk, + {Operand::VReg(vreg), Operand::Imm(part2), Operand::Imm(32)}); + i3.AddDef(vreg); + i3.AddUse(vreg); + } + if (part3 != 0) { + auto& i4 = block.Append(Opcode::Movk, + {Operand::VReg(vreg), Operand::Imm(part3), Operand::Imm(48)}); + i4.AddDef(vreg); + i4.AddUse(vreg); + } +} + +// ========== 核心:将 IR Value 转换为虚拟寄存器 ========== +int EmitValueToVReg(const ir::Value* value, VRegContext& ctx, + MachineBasicBlock& block, MachineFunction& function) { + // 已经映射的值直接返回 + if (ctx.HasVReg(value)) { + return ctx.GetVReg(value); + } + + // 整数常量 + if (auto* constant = dynamic_cast(value)) { + uint32_t imm = static_cast(constant->GetValue()); + int vreg = ctx.NewVReg(VRegType::kInt32); + EmitMovImm(vreg, imm, block); + ctx.SetVReg(value, vreg); + return vreg; + } + + // 浮点常量(需要经过栈槽:整数bit→栈→浮点load) + if (auto* fconstant = dynamic_cast(value)) { + float fval = fconstant->GetValue(); + uint32_t bits = FloatToBits(fval); + int slot = function.CreateFrameIndex(4); + EmitMovImm(ctx.NewVReg(VRegType::kInt32), bits, block); + // 上面生成了一个新 vreg 来承载 bit pattern,直接用那个 vreg store 到栈 + // 但我们需要拿到它的编号...上面的 EmitMovImm 内部分配了 vreg,不方便取回。 + // 简化处理:用一个临时 vreg + int tmp = ctx.NewVReg(VRegType::kInt32); + EmitMovImm(tmp, bits, block); + auto& s = block.Append(Opcode::StoreStack, + {Operand::VReg(tmp), Operand::FrameIndex(slot)}); + s.AddUse(tmp); + int fvreg = ctx.NewVReg(VRegType::kFloat32); + auto& l = block.Append(Opcode::LoadStack, + {Operand::VReg(fvreg), Operand::FrameIndex(slot)}); + l.AddDef(fvreg); + ctx.SetVReg(value, fvreg); + return fvreg; + } + + // 零常量 / 聚合零 + if (dynamic_cast(value) || + dynamic_cast(value)) { + int vreg = ctx.NewVReg(VRegType::kInt32); + auto& instr = block.Append(Opcode::MovImm, + {Operand::VReg(vreg), Operand::Imm(0)}); + instr.AddDef(vreg); + ctx.SetVReg(value, vreg); + return vreg; + } + + // 全局变量:生成地址到 64 位 vreg + if (auto* global = dynamic_cast(value)) { + int vreg = ctx.NewVReg(VRegType::kInt64); + auto& i1 = block.Append(Opcode::Adrp, + {Operand::VReg(vreg), Operand::Label(global->GetName())}); + i1.AddDef(vreg); + auto& i2 = block.Append(Opcode::AddLabel, + {Operand::VReg(vreg), Operand::VReg(vreg), Operand::Label(global->GetName())}); + i2.AddDef(vreg); + i2.AddUse(vreg); + ctx.SetVReg(value, vreg); + return vreg; + } + + // 未找到 + std::string name = value->GetName(); + if (name.empty()) name = "(anonymous)"; + throw std::runtime_error( + FormatError("mir", "EmitValueToVReg: 找不到值对应的 vreg: " + name)); +} + +// ========== IR 比较谓词 → ARMv8 条件码 ========== CondCode IcmpToCondCode(ir::IcmpInst::Predicate pred) { switch (pred) { case ir::IcmpInst::Predicate::EQ: return CondCode::EQ; @@ -51,7 +241,6 @@ CondCode IcmpToCondCode(ir::IcmpInst::Predicate pred) { } } -// 将 IR 浮点比较谓词转换为 ARMv8 条件码 CondCode FcmpToCondCode(ir::FcmpInst::Predicate pred, bool& isOrdered) { isOrdered = true; switch (pred) { @@ -71,1305 +260,931 @@ CondCode FcmpToCondCode(ir::FcmpInst::Predicate pred, bool& isOrdered) { } } -// 获取基本块的标签名(用于汇编输出) std::string GetBlockLabel(const ir::BasicBlock* bb) { - if (!bb || !bb->GetParent()) { - return ".Lunknown"; - } - // 格式:.L函数名_基本块名 + if (!bb || !bb->GetParent()) return ".Lunknown"; std::string funcName = bb->GetParent()->GetName(); std::string blockName = bb->GetName(); - - // 如果基本块没有名字,使用地址作为标识 - if (blockName.empty()) { + if (blockName.empty()) blockName = std::to_string(reinterpret_cast(bb)); - } - return ".L" + funcName + "_" + blockName; } -// 获取数组类型的维度信息 -static const ir::ArrayType* GetArrayType(const ir::Type* type) { - if (type->IsArray()) { - return static_cast(type); - } - return nullptr; -} - -static std::vector GetArrayStrides(const ir::ArrayType* arrayType) { - std::vector strides; - const std::vector& dims = arrayType->GetDimensions(); - int stride = 4; // 元素大小(int/float 是 4 字节) - - // 从最后一维向前计算步长 - for (int i = dims.size() - 1; i >= 0; --i) { - strides.insert(strides.begin(), stride); - stride *= dims[i]; - } - return strides; -} - -// 在 Lowering.cpp 中添加辅助函数 -const ir::Value* GetOperand(const ir::Instruction& inst, size_t index) { - if (index < inst.GetNumOperands()) { - return inst.GetOperand(index); - } - return nullptr; -} - -const ir::BasicBlock* GetBasicBlockOperand(const ir::Instruction& inst, size_t index) { - const ir::Value* operand = GetOperand(inst, index); - if (operand) { - return dynamic_cast(operand); - } - return nullptr; -} - -void EmitValueToReg(const ir::Value* value, PhysReg target, - const ValueSlotMap& slots, MachineBasicBlock& block, - MachineFunction& function) { - // 处理整数常量 - if (value == nullptr) { - DEBUG_MSG( "EmitValueToReg called with null value\n"); - } - - // 辅助函数:判断32位立即数是否可用单条 movz/movn 编码 - auto isLegalMovImm = [](uint32_t imm) -> bool { - // 检查是否可以通过 movz 或 movn 表示(16位立即数,可左移0/16/32/48位) - // 对于32位寄存器,只考虑 lsl 0 或 16 - if ((imm & 0xFFFF) == imm) return true; // 低16位,lsl #0 - if ((imm & 0xFFFF0000) == imm) return true; // 高16位,lsl #16 - // 可选:检查 movn 情况(~imm 满足上述条件),为了简单可不做,直接返回 false - return false; - }; - - if (auto* constant = dynamic_cast(value)) { - uint32_t imm = static_cast(constant->GetValue()); +// ========== LowerInstruction:将单条 IR 指令转换为 MIR 指令 ========== +void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, + VRegContext& ctx, MachineBasicBlock& block, + std::unordered_map& blockMap) { + DEBUG_MSG("Processing instruction: " << inst.GetName() + << " (opcode: " << static_cast(inst.GetOpcode()) << ")"); - // 如果目标是浮点寄存器,将 imm 位模式解释为 float 并加载 - if (target >= PhysReg::S0 && target <= PhysReg::S7) { - auto it = slots.find(value); - int slot; - if (it == slots.end()) { - slot = function.CreateFrameIndex(4); - // 将 imm 写入栈槽(与 ConstantFloat 相同方式) - if (isLegalMovImm(imm)) { - block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm(static_cast(imm))}); - } else { - uint16_t low = imm & 0xFFFF; - uint16_t high = (imm >> 16) & 0xFFFF; - block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm(low)}); - if (high != 0) { - block.Append(Opcode::Movk, {Operand::Reg(PhysReg::W10), Operand::Imm(high), Operand::Imm(16)}); - } - } - block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W10), Operand::FrameIndex(slot)}); - const_cast(slots).emplace(value, slot); - } else { - slot = it->second; - } - // 从栈槽加载到浮点目标寄存器 - block.Append(Opcode::LoadStack, {Operand::Reg(target), Operand::FrameIndex(slot)}); - return; - } + switch (inst.GetOpcode()) { - if (isLegalMovImm(imm)) { - block.Append(Opcode::MovImm, - {Operand::Reg(target), Operand::Imm(static_cast(imm))}); - } else { - // 分解为 movz (低16位) + movk (高16位) - uint16_t low = imm & 0xFFFF; - uint16_t high = (imm >> 16) & 0xFFFF; - block.Append(Opcode::MovImm, - {Operand::Reg(target), Operand::Imm(low)}); // 先加载低16位 - if (high != 0) { - // 使用 Movk 指令写入高16位 - block.Append(Opcode::Movk, - {Operand::Reg(target), Operand::Imm(high), Operand::Imm(16)}); - } - } - return; - } - // 处理浮点常量 - if (auto* fconstant = dynamic_cast(value)) { - // 检查是否已经为这个常量分配了栈槽 - auto it = slots.find(value); - int slot; - if (it == slots.end()) { - DEBUG_MSG("Value not found: " << value->GetName()); - // 输出所有 slots 的键名用于调试 - for (auto& p : slots) { - DEBUG_MSG(" Slot key: " << p.first->GetName()); - } - // 分配新的栈槽 - slot = function.CreateFrameIndex(4); - // 将浮点常量存储到栈槽 - float fval = fconstant->GetValue(); - uint32_t int_val = FloatToBits(fval); - - // 同样需要对 int_val 进行大立即数分解 - if (isLegalMovImm(int_val)) { - block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm(static_cast(int_val))}); - } else { - uint16_t low = int_val & 0xFFFF; - uint16_t high = (int_val >> 16) & 0xFFFF; - block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm(low)}); - if (high != 0) { - block.Append(Opcode::Movk, {Operand::Reg(PhysReg::W10), Operand::Imm(high), Operand::Imm(16)}); - } - } - block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W10), Operand::FrameIndex(slot)}); - const_cast(slots).emplace(value, slot); - } else { - slot = it->second; + // ---- Alloca:在栈上分配空间,返回指针 ---- + case ir::Opcode::Alloca: { + int arraySlot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); + int ptrVreg = ctx.NewVReg(VRegType::kInt64); + auto& instr = block.Append(Opcode::LoadStackAddr, + {Operand::VReg(ptrVreg), Operand::FrameIndex(arraySlot)}); + instr.AddDef(ptrVreg); + ctx.SetVReg(&inst, ptrVreg); + return; } - // 从栈槽加载到目标寄存器 - block.Append(Opcode::LoadStack, {Operand::Reg(target), Operand::FrameIndex(slot)}); - return; - } + // ---- Store:将值存储到指针指向的地址 ---- + case ir::Opcode::Store: { + auto& store = static_cast(inst); + const ir::Value* ptr = store.GetPtr(); + const ir::Value* val = store.GetValue(); - // 处理零常量 - if (dynamic_cast(value) || - dynamic_cast(value)) { - // 如果目标是浮点寄存器,必须通过栈槽加载 0.0f 的位模式 - if (target >= PhysReg::S0 && target <= PhysReg::S7) { - auto it = slots.find(value); - int slot; - if (it == slots.end()) { - slot = function.CreateFrameIndex(4); - // 写入 0 - block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm(0)}); - block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W10), Operand::FrameIndex(slot)}); - const_cast(slots).emplace(value, slot); - } else { - slot = it->second; - } - // 加载到浮点寄存器 - block.Append(Opcode::LoadStack, {Operand::Reg(target), Operand::FrameIndex(slot)}); - return; - } - // 原有的整数/指针零常量处理保持不变 - block.Append(Opcode::MovImm, {Operand::Reg(target), Operand::Imm(0)}); - return; - } + int ptrVreg = EmitValueToVReg(ptr, ctx, block, function); + int valVreg = EmitValueToVReg(val, ctx, block, function); - // ========== 处理全局变量 ========== - if (auto* global = dynamic_cast(value)) { - // 如果目标是 32 位寄存器,升级为对应的 64 位寄存器 - PhysReg addrTarget = target; - if (target >= PhysReg::W0 && target <= PhysReg::W30) { - // 映射 Wn → Xn - addrTarget = static_cast( - static_cast(target) - static_cast(PhysReg::W0) + static_cast(PhysReg::X0)); - } - // 现在 addrTarget 一定是 64 位寄存器 - block.Append(Opcode::Adrp, {Operand::Reg(addrTarget), Operand::Label(global->GetName())}); - block.Append(Opcode::AddLabel, {Operand::Reg(addrTarget), Operand::Reg(addrTarget), Operand::Label(global->GetName())}); + auto& instr = block.Append(Opcode::StoreStack, + {Operand::VReg(valVreg), Operand::VReg(ptrVreg)}); + instr.AddUse(valVreg); + instr.AddUse(ptrVreg); return; - } - - // ========== 处理从栈槽加载的值 ========== - auto it = slots.find(value); - if (it == slots.end()) { - // 使用值的地址作为调试信息 - std::string valueName = value->GetName(); - if (valueName.empty()) { - valueName = "(anonymous at " + std::to_string(reinterpret_cast(value)) + ")"; } - DEBUG_MSG("Value not found: " << valueName); - // 输出所有 slots 的键名用于调试 - for (auto& p : slots) { - std::string slotName = p.first->GetName(); - if (slotName.empty()) { - slotName = "(anonymous at " + std::to_string(reinterpret_cast(p.first)) + ")"; - } - DEBUG_MSG(" Slot key: " << slotName); - } - throw std::runtime_error( - FormatError("mir", "找不到值对应的栈槽: " + valueName)); - } - - PhysReg actualTarget = target; - const ir::Type* ty = value->GetType().get(); - bool isPointer = ty->IsPtrInt32() || ty->IsPtrFloat() || ty->IsPtrInt1() - || ty->IsArray(); // 数组类型在地址上下文中视为指针 - - // 若非指针类型且目标是 64 位寄存器,降级为对应的 32 位寄存器(自动零扩展) - if (!isPointer && target >= PhysReg::X0 && target <= PhysReg::X30) { - actualTarget = static_cast( - static_cast(target) - static_cast(PhysReg::X0) + static_cast(PhysReg::W0)); -} - -block.Append(Opcode::LoadStack, - {Operand::Reg(actualTarget), Operand::FrameIndex(it->second)}); -} -void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, - ValueSlotMap& slots, MachineBasicBlock& block, - std::unordered_map& blockMap) { - //auto& block = function.GetEntry(); - DEBUG_MSG("Processing instruction: " << inst.GetName() - << " (opcode: " << static_cast(inst.GetOpcode()) << ")"); - switch (inst.GetOpcode()) { - case ir::Opcode::Alloca: { - // alloca 返回一个指针,我们需要为该指针分配一个栈槽(存放地址值) - int ptrSlot = function.CreateFrameIndex(8); // 指针占8字节 - // 为数组数据分配实际栈空间 - int arraySlot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); - - // 将数组基地址(sp + arraySlot_offset)加载到 x8 - block.Append(Opcode::LoadStackAddr, - {Operand::Reg(PhysReg::X8), Operand::FrameIndex(arraySlot)}); - // 将地址存储到指针槽 - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::X8), Operand::FrameIndex(ptrSlot)}); - - // 将 alloca 指令映射到指针槽,后续使用 alloca 值即获得地址 - slots.emplace(&inst, ptrSlot); - return; - } - case ir::Opcode::Store: { - auto& store = static_cast(inst); - const ir::Value* ptr = store.GetPtr(); - const ir::Value* val = store.GetValue(); - - // 处理全局变量作为存储目标 - if (auto* global = dynamic_cast(ptr)) { - // 生成全局变量地址到 x8 - block.Append(Opcode::Adrp, {Operand::Reg(PhysReg::X8), Operand::Label(global->GetName())}); - block.Append(Opcode::AddLabel, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Label(global->GetName())}); - // 加载要存储的值到 w9 - EmitValueToReg(val, PhysReg::W9, slots, block, function); - // 间接存储:str w9, [x8] - block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::X8)}); - return; - } - - auto dstIt = slots.find(ptr); - if (dstIt == slots.end()) { - // 指针不在 slots 中(例如直接来自函数参数的指针) - // 计算指针地址到 x8 - EmitValueToReg(ptr, PhysReg::X8, slots, block, function); - // 加载要存储的值到 w9 - EmitValueToReg(val, PhysReg::W9, slots, block, function); - // 使用 StoreStack 的寄存器间接形式(str w9, [x8]) - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::X8)}); - } else { - // 指针在 slots 中(alloca 或 gep 的结果),从指针槽加载地址到 x8 - block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::X8), Operand::FrameIndex(dstIt->second)}); - // 加载要存储的值到 w9 - EmitValueToReg(val, PhysReg::W9, slots, block, function); - // 间接存储 - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::X8)}); - } - return; - } + // ---- Load:从指针指向的地址加载值 ---- case ir::Opcode::Load: { - auto& load = static_cast(inst); - const ir::Value* ptr = load.GetPtr(); - int dstSlot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); - - // 处理全局变量作为加载源 - if (auto* global = dynamic_cast(ptr)) { - // 生成全局变量地址到 x8 - block.Append(Opcode::Adrp, {Operand::Reg(PhysReg::X8), Operand::Label(global->GetName())}); - block.Append(Opcode::AddLabel, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Label(global->GetName())}); - // 间接加载到 w9 - block.Append(Opcode::LoadStack, {Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::X8)}); - // 存储结果到栈槽 - block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W9), Operand::FrameIndex(dstSlot)}); - slots.emplace(&inst, dstSlot); - return; - } - - auto srcIt = slots.find(ptr); - if (srcIt == slots.end()) { - // 指针不在 slots 中,计算地址到 x8 - EmitValueToReg(ptr, PhysReg::X8, slots, block, function); - // 间接加载到 w9 - block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::X8)}); - } else { - // 从指针槽加载地址到 x8 - block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::X8), Operand::FrameIndex(srcIt->second)}); - // 间接加载到 w9 - block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::X8)}); - } - // 将加载的值存入结果槽 - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W9), Operand::FrameIndex(dstSlot)}); - slots.emplace(&inst, dstSlot); - return; + auto& load = static_cast(inst); + const ir::Value* ptr = load.GetPtr(); + + int ptrVreg = EmitValueToVReg(ptr, ctx, block, function); + VRegType resultType = GetVRegTypeForIRType(inst.GetType().get()); + int dstVreg = ctx.NewVReg(resultType); + + auto& instr = block.Append(Opcode::LoadStack, + {Operand::VReg(dstVreg), Operand::VReg(ptrVreg)}); + instr.AddDef(dstVreg); + instr.AddUse(ptrVreg); + ctx.SetVReg(&inst, dstVreg); + return; } + + // ---- Add ---- case ir::Opcode::Add: { auto& bin = static_cast(inst); const ir::Type* lhsTy = bin.GetLhs()->GetType().get(); const ir::Type* rhsTy = bin.GetRhs()->GetType().get(); const ir::Type* resultTy = inst.GetType().get(); - // 指针判断:指令结果类型是指针,或者任一操作数是指针(指针算术) bool isPointer = (lhsTy->IsPtrInt32() || lhsTy->IsPtrFloat() || lhsTy->IsPtrInt1() || lhsTy->IsArray()) || - (rhsTy->IsPtrInt32() || rhsTy->IsPtrFloat() || rhsTy->IsPtrInt1() || rhsTy->IsArray()) || - resultTy->IsPtrInt32() || resultTy->IsPtrFloat() || resultTy->IsPtrInt1() || resultTy->IsArray(); - - // 判断是否为纯 32 位有符号整数加法(需要提升为 64 位以避免溢出) + (rhsTy->IsPtrInt32() || rhsTy->IsPtrFloat() || rhsTy->IsPtrInt1() || rhsTy->IsArray()) || + resultTy->IsPtrInt32() || resultTy->IsPtrFloat() || resultTy->IsPtrInt1() || resultTy->IsArray(); bool isI32 = !isPointer && IsInt32Type(lhsTy) && IsInt32Type(rhsTy) && IsInt32Type(resultTy); - int slotSize = isPointer ? 8 : 4; - int dst_slot = function.CreateFrameIndex(slotSize); - - PhysReg lhsReg, rhsReg, dstReg; if (isI32) { - // 使用 64 位寄存器,先符号扩展 - lhsReg = PhysReg::X8; - rhsReg = PhysReg::X9; - dstReg = PhysReg::X8; - - EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block, function); - block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::W8)}); - - EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block, function); - block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::W9)}); - } else { - lhsReg = isPointer ? PhysReg::X8 : PhysReg::W8; - rhsReg = isPointer ? PhysReg::X9 : PhysReg::W9; - dstReg = isPointer ? PhysReg::X8 : PhysReg::W8; - EmitValueToReg(bin.GetLhs(), lhsReg, slots, block, function); - EmitValueToReg(bin.GetRhs(), rhsReg, slots, block, function); - } - - block.Append(Opcode::AddRR, {Operand::Reg(dstReg), Operand::Reg(lhsReg), Operand::Reg(rhsReg)}); - - if (isI32) { - // 结果在 X8 中,只需存储低 32 位(W8) - block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + // i32 加法:用 64 位运算避免溢出 + int vl = EmitValueToVReg(bin.GetLhs(), ctx, block, function); + int vr = EmitValueToVReg(bin.GetRhs(), ctx, block, function); + int vxl = ctx.NewVReg(VRegType::kInt64); + int vxr = ctx.NewVReg(VRegType::kInt64); + int vxd = ctx.NewVReg(VRegType::kInt64); + + auto& s1 = block.Append(Opcode::Sxtw, {Operand::VReg(vxl), Operand::VReg(vl)}); + s1.AddDef(vxl); s1.AddUse(vl); + auto& s2 = block.Append(Opcode::Sxtw, {Operand::VReg(vxr), Operand::VReg(vr)}); + s2.AddDef(vxr); s2.AddUse(vr); + auto& a = block.Append(Opcode::AddRR, {Operand::VReg(vxd), Operand::VReg(vxl), Operand::VReg(vxr)}); + a.AddDef(vxd); a.AddUse(vxl); a.AddUse(vxr); + + // 结果用 32 位 vreg(低 32 位正确) + int vd = ctx.NewVReg(VRegType::kInt32); + auto& mv = block.Append(Opcode::MovReg, {Operand::VReg(vd), Operand::VReg(vxd)}); + mv.AddDef(vd); mv.AddUse(vxd); + ctx.SetVReg(&inst, vd); } else { - block.Append(Opcode::StoreStack, {Operand::Reg(dstReg), Operand::FrameIndex(dst_slot)}); + VRegType ty = isPointer ? VRegType::kInt64 : VRegType::kInt32; + int vl = EmitValueToVReg(bin.GetLhs(), ctx, block, function); + int vr = EmitValueToVReg(bin.GetRhs(), ctx, block, function); + int vd = ctx.NewVReg(ty); + auto& a = block.Append(Opcode::AddRR, {Operand::VReg(vd), Operand::VReg(vl), Operand::VReg(vr)}); + a.AddDef(vd); a.AddUse(vl); a.AddUse(vr); + ctx.SetVReg(&inst, vd); } - slots.emplace(&inst, dst_slot); return; } + // ---- Sub ---- case ir::Opcode::Sub: { auto& bin = static_cast(inst); - // 如果两个操作数都是浮点类型,则使用浮点减法 + // 浮点减法 if (bin.GetLhs()->GetType()->IsFloat() && bin.GetRhs()->GetType()->IsFloat()) { - int dst_slot = function.CreateFrameIndex(4); - EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block, function); - EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block, function); - block.Append(Opcode::FSubRR, {Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::S1)}); - block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - return; + int vl = EmitValueToVReg(bin.GetLhs(), ctx, block, function); + int vr = EmitValueToVReg(bin.GetRhs(), ctx, block, function); + int vd = ctx.NewVReg(VRegType::kFloat32); + auto& a = block.Append(Opcode::FSubRR, {Operand::VReg(vd), Operand::VReg(vl), Operand::VReg(vr)}); + a.AddDef(vd); a.AddUse(vl); a.AddUse(vr); + ctx.SetVReg(&inst, vd); + return; } const ir::Type* lhsTy = bin.GetLhs()->GetType().get(); const ir::Type* rhsTy = bin.GetRhs()->GetType().get(); const ir::Type* resultTy = inst.GetType().get(); - bool isPointer = (lhsTy->IsPtrInt32() || lhsTy->IsPtrFloat() || lhsTy->IsPtrInt1() || lhsTy->IsArray()) || - (rhsTy->IsPtrInt32() || rhsTy->IsPtrFloat() || rhsTy->IsPtrInt1() || rhsTy->IsArray()) || - resultTy->IsPtrInt32() || resultTy->IsPtrFloat() || resultTy->IsPtrInt1() || resultTy->IsArray(); - + (rhsTy->IsPtrInt32() || rhsTy->IsPtrFloat() || rhsTy->IsPtrInt1() || rhsTy->IsArray()) || + resultTy->IsPtrInt32() || resultTy->IsPtrFloat() || resultTy->IsPtrInt1() || resultTy->IsArray(); bool isI32 = !isPointer && IsInt32Type(lhsTy) && IsInt32Type(rhsTy) && IsInt32Type(resultTy); - int slotSize = isPointer ? 8 : 4; - int dst_slot = function.CreateFrameIndex(slotSize); - - PhysReg lhsReg, rhsReg, dstReg; - if (isI32) { - lhsReg = PhysReg::X8; - rhsReg = PhysReg::X9; - dstReg = PhysReg::X8; - - EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block, function); - block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::W8)}); - - EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block, function); - block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::W9)}); - } else { - lhsReg = isPointer ? PhysReg::X8 : PhysReg::W8; - rhsReg = isPointer ? PhysReg::X9 : PhysReg::W9; - dstReg = isPointer ? PhysReg::X8 : PhysReg::W8; - EmitValueToReg(bin.GetLhs(), lhsReg, slots, block, function); - EmitValueToReg(bin.GetRhs(), rhsReg, slots, block, function); - } - - block.Append(Opcode::SubRR, {Operand::Reg(dstReg), Operand::Reg(lhsReg), Operand::Reg(rhsReg)}); - if (isI32) { - block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + int vl = EmitValueToVReg(bin.GetLhs(), ctx, block, function); + int vr = EmitValueToVReg(bin.GetRhs(), ctx, block, function); + int vxl = ctx.NewVReg(VRegType::kInt64); + int vxr = ctx.NewVReg(VRegType::kInt64); + int vxd = ctx.NewVReg(VRegType::kInt64); + auto& s1 = block.Append(Opcode::Sxtw, {Operand::VReg(vxl), Operand::VReg(vl)}); + s1.AddDef(vxl); s1.AddUse(vl); + auto& s2 = block.Append(Opcode::Sxtw, {Operand::VReg(vxr), Operand::VReg(vr)}); + s2.AddDef(vxr); s2.AddUse(vr); + auto& a = block.Append(Opcode::SubRR, {Operand::VReg(vxd), Operand::VReg(vxl), Operand::VReg(vxr)}); + a.AddDef(vxd); a.AddUse(vxl); a.AddUse(vxr); + int vd = ctx.NewVReg(VRegType::kInt32); + auto& mv = block.Append(Opcode::MovReg, {Operand::VReg(vd), Operand::VReg(vxd)}); + mv.AddDef(vd); mv.AddUse(vxd); + ctx.SetVReg(&inst, vd); } else { - block.Append(Opcode::StoreStack, {Operand::Reg(dstReg), Operand::FrameIndex(dst_slot)}); + VRegType ty = isPointer ? VRegType::kInt64 : VRegType::kInt32; + int vl = EmitValueToVReg(bin.GetLhs(), ctx, block, function); + int vr = EmitValueToVReg(bin.GetRhs(), ctx, block, function); + int vd = ctx.NewVReg(ty); + auto& a = block.Append(Opcode::SubRR, {Operand::VReg(vd), Operand::VReg(vl), Operand::VReg(vr)}); + a.AddDef(vd); a.AddUse(vl); a.AddUse(vr); + ctx.SetVReg(&inst, vd); } - slots.emplace(&inst, dst_slot); return; } + // ---- Mul ---- case ir::Opcode::Mul: { auto& bin = static_cast(inst); const ir::Type* lhsTy = bin.GetLhs()->GetType().get(); const ir::Type* rhsTy = bin.GetRhs()->GetType().get(); const ir::Type* resultTy = inst.GetType().get(); - - // 乘法一般不涉及指针,但保留判断 bool isPointer = (lhsTy->IsPtrInt32() || lhsTy->IsPtrFloat() || lhsTy->IsPtrInt1()) || - (rhsTy->IsPtrInt32() || rhsTy->IsPtrFloat() || rhsTy->IsPtrInt1()) || - resultTy->IsPtrInt32() || resultTy->IsPtrFloat() || resultTy->IsPtrInt1(); - + (rhsTy->IsPtrInt32() || rhsTy->IsPtrFloat() || rhsTy->IsPtrInt1()) || + resultTy->IsPtrInt32() || resultTy->IsPtrFloat() || resultTy->IsPtrInt1(); bool isI32 = !isPointer && IsInt32Type(lhsTy) && IsInt32Type(rhsTy) && IsInt32Type(resultTy); - int slotSize = isPointer ? 8 : 4; - int dst_slot = function.CreateFrameIndex(slotSize); - - PhysReg lhsReg, rhsReg, dstReg; - if (isI32) { - lhsReg = PhysReg::X8; - rhsReg = PhysReg::X9; - dstReg = PhysReg::X8; - - EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block, function); - block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::W8)}); - - EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block, function); - block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::W9)}); - } else { - lhsReg = isPointer ? PhysReg::X8 : PhysReg::W8; - rhsReg = isPointer ? PhysReg::X9 : PhysReg::W9; - dstReg = isPointer ? PhysReg::X8 : PhysReg::W8; - EmitValueToReg(bin.GetLhs(), lhsReg, slots, block, function); - EmitValueToReg(bin.GetRhs(), rhsReg, slots, block, function); - } - - block.Append(Opcode::MulRR, {Operand::Reg(dstReg), Operand::Reg(lhsReg), Operand::Reg(rhsReg)}); - if (isI32) { - block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + int vl = EmitValueToVReg(bin.GetLhs(), ctx, block, function); + int vr = EmitValueToVReg(bin.GetRhs(), ctx, block, function); + int vxl = ctx.NewVReg(VRegType::kInt64); + int vxr = ctx.NewVReg(VRegType::kInt64); + int vxd = ctx.NewVReg(VRegType::kInt64); + auto& s1 = block.Append(Opcode::Sxtw, {Operand::VReg(vxl), Operand::VReg(vl)}); + s1.AddDef(vxl); s1.AddUse(vl); + auto& s2 = block.Append(Opcode::Sxtw, {Operand::VReg(vxr), Operand::VReg(vr)}); + s2.AddDef(vxr); s2.AddUse(vr); + auto& a = block.Append(Opcode::MulRR, {Operand::VReg(vxd), Operand::VReg(vxl), Operand::VReg(vxr)}); + a.AddDef(vxd); a.AddUse(vxl); a.AddUse(vxr); + int vd = ctx.NewVReg(VRegType::kInt32); + auto& mv = block.Append(Opcode::MovReg, {Operand::VReg(vd), Operand::VReg(vxd)}); + mv.AddDef(vd); mv.AddUse(vxd); + ctx.SetVReg(&inst, vd); } else { - block.Append(Opcode::StoreStack, {Operand::Reg(dstReg), Operand::FrameIndex(dst_slot)}); + VRegType ty = isPointer ? VRegType::kInt64 : VRegType::kInt32; + int vl = EmitValueToVReg(bin.GetLhs(), ctx, block, function); + int vr = EmitValueToVReg(bin.GetRhs(), ctx, block, function); + int vd = ctx.NewVReg(ty); + auto& a = block.Append(Opcode::MulRR, {Operand::VReg(vd), Operand::VReg(vl), Operand::VReg(vr)}); + a.AddDef(vd); a.AddUse(vl); a.AddUse(vr); + ctx.SetVReg(&inst, vd); } - slots.emplace(&inst, dst_slot); return; } + // ---- Div ---- case ir::Opcode::Div: { auto& bin = static_cast(inst); const ir::Type* lhsTy = bin.GetLhs()->GetType().get(); const ir::Type* rhsTy = bin.GetRhs()->GetType().get(); const ir::Type* resultTy = inst.GetType().get(); - bool isPointer = (lhsTy->IsPtrInt32() || lhsTy->IsPtrFloat() || lhsTy->IsPtrInt1()) || - (rhsTy->IsPtrInt32() || rhsTy->IsPtrFloat() || rhsTy->IsPtrInt1()) || - resultTy->IsPtrInt32() || resultTy->IsPtrFloat() || resultTy->IsPtrInt1(); - + (rhsTy->IsPtrInt32() || rhsTy->IsPtrFloat() || rhsTy->IsPtrInt1()) || + resultTy->IsPtrInt32() || resultTy->IsPtrFloat() || resultTy->IsPtrInt1(); bool isI32 = !isPointer && IsInt32Type(lhsTy) && IsInt32Type(rhsTy) && IsInt32Type(resultTy); - int slotSize = isPointer ? 8 : 4; - int dst_slot = function.CreateFrameIndex(slotSize); - - PhysReg lhsReg, rhsReg, dstReg; if (isI32) { - lhsReg = PhysReg::X8; - rhsReg = PhysReg::X9; - dstReg = PhysReg::X8; - - EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block, function); - block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::W8)}); - - EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block, function); - block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::W9)}); + int vl = EmitValueToVReg(bin.GetLhs(), ctx, block, function); + int vr = EmitValueToVReg(bin.GetRhs(), ctx, block, function); + int vxl = ctx.NewVReg(VRegType::kInt64); + int vxr = ctx.NewVReg(VRegType::kInt64); + int vxd = ctx.NewVReg(VRegType::kInt64); + auto& s1 = block.Append(Opcode::Sxtw, {Operand::VReg(vxl), Operand::VReg(vl)}); + s1.AddDef(vxl); s1.AddUse(vl); + auto& s2 = block.Append(Opcode::Sxtw, {Operand::VReg(vxr), Operand::VReg(vr)}); + s2.AddDef(vxr); s2.AddUse(vr); + auto& a = block.Append(Opcode::SDivRR, {Operand::VReg(vxd), Operand::VReg(vxl), Operand::VReg(vxr)}); + a.AddDef(vxd); a.AddUse(vxl); a.AddUse(vxr); + int vd = ctx.NewVReg(VRegType::kInt32); + auto& mv = block.Append(Opcode::MovReg, {Operand::VReg(vd), Operand::VReg(vxd)}); + mv.AddDef(vd); mv.AddUse(vxd); + ctx.SetVReg(&inst, vd); } else { - lhsReg = isPointer ? PhysReg::X8 : PhysReg::W8; - rhsReg = isPointer ? PhysReg::X9 : PhysReg::W9; - dstReg = isPointer ? PhysReg::X8 : PhysReg::W8; - EmitValueToReg(bin.GetLhs(), lhsReg, slots, block, function); - EmitValueToReg(bin.GetRhs(), rhsReg, slots, block, function); + VRegType ty = isPointer ? VRegType::kInt64 : VRegType::kInt32; + int vl = EmitValueToVReg(bin.GetLhs(), ctx, block, function); + int vr = EmitValueToVReg(bin.GetRhs(), ctx, block, function); + int vd = ctx.NewVReg(ty); + auto& a = block.Append(Opcode::SDivRR, {Operand::VReg(vd), Operand::VReg(vl), Operand::VReg(vr)}); + a.AddDef(vd); a.AddUse(vl); a.AddUse(vr); + ctx.SetVReg(&inst, vd); } - - block.Append(Opcode::SDivRR, {Operand::Reg(dstReg), Operand::Reg(lhsReg), Operand::Reg(rhsReg)}); - - if (isI32) { - block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); - } else { - block.Append(Opcode::StoreStack, {Operand::Reg(dstReg), Operand::FrameIndex(dst_slot)}); - } - slots.emplace(&inst, dst_slot); return; } + // ---- Mod: a % b = a - (a / b) * b ---- case ir::Opcode::Mod: { - // Mod 指令:a % b = a - (a / b) * b - // 我们直接复用提升策略:使用 64 位运算 const ir::Value* lhs = inst.GetOperand(0); const ir::Value* rhs = inst.GetOperand(1); const ir::Type* lhsTy = lhs->GetType().get(); const ir::Type* rhsTy = rhs->GetType().get(); const ir::Type* resultTy = inst.GetType().get(); - bool isI32 = IsInt32Type(lhsTy) && IsInt32Type(rhsTy) && IsInt32Type(resultTy); - int dst_slot = function.CreateFrameIndex(4); // 结果总是 32 位 if (isI32) { - // 加载并扩展 lhs - EmitValueToReg(lhs, PhysReg::W8, slots, block, function); - block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::W8)}); - // 加载并扩展 rhs - EmitValueToReg(rhs, PhysReg::W9, slots, block, function); - block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::W9)}); - - // X10 = X8 / X9 - block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X9)}); - // X10 = X10 * X9 - block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X9)}); - // X8 = X8 - X10 - block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X10)}); - // 存储低 32 位 - block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + int vl = EmitValueToVReg(lhs, ctx, block, function); + int vr = EmitValueToVReg(rhs, ctx, block, function); + int vxl = ctx.NewVReg(VRegType::kInt64); + int vxr = ctx.NewVReg(VRegType::kInt64); + auto& s1 = block.Append(Opcode::Sxtw, {Operand::VReg(vxl), Operand::VReg(vl)}); + s1.AddDef(vxl); s1.AddUse(vl); + auto& s2 = block.Append(Opcode::Sxtw, {Operand::VReg(vxr), Operand::VReg(vr)}); + s2.AddDef(vxr); s2.AddUse(vr); + int vxdiv = ctx.NewVReg(VRegType::kInt64); + int vxtmp = ctx.NewVReg(VRegType::kInt64); + int vxd = ctx.NewVReg(VRegType::kInt64); + auto& d = block.Append(Opcode::SDivRR, {Operand::VReg(vxdiv), Operand::VReg(vxl), Operand::VReg(vxr)}); + d.AddDef(vxdiv); d.AddUse(vxl); d.AddUse(vxr); + auto& m = block.Append(Opcode::MulRR, {Operand::VReg(vxtmp), Operand::VReg(vxdiv), Operand::VReg(vxr)}); + m.AddDef(vxtmp); m.AddUse(vxdiv); m.AddUse(vxr); + auto& sub = block.Append(Opcode::SubRR, {Operand::VReg(vxd), Operand::VReg(vxl), Operand::VReg(vxtmp)}); + sub.AddDef(vxd); sub.AddUse(vxl); sub.AddUse(vxtmp); + int vd = ctx.NewVReg(VRegType::kInt32); + auto& mv = block.Append(Opcode::MovReg, {Operand::VReg(vd), Operand::VReg(vxd)}); + mv.AddDef(vd); mv.AddUse(vxd); + ctx.SetVReg(&inst, vd); } else { - // 原有逻辑(假设不会用于指针或 64 位整数) - EmitValueToReg(lhs, PhysReg::W8, slots, block, function); - EmitValueToReg(rhs, PhysReg::W9, slots, block, function); - block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); - block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W9)}); - block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W10)}); - block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + int vl = EmitValueToVReg(lhs, ctx, block, function); + int vr = EmitValueToVReg(rhs, ctx, block, function); + int vdiv = ctx.NewVReg(VRegType::kInt32); + int vtmp = ctx.NewVReg(VRegType::kInt32); + int vd = ctx.NewVReg(VRegType::kInt32); + auto& d = block.Append(Opcode::SDivRR, {Operand::VReg(vdiv), Operand::VReg(vl), Operand::VReg(vr)}); + d.AddDef(vdiv); d.AddUse(vl); d.AddUse(vr); + auto& m = block.Append(Opcode::MulRR, {Operand::VReg(vtmp), Operand::VReg(vdiv), Operand::VReg(vr)}); + m.AddDef(vtmp); m.AddUse(vdiv); m.AddUse(vr); + auto& s = block.Append(Opcode::SubRR, {Operand::VReg(vd), Operand::VReg(vl), Operand::VReg(vtmp)}); + s.AddDef(vd); s.AddUse(vl); s.AddUse(vtmp); + ctx.SetVReg(&inst, vd); } - slots.emplace(&inst, dst_slot); return; } + + // ---- Ret ---- case ir::Opcode::Ret: { auto& ret = static_cast(inst); const ir::Value* retVal = ret.GetValue(); if (retVal != nullptr) { - const ir::Type* retType = retVal->GetType().get(); - PhysReg retReg = PhysReg::W0; // 默认整数返回值 - if (retType->IsFloat()) { - retReg = PhysReg::S0; - } else if (retType->IsPtrInt32() || retType->IsPtrFloat() || retType->IsPtrInt1()) { - retReg = PhysReg::X0; - } else { - retReg = PhysReg::W0; - } - EmitValueToReg(retVal, retReg, slots, block, function); + const ir::Type* retType = retVal->GetType().get(); + int valVreg = EmitValueToVReg(retVal, ctx, block, function); + PhysReg retReg = PhysReg::W0; + if (retType->IsFloat()) retReg = PhysReg::S0; + else if (retType->IsPtrInt32() || retType->IsPtrFloat() || retType->IsPtrInt1()) + retReg = PhysReg::X0; + auto& mv = block.Append(Opcode::MovReg, + {Operand::Reg(retReg), Operand::VReg(valVreg)}); + mv.AddUse(valVreg); } - block.Append(Opcode::Ret); + auto& r = block.Append(Opcode::Ret); return; } + + // ---- FAdd ---- case ir::Opcode::FAdd: { auto& bin = static_cast(inst); - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); // 分配结果栈槽 - // 浮点值加载到 S0, S1(使用浮点寄存器) - EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block, function); - EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block, function); - block.Append(Opcode::FAddRR, {Operand::Reg(PhysReg::S0), - Operand::Reg(PhysReg::S0), - Operand::Reg(PhysReg::S1)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); + int vl = EmitValueToVReg(bin.GetLhs(), ctx, block, function); + int vr = EmitValueToVReg(bin.GetRhs(), ctx, block, function); + int vd = ctx.NewVReg(VRegType::kFloat32); + auto& a = block.Append(Opcode::FAddRR, {Operand::VReg(vd), Operand::VReg(vl), Operand::VReg(vr)}); + a.AddDef(vd); a.AddUse(vl); a.AddUse(vr); + ctx.SetVReg(&inst, vd); return; } + case ir::Opcode::FSub: { auto& bin = static_cast(inst); - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); // 分配结果栈槽 - // 浮点值加载到 S0, S1(使用浮点寄存器) - EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block, function); - EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block, function); - block.Append(Opcode::FSubRR, {Operand::Reg(PhysReg::S0), - Operand::Reg(PhysReg::S0), - Operand::Reg(PhysReg::S1)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); + int vl = EmitValueToVReg(bin.GetLhs(), ctx, block, function); + int vr = EmitValueToVReg(bin.GetRhs(), ctx, block, function); + int vd = ctx.NewVReg(VRegType::kFloat32); + auto& a = block.Append(Opcode::FSubRR, {Operand::VReg(vd), Operand::VReg(vl), Operand::VReg(vr)}); + a.AddDef(vd); a.AddUse(vl); a.AddUse(vr); + ctx.SetVReg(&inst, vd); return; } + case ir::Opcode::FMul: { auto& bin = static_cast(inst); - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); // 分配结果栈槽 - // 浮点值加载到 S0, S1(使用浮点寄存器) - EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block, function); - EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block, function); - block.Append(Opcode::FMulRR, {Operand::Reg(PhysReg::S0), - Operand::Reg(PhysReg::S0), - Operand::Reg(PhysReg::S1)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); + int vl = EmitValueToVReg(bin.GetLhs(), ctx, block, function); + int vr = EmitValueToVReg(bin.GetRhs(), ctx, block, function); + int vd = ctx.NewVReg(VRegType::kFloat32); + auto& a = block.Append(Opcode::FMulRR, {Operand::VReg(vd), Operand::VReg(vl), Operand::VReg(vr)}); + a.AddDef(vd); a.AddUse(vl); a.AddUse(vr); + ctx.SetVReg(&inst, vd); return; } + case ir::Opcode::FDiv: { auto& bin = static_cast(inst); - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); // 分配结果栈槽 - // 浮点值加载到 S0, S1(使用浮点寄存器) - EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block, function); - EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block, function); - block.Append(Opcode::FDivRR, {Operand::Reg(PhysReg::S0), - Operand::Reg(PhysReg::S0), - Operand::Reg(PhysReg::S1)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); + int vl = EmitValueToVReg(bin.GetLhs(), ctx, block, function); + int vr = EmitValueToVReg(bin.GetRhs(), ctx, block, function); + int vd = ctx.NewVReg(VRegType::kFloat32); + auto& a = block.Append(Opcode::FDivRR, {Operand::VReg(vd), Operand::VReg(vl), Operand::VReg(vr)}); + a.AddDef(vd); a.AddUse(vl); a.AddUse(vr); + ctx.SetVReg(&inst, vd); return; } - // ========== 整数比较指令(修正版)========== + + // ---- Icmp(整数比较) ---- case ir::Opcode::Icmp: { auto& icmp = static_cast(inst); - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); - - EmitValueToReg(icmp.GetLhs(), PhysReg::W8, slots, block, function); - EmitValueToReg(icmp.GetRhs(), PhysReg::W9, slots, block, function); - - block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); + int vl = EmitValueToVReg(icmp.GetLhs(), ctx, block, function); + int vr = EmitValueToVReg(icmp.GetRhs(), ctx, block, function); CondCode cc = IcmpToCondCode(icmp.GetPredicate()); - - // 使用 CSET 模式 - block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W8), Operand::Imm(1)}); - block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(0)}); - - std::string true_label = ".L_cset_true_" + std::to_string(reinterpret_cast(&icmp)); - std::string end_label = ".L_cset_end_" + std::to_string(reinterpret_cast(&icmp)); - - block.Append(Opcode::BCond, {Operand::Cond(cc), Operand::Label(true_label)}); - block.Append(Opcode::MovReg, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); - block.Append(Opcode::B, {Operand::Label(end_label)}); - block.Append(Opcode::Label, {Operand::Label(true_label)}); - block.Append(Opcode::Label, {Operand::Label(end_label)}); - - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - return; - } - // ========== 浮点比较指令 ========== - case ir::Opcode::FCmp: { - auto& fcmp = static_cast(inst); - int dst_slot = function.CreateFrameIndex(4); // 结果是 i1(4字节) - // 1. 加载浮点操作数并比较 - EmitValueToReg(fcmp.GetLhs(), PhysReg::S0, slots, block, function); - EmitValueToReg(fcmp.GetRhs(), PhysReg::S1, slots, block, function); - block.Append(Opcode::FCmpRR, {Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::S1)}); + auto& cmp = block.Append(Opcode::CmpRR, {Operand::VReg(vl), Operand::VReg(vr)}); + cmp.AddUse(vl); cmp.AddUse(vr); - // 2. 获取有序/无序标志和条件码 - bool isOrdered; - CondCode cc = FcmpToCondCode(fcmp.GetPredicate(), isOrdered); + int vd = ctx.NewVReg(VRegType::kInt32); + auto& set1 = block.Append(Opcode::MovImm, {Operand::VReg(vd), Operand::Imm(1)}); + set1.AddDef(vd); - // 3. 结果预设为 0 - block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W8), Operand::Imm(0)}); + std::string trueLabel = ".L_cset_true_" + std::to_string(reinterpret_cast(&icmp)); + std::string endLabel = ".L_cset_end_" + std::to_string(reinterpret_cast(&icmp)); - // 4. 生成标签 - std::string set1_label = ".L_fcset_true_" + std::to_string(reinterpret_cast(&fcmp)); - std::string end_label = ".L_fcset_end_" + std::to_string(reinterpret_cast(&fcmp)); + auto& bc = block.Append(Opcode::BCond, {Operand::Cond(cc), Operand::Label(trueLabel)}); + auto& mv0 = block.Append(Opcode::MovImm, {Operand::VReg(vd), Operand::Imm(0)}); + mv0.AddDef(vd); + auto& b = block.Append(Opcode::B, {Operand::Label(endLabel)}); + auto& lt = block.Append(Opcode::Label, {Operand::Label(trueLabel)}); + auto& le = block.Append(Opcode::Label, {Operand::Label(endLabel)}); - // 5. 处理 NaN 情况 - if (isOrdered) { - // 有序比较:如果无序 (V=1) → 直接结束(结果为 0) - block.Append(Opcode::BCond, {Operand::Cond(CondCode::VS), Operand::Label(end_label)}); - } else { - // 无序比较:如果无序 (V=1) → 结果置 1 - block.Append(Opcode::BCond, {Operand::Cond(CondCode::VS), Operand::Label(set1_label)}); - } - // 6. 正常条件跳转 - block.Append(Opcode::BCond, {Operand::Cond(cc), Operand::Label(set1_label)}); + ctx.SetVReg(&inst, vd); + return; + } - // 7. 无条件到结束 - block.Append(Opcode::B, {Operand::Label(end_label)}); + // ---- FCmp(浮点比较) ---- + case ir::Opcode::FCmp: { + auto& fcmp = static_cast(inst); + int vl = EmitValueToVReg(fcmp.GetLhs(), ctx, block, function); + int vr = EmitValueToVReg(fcmp.GetRhs(), ctx, block, function); + bool isOrdered; + CondCode cc = FcmpToCondCode(fcmp.GetPredicate(), isOrdered); - // 8. 置 1 标签 - block.Append(Opcode::Label, {Operand::Label(set1_label)}); - block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W8), Operand::Imm(1)}); + auto& cmp = block.Append(Opcode::FCmpRR, {Operand::VReg(vl), Operand::VReg(vr)}); + cmp.AddUse(vl); cmp.AddUse(vr); - // 9. 结束标签并存储结果 - block.Append(Opcode::Label, {Operand::Label(end_label)}); - block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - return; - } - // ========== 跳转指令(使用标签操作数)========== - case ir::Opcode::Br: { - DEBUG_MSG("Processing Br"); - auto& br = static_cast(inst); - - if (br.IsConditional()) { - // 条件跳转 - EmitValueToReg(br.GetCondition(), PhysReg::W8, slots, block, function); - block.Append(Opcode::CmpRI, {Operand::Reg(PhysReg::W8), Operand::Imm(0)}); - - std::string trueLabel = GetBlockLabel(br.GetTrueTarget()); - std::string falseLabel = GetBlockLabel(br.GetFalseTarget()); - - block.Append(Opcode::BCond, {Operand::Cond(CondCode::NE), Operand::Label(trueLabel)}); - block.Append(Opcode::B, {Operand::Label(falseLabel)}); - } else { - // 无条件跳转 - std::string targetLabel = GetBlockLabel(br.GetTarget()); - block.Append(Opcode::B, {Operand::Label(targetLabel)}); - } - return; + int vd = ctx.NewVReg(VRegType::kInt32); + auto& zero = block.Append(Opcode::MovImm, {Operand::VReg(vd), Operand::Imm(0)}); + zero.AddDef(vd); + + std::string set1Label = ".L_fcset_true_" + std::to_string(reinterpret_cast(&fcmp)); + std::string endLabel = ".L_fcset_end_" + std::to_string(reinterpret_cast(&fcmp)); + + if (isOrdered) { + auto& bc_vs = block.Append(Opcode::BCond, {Operand::Cond(CondCode::VS), Operand::Label(endLabel)}); + } else { + auto& bc_vs = block.Append(Opcode::BCond, {Operand::Cond(CondCode::VS), Operand::Label(set1Label)}); + } + auto& bc = block.Append(Opcode::BCond, {Operand::Cond(cc), Operand::Label(set1Label)}); + auto& b = block.Append(Opcode::B, {Operand::Label(endLabel)}); + auto& lt = block.Append(Opcode::Label, {Operand::Label(set1Label)}); + auto& s1 = block.Append(Opcode::MovImm, {Operand::VReg(vd), Operand::Imm(1)}); + s1.AddDef(vd); + auto& le = block.Append(Opcode::Label, {Operand::Label(endLabel)}); + + ctx.SetVReg(&inst, vd); + return; } + + // ---- Br / CondBr ---- + case ir::Opcode::Br: case ir::Opcode::CondBr: { - DEBUG_MSG("Processing CondBr"); - auto& br = static_cast(inst); - - // 条件跳转处理 - EmitValueToReg(br.GetCondition(), PhysReg::W8, slots, block, function); - block.Append(Opcode::CmpRI, {Operand::Reg(PhysReg::W8), Operand::Imm(0)}); - + auto& br = static_cast(inst); + if (br.IsConditional()) { + int condVreg = EmitValueToVReg(br.GetCondition(), ctx, block, function); + auto& cmp = block.Append(Opcode::CmpRI, {Operand::VReg(condVreg), Operand::Imm(0)}); + cmp.AddUse(condVreg); + std::string trueLabel = GetBlockLabel(br.GetTrueTarget()); std::string falseLabel = GetBlockLabel(br.GetFalseTarget()); - - block.Append(Opcode::BCond, {Operand::Cond(CondCode::NE), Operand::Label(trueLabel)}); - block.Append(Opcode::B, {Operand::Label(falseLabel)}); - return; + + auto& bc = block.Append(Opcode::BCond, {Operand::Cond(CondCode::NE), Operand::Label(trueLabel)}); + auto& b = block.Append(Opcode::B, {Operand::Label(falseLabel)}); + } else { + std::string targetLabel = GetBlockLabel(br.GetTarget()); + auto& b = block.Append(Opcode::B, {Operand::Label(targetLabel)}); + } + return; } - // ========== 函数调用 ========== + + // ---- Call ---- case ir::Opcode::Call: { auto& call = static_cast(inst); const ir::Function* callee = call.GetCallee(); const std::string& calleeName = callee->GetName(); - - // 分配结果栈槽(如果有返回值) - int dst_slot = -1; - if (!inst.GetType()->IsVoid()) { - dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); - } - - // 按照 ARM64 调用约定传递参数 + + // 传递参数(vreg → 物理寄存器) const auto& args = call.GetArgs(); size_t intArgCount = 0; size_t fpArgCount = 0; - for (size_t i = 0; i < args.size(); ++i) { const auto* arg = args[i]; const ir::Type* argType = arg->GetType().get(); - + int argVreg = EmitValueToVReg(arg, ctx, block, function); + if (argType->IsFloat()) { - // 浮点参数 PhysReg reg = static_cast(static_cast(PhysReg::S0) + fpArgCount); - EmitValueToReg(arg, reg, slots, block, function); + auto& mv = block.Append(Opcode::MovReg, {Operand::Reg(reg), Operand::VReg(argVreg)}); + mv.AddUse(argVreg); fpArgCount++; } else if (argType->IsPtrInt32() || argType->IsPtrFloat() || argType->IsPtrInt1()) { - // 指针参数 → X 寄存器(占用一个整数参数槽) PhysReg reg = static_cast(static_cast(PhysReg::X0) + intArgCount); - EmitValueToReg(arg, reg, slots, block, function); + auto& mv = block.Append(Opcode::MovReg, {Operand::Reg(reg), Operand::VReg(argVreg)}); + mv.AddUse(argVreg); intArgCount++; } else { - // 普通整数 → W 寄存器 PhysReg reg = static_cast(static_cast(PhysReg::W0) + intArgCount); - EmitValueToReg(arg, reg, slots, block, function); + auto& mv = block.Append(Opcode::MovReg, {Operand::Reg(reg), Operand::VReg(argVreg)}); + mv.AddUse(argVreg); intArgCount++; } } - - // 生成调用指令 - //block.Append(Opcode::Call, {Operand::Imm(0)}); // 实际需要传递函数名 - block.Append(Opcode::Call, {Operand::Label(calleeName)}); - // 保存返回值 - if (dst_slot != -1) { + + auto& callInst = block.Append(Opcode::Call, {Operand::Label(calleeName)}); + + // 返回值:物理寄存器 → vreg + if (!inst.GetType()->IsVoid()) { const ir::Type* retType = inst.GetType().get(); + VRegType vtype = GetVRegTypeForIRType(retType); + int dstVreg = ctx.NewVReg(vtype); PhysReg srcReg = PhysReg::W0; - if (retType->IsFloat()) { - srcReg = PhysReg::S0; - } else if (retType->IsPtrInt32() || retType->IsPtrFloat() || retType->IsPtrInt1()) { + if (retType->IsFloat()) srcReg = PhysReg::S0; + else if (retType->IsPtrInt32() || retType->IsPtrFloat() || retType->IsPtrInt1()) srcReg = PhysReg::X0; - } else { - srcReg = PhysReg::W0; - } - block.Append(Opcode::StoreStack, - {Operand::Reg(srcReg), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); + auto& mv = block.Append(Opcode::MovReg, {Operand::VReg(dstVreg), Operand::Reg(srcReg)}); + mv.AddDef(dstVreg); + ctx.SetVReg(&inst, dstVreg); } return; } - // ========== 类型转换指令 ========== + + // ---- ZExt (i1 → i32) ---- case ir::Opcode::ZExt: { auto& zext = static_cast(inst); - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); - - // 加载源值到 w8 - EmitValueToReg(zext.GetValue(), PhysReg::W8, slots, block, function); - - // 零扩展:i1 -> i32,直接存储即可(因为 i1 已经是 0 或 1) - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); + int src = EmitValueToVReg(zext.GetValue(), ctx, block, function); + int dst = ctx.NewVReg(VRegType::kInt32); + auto& mv = block.Append(Opcode::MovReg, {Operand::VReg(dst), Operand::VReg(src)}); + mv.AddDef(dst); mv.AddUse(src); + ctx.SetVReg(&inst, dst); return; } + + // ---- SIToFP ---- case ir::Opcode::SIToFP: { auto& sitofp = static_cast(inst); - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); - - // 加载整数到 w8 - EmitValueToReg(sitofp.GetValue(), PhysReg::W8, slots, block, function); - - // 整数转浮点:SCVTF s0, w8 - block.Append(Opcode::SIToFP, {Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::W8)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); + int src = EmitValueToVReg(sitofp.GetValue(), ctx, block, function); + int dst = ctx.NewVReg(VRegType::kFloat32); + auto& cvt = block.Append(Opcode::SIToFP, {Operand::VReg(dst), Operand::VReg(src)}); + cvt.AddDef(dst); cvt.AddUse(src); + ctx.SetVReg(&inst, dst); return; } + + // ---- FPToSI ---- case ir::Opcode::FPToSI: { auto& fptosi = static_cast(inst); - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); - - // 加载浮点数到 s0 - EmitValueToReg(fptosi.GetValue(), PhysReg::S0, slots, block, function); - - // 浮点转整数:FCVTZS w8, s0 - block.Append(Opcode::FPToSI, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::S0)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); + int src = EmitValueToVReg(fptosi.GetValue(), ctx, block, function); + int dst = ctx.NewVReg(VRegType::kInt32); + auto& cvt = block.Append(Opcode::FPToSI, {Operand::VReg(dst), Operand::VReg(src)}); + cvt.AddDef(dst); cvt.AddUse(src); + ctx.SetVReg(&inst, dst); return; } - case ir::Opcode::GEP: { - auto& gep = static_cast(inst); - - DEBUG_MSG("Processing GEP instruction: " << inst.GetName()); - - // GEP 返回指针类型,在 ARM64 上指针是 8 字节 - int dst_slot = function.CreateFrameIndex(8); - - // 获取基地址(数组的起始地址) - ir::Value* base = gep.GetBase(); - const auto& indices = gep.GetIndices(); - - std::string baseName = base->GetName().empty() ? "unnamed" : base->GetName(); - DEBUG_MSG("Base value: " << baseName); - DEBUG_MSG("Number of indices: " << indices.size()); - - // 打印索引值 - for (size_t idx_i = 0; idx_i < indices.size(); ++idx_i) { - if (auto* const_int = dynamic_cast(indices[idx_i])) { - DEBUG_MSG(" Index[" << idx_i << "] = " << const_int->GetValue() << " (constant)"); - } else { - DEBUG_MSG(" Index[" << idx_i << "] = variable"); - } - } - - // 加载基地址到 x8 - EmitValueToReg(base, PhysReg::X8, slots, block, function); - - if (indices.empty()) { - DEBUG_MSG("No indices, storing base address directly"); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::X8), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - return; - } - - // 获取基地址类型 - const ir::Type* baseType = base->GetType().get(); - DEBUG_MSG("Base type kind: " << static_cast(baseType->GetKind())); - - // 关键修改:对于数组指针类型,第一个索引是多余的,应该跳过 - size_t start_index = 0; - if (baseType->IsPtrInt32() || baseType->IsPtrFloat() || baseType->IsPtrInt1()) { - // 对于指针类型,第一个索引是偏移量,不能跳过 - DEBUG_MSG("Base is pointer type, using all indices for pointer arithmetic"); - start_index = 0; - } else if (baseType->IsArray()) { - // 对于数组类型(非指针),第一个索引是多余的 - // 因为 base 已经是数组本身,不需要再解引用 - DEBUG_MSG("Base is array type, skipping first index (array decay)"); - start_index = 1; - } - - // 如果基地址是数组类型,需要处理多维数组 - if (baseType->IsArray()) { - DEBUG_MSG("Base is array type, processing multi-dimensional array"); - const ir::ArrayType* arrayType = static_cast(baseType); - const std::vector& dims = arrayType->GetDimensions(); - - DEBUG_MSG("Array dimensions: "); - for (size_t i = 0; i < dims.size(); ++i) { - DEBUG_MSG(" dim[" << i << "] = " << dims[i]); - } - - // 正确计算每个维度的步长 - std::vector strides(dims.size()); - int element_size = 4; // 元素大小(int/float 是 4 字节) - for (int i = dims.size() - 1; i >= 0; --i) { - if (i == static_cast(dims.size()) - 1) { - strides[i] = element_size; - DEBUG_MSG("strides[" << i << "] = " << strides[i] << " (element size)"); - } else { - strides[i] = strides[i + 1] * dims[i + 1]; - DEBUG_MSG("strides[" << i << "] = " << strides[i+1] << " * " << dims[i+1] - << " = " << strides[i]); - } - } - - // 计算总偏移,跳过第一个索引 - size_t numIndices = indices.size(); - size_t effective_indices = numIndices - start_index; - if (effective_indices > dims.size()) { - DEBUG_MSG("Warning: effective indices (" << effective_indices << ") > dims.size() (" - << dims.size() << "), truncating"); - effective_indices = dims.size(); - } - - DEBUG_MSG("Using " << effective_indices << " effective indices (starting from " << start_index << ")"); - - // 加载当前地址到 x9 作为偏移量累加器 - block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X9), Operand::Imm(0)}); - - // 用于调试的静态偏移计算 - int debug_offset = 0; - - for (size_t i = 0; i < effective_indices; ++i) { - size_t idx_pos = start_index + i; - int index_value = 0; - if (auto* const_int = dynamic_cast(indices[idx_pos])) { - index_value = const_int->GetValue(); - DEBUG_MSG("Index[" << idx_pos << "] = " << index_value << " (constant)"); - debug_offset += index_value * strides[i]; - DEBUG_MSG(" Contribution = " << index_value << " * " << strides[i] - << " = " << (index_value * strides[i])); - DEBUG_MSG(" Running offset = " << debug_offset); - } else { - DEBUG_MSG("Index[" << idx_pos << "] = variable"); - } - - // 加载当前索引到 x10 - EmitValueToReg(indices[idx_pos], PhysReg::X10, slots, block, function); - - // 乘以步长 - block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X11), Operand::Imm(strides[i])}); - block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X10), - Operand::Reg(PhysReg::X10), - Operand::Reg(PhysReg::X11)}); - - // 累加到偏移量 - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::X10)}); - } - - DEBUG_MSG("Total computed offset = " << debug_offset); - - // 最终地址 = base + offset - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X8), - Operand::Reg(PhysReg::X8), - Operand::Reg(PhysReg::X9)}); - - // 存储计算出的地址 - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::X8), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - DEBUG_MSG("Array GEP completed, result stored in slot " << dst_slot); - return; - } - - // 其他情况的处理... - DEBUG_MSG("Base is other type, using simple handling"); - if (indices.size() >= 1) { - EmitValueToReg(indices[0], PhysReg::X9, slots, block, function); - - // 乘以元素大小(默认 4 字节) - block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X10), Operand::Imm(4)}); - block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::X10)}); - - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X8), - Operand::Reg(PhysReg::X8), - Operand::Reg(PhysReg::X9)}); - } - - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::X8), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - DEBUG_MSG("Simple GEP completed"); - return; - } - // 处理 Trunc 指令 + + // ---- Trunc ---- case ir::Opcode::Trunc: { - auto& inst_ref = static_cast(inst); - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); - - // 假设 Trunc 指令有 GetValue() 方法 - // 如果没有,需要通过操作数列表获取 - const ir::Value* src_val = nullptr; - if (inst.GetNumOperands() > 0) { - src_val = inst.GetOperand(0); - } - if (src_val) { - EmitValueToReg(src_val, PhysReg::W8, slots, block, function); - } - - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - return; + const ir::Value* srcVal = inst.GetNumOperands() > 0 ? inst.GetOperand(0) : nullptr; + if (srcVal) { + int src = EmitValueToVReg(srcVal, ctx, block, function); + int dst = ctx.NewVReg(VRegType::kInt32); + auto& mv = block.Append(Opcode::MovReg, {Operand::VReg(dst), Operand::VReg(src)}); + mv.AddDef(dst); mv.AddUse(src); + ctx.SetVReg(&inst, dst); + } + return; } - // 处理 And 指令 + + // ---- And ---- case ir::Opcode::And: { - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); - - const ir::Value* lhs = nullptr; - const ir::Value* rhs = nullptr; - if (inst.GetNumOperands() >= 2) { - lhs = inst.GetOperand(0); - rhs = inst.GetOperand(1); - } - - if (lhs && rhs) { - EmitValueToReg(lhs, PhysReg::W8, slots, block, function); - EmitValueToReg(rhs, PhysReg::W9, slots, block, function); - block.Append(Opcode::AndRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9)}); - } - - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - return; + const ir::Value* lhs = inst.GetNumOperands() >= 2 ? inst.GetOperand(0) : nullptr; + const ir::Value* rhs = inst.GetNumOperands() >= 2 ? inst.GetOperand(1) : nullptr; + if (lhs && rhs) { + int vl = EmitValueToVReg(lhs, ctx, block, function); + int vr = EmitValueToVReg(rhs, ctx, block, function); + int vd = ctx.NewVReg(VRegType::kInt32); + auto& a = block.Append(Opcode::AndRR, {Operand::VReg(vd), Operand::VReg(vl), Operand::VReg(vr)}); + a.AddDef(vd); a.AddUse(vl); a.AddUse(vr); + ctx.SetVReg(&inst, vd); + } + return; } - // 处理 Or 指令 + + // ---- Or ---- case ir::Opcode::Or: { - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); - - const ir::Value* lhs = nullptr; - const ir::Value* rhs = nullptr; - if (inst.GetNumOperands() >= 2) { - lhs = inst.GetOperand(0); - rhs = inst.GetOperand(1); - } - - if (lhs && rhs) { - EmitValueToReg(lhs, PhysReg::W8, slots, block, function); - EmitValueToReg(rhs, PhysReg::W9, slots, block, function); - block.Append(Opcode::OrRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9)}); - } - - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - return; + const ir::Value* lhs = inst.GetNumOperands() >= 2 ? inst.GetOperand(0) : nullptr; + const ir::Value* rhs = inst.GetNumOperands() >= 2 ? inst.GetOperand(1) : nullptr; + if (lhs && rhs) { + int vl = EmitValueToVReg(lhs, ctx, block, function); + int vr = EmitValueToVReg(rhs, ctx, block, function); + int vd = ctx.NewVReg(VRegType::kInt32); + auto& a = block.Append(Opcode::OrRR, {Operand::VReg(vd), Operand::VReg(vl), Operand::VReg(vr)}); + a.AddDef(vd); a.AddUse(vl); a.AddUse(vr); + ctx.SetVReg(&inst, vd); + } + return; } - // 处理 Not 指令 + + // ---- Not ---- case ir::Opcode::Not: { - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); - - const ir::Value* src_val = nullptr; - if (inst.GetNumOperands() > 0) { - src_val = inst.GetOperand(0); - } - - if (src_val) { - EmitValueToReg(src_val, PhysReg::W8, slots, block, function); - // NOT = XOR with -1 - block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(-1)}); - block.Append(Opcode::EorRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9)}); - } - - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - return; + const ir::Value* srcVal = inst.GetNumOperands() > 0 ? inst.GetOperand(0) : nullptr; + if (srcVal) { + int src = EmitValueToVReg(srcVal, ctx, block, function); + int m1 = ctx.NewVReg(VRegType::kInt32); + auto& mv1 = block.Append(Opcode::MovImm, {Operand::VReg(m1), Operand::Imm(-1)}); + mv1.AddDef(m1); + int vd = ctx.NewVReg(VRegType::kInt32); + auto& eor = block.Append(Opcode::EorRR, {Operand::VReg(vd), Operand::VReg(src), Operand::VReg(m1)}); + eor.AddDef(vd); eor.AddUse(src); eor.AddUse(m1); + ctx.SetVReg(&inst, vd); + } + return; } - // 处理 FPExt(浮点扩展) + + // ---- FPExt ---- case ir::Opcode::FPExt: { - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); - - const ir::Value* src_val = nullptr; - if (inst.GetNumOperands() > 0) { - src_val = inst.GetOperand(0); - } - - if (src_val) { - EmitValueToReg(src_val, PhysReg::S0, slots, block, function); - } - - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - return; + const ir::Value* srcVal = inst.GetNumOperands() > 0 ? inst.GetOperand(0) : nullptr; + if (srcVal) { + int src = EmitValueToVReg(srcVal, ctx, block, function); + int dst = ctx.NewVReg(VRegType::kFloat32); + auto& mv = block.Append(Opcode::MovReg, {Operand::VReg(dst), Operand::VReg(src)}); + mv.AddDef(dst); mv.AddUse(src); + ctx.SetVReg(&inst, dst); + } + return; } - // 处理 FPTrunc(浮点截断) + + // ---- FPTrunc ---- case ir::Opcode::FPTrunc: { - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); - - const ir::Value* src_val = nullptr; - if (inst.GetNumOperands() > 0) { - src_val = inst.GetOperand(0); + const ir::Value* srcVal = inst.GetNumOperands() > 0 ? inst.GetOperand(0) : nullptr; + if (srcVal) { + int src = EmitValueToVReg(srcVal, ctx, block, function); + int dst = ctx.NewVReg(VRegType::kFloat32); + auto& mv = block.Append(Opcode::MovReg, {Operand::VReg(dst), Operand::VReg(src)}); + mv.AddDef(dst); mv.AddUse(src); + ctx.SetVReg(&inst, dst); + } + return; + } + + // ---- GEP ---- + case ir::Opcode::GEP: { + auto& gep = static_cast(inst); + ir::Value* base = gep.GetBase(); + const auto& indices = gep.GetIndices(); + + int baseVreg = EmitValueToVReg(base, ctx, block, function); + + if (indices.empty()) { + ctx.SetVReg(&inst, baseVreg); + return; + } + + const ir::Type* baseType = base->GetType().get(); + size_t startIndex = 0; + if (baseType->IsArray()) { + startIndex = 1; // 跳过数组解码的第一个索引 + } + + if (baseType->IsArray() && indices.size() > startIndex) { + const ir::ArrayType* arrayType = static_cast(baseType); + const std::vector& dims = arrayType->GetDimensions(); + + std::vector strides(dims.size()); + int elementSize = 4; + for (int i = static_cast(dims.size()) - 1; i >= 0; --i) { + if (i == static_cast(dims.size()) - 1) + strides[i] = elementSize; + else + strides[i] = strides[i + 1] * dims[i + 1]; + } + + size_t numIndices = indices.size(); + size_t effectiveIndices = numIndices - startIndex; + if (effectiveIndices > dims.size()) effectiveIndices = dims.size(); + + // 偏移累加器初始化为 0 + int offsetVreg = ctx.NewVReg(VRegType::kInt64); + auto& zero = block.Append(Opcode::MovImm, {Operand::VReg(offsetVreg), Operand::Imm(0)}); + zero.AddDef(offsetVreg); + + for (size_t i = 0; i < effectiveIndices; ++i) { + size_t idxPos = startIndex + i; + int idxVreg = EmitValueToVReg(indices[idxPos], ctx, block, function); + // 确保索引是 Int64(地址计算必须是 64 位) + if (ctx.GetType(idxVreg) == VRegType::kInt32) { + int idxVreg64 = ctx.NewVReg(VRegType::kInt64); + auto& sxt = block.Append(Opcode::Sxtw, + {Operand::VReg(idxVreg64), Operand::VReg(idxVreg)}); + sxt.AddDef(idxVreg64); sxt.AddUse(idxVreg); + idxVreg = idxVreg64; + } + int strideVreg = ctx.NewVReg(VRegType::kInt64); + auto& stride = block.Append(Opcode::MovImm, + {Operand::VReg(strideVreg), Operand::Imm(strides[i])}); + stride.AddDef(strideVreg); + int prodVreg = ctx.NewVReg(VRegType::kInt64); + auto& mul = block.Append(Opcode::MulRR, + {Operand::VReg(prodVreg), Operand::VReg(idxVreg), Operand::VReg(strideVreg)}); + mul.AddDef(prodVreg); mul.AddUse(idxVreg); mul.AddUse(strideVreg); + int newOff = ctx.NewVReg(VRegType::kInt64); + auto& add = block.Append(Opcode::AddRR, + {Operand::VReg(newOff), Operand::VReg(offsetVreg), Operand::VReg(prodVreg)}); + add.AddDef(newOff); add.AddUse(offsetVreg); add.AddUse(prodVreg); + offsetVreg = newOff; } - - if (src_val) { - EmitValueToReg(src_val, PhysReg::S0, slots, block, function); + + int resultVreg = ctx.NewVReg(VRegType::kInt64); + auto& finalAdd = block.Append(Opcode::AddRR, + {Operand::VReg(resultVreg), Operand::VReg(baseVreg), Operand::VReg(offsetVreg)}); + finalAdd.AddDef(resultVreg); finalAdd.AddUse(baseVreg); finalAdd.AddUse(offsetVreg); + ctx.SetVReg(&inst, resultVreg); + } else if (indices.size() >= 1) { + // 简单指针运算 + int idxVreg = EmitValueToVReg(indices[0], ctx, block, function); + if (ctx.GetType(idxVreg) == VRegType::kInt32) { + int idxVreg64 = ctx.NewVReg(VRegType::kInt64); + auto& sxt = block.Append(Opcode::Sxtw, + {Operand::VReg(idxVreg64), Operand::VReg(idxVreg)}); + sxt.AddDef(idxVreg64); sxt.AddUse(idxVreg); + idxVreg = idxVreg64; } - - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - return; + int strideVreg = ctx.NewVReg(VRegType::kInt64); + int elemSize = 4; + if (baseType->IsArray()) { + const ir::ArrayType* arrType = static_cast(baseType); + elemSize = GetTypeSize(arrType->GetElementType().get()); + } + auto& stride = block.Append(Opcode::MovImm, {Operand::VReg(strideVreg), Operand::Imm(elemSize)}); + stride.AddDef(strideVreg); + int prodVreg = ctx.NewVReg(VRegType::kInt64); + auto& mul = block.Append(Opcode::MulRR, + {Operand::VReg(prodVreg), Operand::VReg(idxVreg), Operand::VReg(strideVreg)}); + mul.AddDef(prodVreg); mul.AddUse(idxVreg); mul.AddUse(strideVreg); + int resultVreg = ctx.NewVReg(VRegType::kInt64); + auto& add = block.Append(Opcode::AddRR, + {Operand::VReg(resultVreg), Operand::VReg(baseVreg), Operand::VReg(prodVreg)}); + add.AddDef(resultVreg); add.AddUse(baseVreg); add.AddUse(prodVreg); + ctx.SetVReg(&inst, resultVreg); + } + return; } + default: - DEBUG_MSG("Unhandled opcode: " << static_cast(inst.GetOpcode()) - << " for instruction: " << inst.GetName()); - throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令,opcode: " - + std::to_string(static_cast(inst.GetOpcode())))); - //throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令")); - //throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令,opcode: " - // + std::to_string(static_cast(inst.GetOpcode())))); + throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令,opcode: " + + std::to_string(static_cast(inst.GetOpcode())))); } } -} // namespace - -// 辅助函数,将单个 IR 函数转换为 MachineFunction +// ========== LowerFunction:将 IR 函数转换为 MachineFunction ========== std::unique_ptr LowerFunction(const ir::Function& func) { - auto machine_func = std::make_unique(func.GetName()); - ValueSlotMap slots; + auto machineFunc = std::make_unique(func.GetName()); + VRegContext ctx; - // 存储参数信息,稍后处理 + // 为函数参数分配 vreg struct ParamInfo { - const ir::Value* arg; - int slot; - bool isFloat; - bool isPointer; + const ir::Value* arg; + int vreg; + bool isFloat; + bool isPointer; }; std::vector paramInfos; - - // 为函数参数分配栈槽 + for (const auto& arg : func.GetArguments()) { - int slot = machine_func->CreateFrameIndex(GetTypeSize(arg->GetType().get())); - slots.emplace(arg.get(), slot); bool isFloat = arg->GetType()->IsFloat(); - bool isPointer = arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat() || arg->GetType()->IsPtrInt1(); - paramInfos.push_back({arg.get(), slot, isFloat, isPointer}); + bool isPointer = arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat() + || arg->GetType()->IsPtrInt1(); + VRegType ty = isFloat ? VRegType::kFloat32 + : (isPointer ? VRegType::kInt64 : VRegType::kInt32); + int vreg = ctx.NewVReg(ty); + ctx.SetVReg(arg.get(), vreg); + paramInfos.push_back({arg.get(), vreg, isFloat, isPointer}); } - - // IR 基本块到 MIR 基本块的映射 + + // IR 基本块 → MIR 基本块 映射 std::unordered_map blockMap; - - // 第一遍:为每个 IR 基本块创建 MIR 基本块 - std::string func_name = func.GetName(); + std::string funcName = func.GetName(); + for (const auto& bb : func.GetBlocks()) { - // 格式: .L函数名_基本块名 - auto mirBB = std::make_unique(".L" + func_name + "_" + bb->GetName()); + auto mirBB = std::make_unique(".L" + funcName + "_" + bb->GetName()); blockMap[bb.get()] = mirBB.get(); - machine_func->AddBasicBlock(std::move(mirBB)); + machineFunc->AddBasicBlock(std::move(mirBB)); } - - // 在入口基本块的开头添加参数加载指令 + + // 在入口基本块开头:从物理参数寄存器加载参数到 vreg if (!func.GetBlocks().empty()) { MachineBasicBlock* entryBB = blockMap[func.GetEntry()]; if (entryBB) { - size_t intArgIdx = 0; - size_t fpArgIdx = 0; - - for (const auto& param : paramInfos) { - if (param.isFloat) { - if (fpArgIdx < 8) { - PhysReg reg = static_cast(static_cast(PhysReg::S0) + fpArgIdx); - entryBB->Append(Opcode::StoreStack, - {Operand::Reg(reg), Operand::FrameIndex(param.slot)}); - } - fpArgIdx++; - } else if (param.isPointer) { - if (intArgIdx < 8) { - PhysReg reg = static_cast(static_cast(PhysReg::X0) + intArgIdx); - entryBB->Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(param.slot)}); - } - intArgIdx++; - } else { - if (intArgIdx < 8) { - PhysReg reg = static_cast(static_cast(PhysReg::W0) + intArgIdx); - entryBB->Append(Opcode::StoreStack, - {Operand::Reg(reg), Operand::FrameIndex(param.slot)}); - } - intArgIdx++; - } + size_t intArgIdx = 0; + size_t fpArgIdx = 0; + for (const auto& param : paramInfos) { + if (param.isFloat) { + if (fpArgIdx < 8) { + PhysReg reg = static_cast( + static_cast(PhysReg::S0) + fpArgIdx); + auto& mv = entryBB->Append(Opcode::MovReg, + {Operand::VReg(param.vreg), Operand::Reg(reg)}); + mv.AddDef(param.vreg); + } + fpArgIdx++; + } else if (param.isPointer) { + if (intArgIdx < 8) { + PhysReg reg = static_cast( + static_cast(PhysReg::X0) + intArgIdx); + auto& mv = entryBB->Append(Opcode::MovReg, + {Operand::VReg(param.vreg), Operand::Reg(reg)}); + mv.AddDef(param.vreg); + } + intArgIdx++; + } else { + if (intArgIdx < 8) { + PhysReg reg = static_cast( + static_cast(PhysReg::W0) + intArgIdx); + auto& mv = entryBB->Append(Opcode::MovReg, + {Operand::VReg(param.vreg), Operand::Reg(reg)}); + mv.AddDef(param.vreg); + } + intArgIdx++; } + } } } - - // 第二遍:遍历每个基本块,转换指令 + + // 转换每个基本块的指令 for (const auto& bb : func.GetBlocks()) { MachineBasicBlock* mirBB = blockMap[bb.get()]; - if (!mirBB) { + if (!mirBB) throw std::runtime_error(FormatError("mir", "找不到基本块对应的 MIR 基本块")); - } - + for (const auto& inst : bb->GetInstructions()) { - LowerInstruction(*inst, *machine_func, slots, *mirBB, blockMap); + LowerInstruction(*inst, *machineFunc, ctx, *mirBB, blockMap); + } + } + + // 将 vreg 类型信息存入 MachineFunction(RA 阶段使用) + for (const auto& [vreg, type] : ctx.vreg_types) { + switch (type) { + case VRegType::kInt32: + machineFunc->SetVRegType(vreg, MachineFunction::VRegType::kInt32); + break; + case VRegType::kInt64: + machineFunc->SetVRegType(vreg, MachineFunction::VRegType::kInt64); + break; + case VRegType::kFloat32: + machineFunc->SetVRegType(vreg, MachineFunction::VRegType::kFloat32); + break; + } + } + + // 构建 CFG 边(解析 B/BCond 中的 Label 目标) + for (const auto& bb : func.GetBlocks()) { + MachineBasicBlock* mirBB = blockMap[bb.get()]; + auto& insts = mirBB->GetInstructions(); + if (insts.empty()) continue; + + const auto& last = insts.back(); + if (last.GetOpcode() == Opcode::B) { + for (const auto& op : last.GetOperands()) { + if (op.GetKind() == Operand::Kind::Label) { + MachineBasicBlock* target = machineFunc->GetBlockByName(op.GetLabel()); + if (target) { + mirBB->AddSuccessor(target); + target->AddPredecessor(mirBB); + } + } + } + } else if (last.GetOpcode() == Opcode::BCond) { + // BCond 之后如果有 B,则有两个后继 + // 遍历该块倒数第二条开始找目标 + for (const auto& op : last.GetOperands()) { + if (op.GetKind() == Operand::Kind::Label) { + MachineBasicBlock* target = machineFunc->GetBlockByName(op.GetLabel()); + if (target) { + mirBB->AddSuccessor(target); + target->AddPredecessor(mirBB); + } + } + } + // 查找倒数第二条 B 指令的目标 + if (insts.size() >= 2) { + const auto& prev = insts[insts.size() - 2]; + if (prev.GetOpcode() == Opcode::B) { + for (const auto& op : prev.GetOperands()) { + if (op.GetKind() == Operand::Kind::Label) { + MachineBasicBlock* target = machineFunc->GetBlockByName(op.GetLabel()); + if (target) { + mirBB->AddSuccessor(target); + target->AddPredecessor(mirBB); + } + } + } + } + } + } else if (last.GetOpcode() == Opcode::Ret) { + // Ret 无后继 + } else { + // 非终结指令:fall-through 到下一个基本块(如果有) + // 查找基本块列表中的下一个 + bool found = false; + for (size_t i = 0; i + 1 < func.GetBlocks().size(); ++i) { + if (func.GetBlocks()[i].get() == bb.get()) { + const auto* nextBB = func.GetBlocks()[i + 1].get(); + MachineBasicBlock* nextMIR = blockMap[nextBB]; + if (nextMIR) { + mirBB->AddSuccessor(nextMIR); + nextMIR->AddPredecessor(mirBB); + } + found = true; + break; + } + } } } - return machine_func; + + return machineFunc; } +} // namespace + +// ========== LowerToMIR:入口函数 ========== std::unique_ptr LowerToMIR(const ir::Module& module) { DefaultContext(); - - auto machine_module = std::make_unique(); + auto machineModule = std::make_unique(); // 收集全局变量信息 for (const auto& global : module.GetGlobals()) { int size = GetTypeSize(global->GetType().get()); int alignment = global->GetType()->Alignment(); - bool is_zero_init = !global->HasInitializer(); - bool has_init_data = false; - uint64_t init_data = 0; + bool isZeroInit = !global->HasInitializer(); + bool hasInitData = false; + uint64_t initData = 0; - if (!is_zero_init) { + if (!isZeroInit) { const auto& init = global->GetInitializer(); - // 简单处理:只支持单个元素的标量初始化(float 或 int) if (init.size() == 1) { if (auto* cf = dynamic_cast(init[0])) { float fval = cf->GetValue(); uint32_t bits; memcpy(&bits, &fval, sizeof(bits)); - init_data = bits; - has_init_data = true; + initData = bits; + hasInitData = true; } else if (auto* ci = dynamic_cast(init[0])) { - init_data = static_cast(ci->GetValue()); - has_init_data = true; + initData = static_cast(ci->GetValue()); + hasInitData = true; } } } - machine_module->AddGlobal(global->GetName(), size, alignment, - is_zero_init, has_init_data, init_data); + machineModule->AddGlobal(global->GetName(), size, alignment, + isZeroInit, hasInitData, initData); } - std::vector globals; - // 处理全局变量 - for (const auto& global : module.GetGlobals()) { - // 为全局变量在数据段分配空间 - // 这里需要扩展 MachineModule 来支持全局变量 - DEBUG_MSG("Global variable: " << global->GetName()); - globals.push_back(global.get()); - } - - // 遍历模块中的所有函数 + // 转换所有函数 for (const auto& func : module.GetFunctions()) { try { - auto machine_func = LowerFunction(*func); - machine_module->AddFunction(std::move(machine_func)); + auto machineFunc = LowerFunction(*func); + machineModule->AddFunction(std::move(machineFunc)); } catch (const std::runtime_error& e) { - // 记录错误但继续处理其他函数 - throw std::runtime_error(FormatError("mir", "转换函数失败: " + func->GetName() + " - " + e.what())); + throw std::runtime_error( + FormatError("mir", "转换函数失败: " + func->GetName() + " - " + e.what())); } } - - if (machine_module->GetFunctions().empty()) { + + if (machineModule->GetFunctions().empty()) { throw std::runtime_error(FormatError("mir", "模块中没有成功转换的函数")); } - - return machine_module; + + return machineModule; } -} // namespace mir \ No newline at end of file +} // namespace mir diff --git a/src/mir/MIRFunction.cpp b/src/mir/MIRFunction.cpp index 334f8cc..1d9aa94 100644 --- a/src/mir/MIRFunction.cpp +++ b/src/mir/MIRFunction.cpp @@ -30,4 +30,14 @@ const FrameSlot& MachineFunction::GetFrameSlot(int index) const { return frame_slots_[index]; } +int MachineFunction::CreateSpillSlot(int size) { + int index = CreateFrameIndex(size); + spill_slot_indices_.insert(index); + return index; +} + +bool MachineFunction::IsSpillSlot(int index) const { + return spill_slot_indices_.count(index) > 0; +} + } // namespace mir diff --git a/src/mir/MIRInstr.cpp b/src/mir/MIRInstr.cpp index 1959b5c..9f2cf91 100644 --- a/src/mir/MIRInstr.cpp +++ b/src/mir/MIRInstr.cpp @@ -9,6 +9,8 @@ Operand::Operand(Kind kind, PhysReg reg, int imm, CondCode cc, const std::string Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0, CondCode::EQ, ""); } +Operand Operand::VReg(int id) { return Operand(Kind::VReg, PhysReg::W0, id, CondCode::EQ, ""); } + Operand Operand::Imm(int value) { return Operand(Kind::Imm, PhysReg::W0, value, CondCode::EQ, ""); } diff --git a/src/mir/RegAlloc.cpp b/src/mir/RegAlloc.cpp index 19f6f51..a4b842d 100644 --- a/src/mir/RegAlloc.cpp +++ b/src/mir/RegAlloc.cpp @@ -1,81 +1,488 @@ #include "mir/MIR.h" +#include +#include +#include #include +#include +#include #include "utils/Log.h" namespace mir { namespace { -bool IsAllowedReg(PhysReg reg) { +// ========== VReg 类型 ========== +enum class VRegClass { kInt32, kInt64, kFloat32 }; + +// ========== 活跃区间 ========== +struct LiveInterval { + int vreg; + int start; + int end; + VRegClass reg_class; + + LiveInterval(int v, int s, int e, VRegClass rc) + : vreg(v), start(s), end(e), reg_class(rc) {} +}; + +// ========== 可分配物理寄存器池 ========== +// 整数 32位:callee-saved 优先(W19-W28),然后 caller-saved(W8-W13) +// W14, W15 保留为 spill scratch +const PhysReg kGPR32Pool[] = { + PhysReg::W19, PhysReg::W20, PhysReg::W21, PhysReg::W22, + PhysReg::W23, PhysReg::W24, PhysReg::W25, PhysReg::W26, + PhysReg::W27, PhysReg::W28, + PhysReg::W8, PhysReg::W9, PhysReg::W10, PhysReg::W11, + PhysReg::W12, PhysReg::W13, +}; +constexpr int kNumGPR32 = sizeof(kGPR32Pool) / sizeof(kGPR32Pool[0]); + +// 整数 64位:callee-saved 优先(X19-X28),然后 caller-saved(X8-X13) +// X14, X15 保留为 spill scratch +const PhysReg kGPR64Pool[] = { + PhysReg::X19, PhysReg::X20, PhysReg::X21, PhysReg::X22, + PhysReg::X23, PhysReg::X24, PhysReg::X25, PhysReg::X26, + PhysReg::X27, PhysReg::X28, + PhysReg::X8, PhysReg::X9, PhysReg::X10, PhysReg::X11, + PhysReg::X12, PhysReg::X13, +}; +constexpr int kNumGPR64 = sizeof(kGPR64Pool) / sizeof(kGPR64Pool[0]); + +// 浮点 32位: S8-S13 (callee-saved). S14-S15 保留为 spill scratch +// S16+ 是 caller-saved,不能放入通用池(外部调用会隐式 clobber) +const PhysReg kFPR32Pool[] = { + PhysReg::S8, PhysReg::S9, PhysReg::S10, PhysReg::S11, + PhysReg::S12, PhysReg::S13, +}; +constexpr int kNumFPR32 = sizeof(kFPR32Pool) / sizeof(kFPR32Pool[0]); + +// Spill scratch registers (每个类型 2 个,避免多 spilled-vreg 冲突) +const PhysReg kSpillScratchInt32[] = { PhysReg::W15, PhysReg::W14 }; +const PhysReg kSpillScratchInt64[] = { PhysReg::X15, PhysReg::X14 }; +const PhysReg kSpillScratchFloat[] = { PhysReg::S15, PhysReg::S14 }; + +PhysReg GetSpillScratch(VRegClass rc, int idx) { + switch (rc) { + case VRegClass::kInt32: return kSpillScratchInt32[idx % 2]; + case VRegClass::kInt64: return kSpillScratchInt64[idx % 2]; + case VRegClass::kFloat32: return kSpillScratchFloat[idx % 2]; + } + return kSpillScratchInt32[0]; +} + +// 判断是否为 callee-saved 寄存器 +bool IsCalleeSaved(PhysReg reg) { switch (reg) { - case PhysReg::W0: - case PhysReg::W8: - case PhysReg::W9: - case PhysReg::X29: //FP = X29 帧指针 - case PhysReg::X30: //LR = X30 链接寄存器 - case PhysReg::SP: + case PhysReg::W19: case PhysReg::W20: case PhysReg::W21: case PhysReg::W22: + case PhysReg::W23: case PhysReg::W24: case PhysReg::W25: case PhysReg::W26: + case PhysReg::W27: case PhysReg::W28: + case PhysReg::X19: case PhysReg::X20: case PhysReg::X21: case PhysReg::X22: + case PhysReg::X23: case PhysReg::X24: case PhysReg::X25: case PhysReg::X26: + case PhysReg::X27: case PhysReg::X28: + case PhysReg::S8: case PhysReg::S9: case PhysReg::S10: case PhysReg::S11: + case PhysReg::S12: case PhysReg::S13: case PhysReg::S14: case PhysReg::S15: + case PhysReg::S16: case PhysReg::S17: case PhysReg::S18: case PhysReg::S19: + case PhysReg::S20: case PhysReg::S21: case PhysReg::S22: case PhysReg::S23: + case PhysReg::S24: case PhysReg::S25: case PhysReg::S26: case PhysReg::S27: + case PhysReg::S28: case PhysReg::S29: case PhysReg::S30: case PhysReg::S31: return true; + default: return false; } - return false; } -} // namespace +// 获取寄存器编号(用于 Wn/Xn 互斥检查) +int GetRegIndex(PhysReg reg) { + if (reg >= PhysReg::W0 && reg <= PhysReg::W30) + return static_cast(reg) - static_cast(PhysReg::W0); + if (reg >= PhysReg::X0 && reg <= PhysReg::X30) + return static_cast(reg) - static_cast(PhysReg::X0); + if (reg >= PhysReg::S0 && reg <= PhysReg::S31) + return 32 + static_cast(reg) - static_cast(PhysReg::S0); + return -1; +} -//void RunRegAlloc(MachineFunction& function) { -// for (const auto& inst : function.GetEntry().GetInstructions()) { -// for (const auto& operand : inst.GetOperands()) { -// if (operand.GetKind() == Operand::Kind::Reg && -// !IsAllowedReg(operand.GetReg())) { -// throw std::runtime_error(FormatError("mir", "寄存器分配失败")); -// } -// } -// } -//} - -// 单函数版本的寄存器分配(原有逻辑) -void RunRegAlloc(MachineFunction& function) { - // 当前仅执行最小一致性检查,不实现真实寄存器分配 - // Lab3 阶段保持栈槽模型,不需要真实寄存器分配 - - // 检查每个基本块中的指令 +// ========== 推断 vreg 类型(优先使用 Lowering 存储的类型) ========== +VRegClass InferVRegClass(int vreg, MachineFunction& function) { + if (function.HasVRegType(vreg)) { + switch (function.GetVRegType(vreg)) { + case MachineFunction::VRegType::kFloat32: return VRegClass::kFloat32; + case MachineFunction::VRegType::kInt64: return VRegClass::kInt64; + case MachineFunction::VRegType::kInt32: return VRegClass::kInt32; + } + } + return VRegClass::kInt32; // 默认(不应到达,因为 Lowering 覆盖所有 vreg) +} + +// ========== 指令编号 ========== +void NumberInstructions(MachineFunction& function, + std::unordered_map& instrToIdx, + std::vector& idxToInstr, + std::map& blockBoundary) { + int idx = 0; for (auto& bb : function.GetBasicBlocks()) { - for (auto& instr : bb->GetInstructions()) { - // 检查指令的操作数是否有效 - for (const auto& operand : instr.GetOperands()) { - switch (operand.GetKind()) { - case Operand::Kind::Reg: - // 寄存器操作数:检查是否在允许的范围内 - // 当前使用固定寄存器 w0, w8, w9, s0, s1 等 - break; - case Operand::Kind::FrameIndex: - // 栈槽索引:检查是否有效 - if (operand.GetFrameIndex() < 0 || - operand.GetFrameIndex() >= static_cast(function.GetFrameSlots().size())) { - throw std::runtime_error( - FormatError("regalloc", "无效的栈槽索引: " + - std::to_string(operand.GetFrameIndex()))); + blockBoundary[idx] = bb.get(); + for (auto& inst : bb->GetInstructions()) { + instrToIdx[&inst] = idx; + idxToInstr.push_back(&inst); + ++idx; + } + } +} + +// ========== 构建活跃区间 ========== +std::vector ComputeLiveIntervals(MachineFunction& function) { + const auto& blocks = function.GetBasicBlocks(); + if (blocks.empty()) return {}; + + // 编号 + std::unordered_map instrToIdx; + std::vector idxToInstr; + std::map blockBoundary; + NumberInstructions(function, instrToIdx, idxToInstr, blockBoundary); + int total = static_cast(idxToInstr.size()); + + // 收集所有 vreg + std::set allVRegs; + for (auto* inst : idxToInstr) { + for (int d : inst->GetDefs()) allVRegs.insert(d); + for (int u : inst->GetUses()) allVRegs.insert(u); + } + + // 每个 vreg 的活跃位置集合 + std::unordered_map> vregPositions; + + // 基本块的 use/def 集合 + struct BlockInfo { + std::set use; + std::set def; + int startIdx; + int endIdx; + }; + std::unordered_map blockInfo; + + for (const auto& bb : blocks) { + auto& info = blockInfo[bb.get()]; + // 找到块的首尾指令序号 + auto& insts = bb->GetInstructions(); + if (!insts.empty()) { + info.startIdx = instrToIdx[&insts.front()]; + info.endIdx = instrToIdx[&insts.back()] + 1; + } else { + info.startIdx = 0; + info.endIdx = 0; + } + + for (auto& inst : insts) { + int pos = instrToIdx[&inst]; + for (int def : inst.GetDefs()) { + info.def.insert(def); + vregPositions[def].insert(pos); + } + for (int use : inst.GetUses()) { + if (info.def.count(use) == 0) { + info.use.insert(use); + } + vregPositions[use].insert(pos); + } + } + } + + // 数据流分析: liveIn/liveOut + std::unordered_map> liveIn, liveOut; + bool changed = true; + while (changed) { + changed = false; + for (auto it = blocks.rbegin(); it != blocks.rend(); ++it) { + MachineBasicBlock* bb = it->get(); + auto& info = blockInfo[bb]; + + // liveOut = union of successors' liveIn + std::set newLiveOut; + for (auto* succ : bb->GetSuccessors()) { + for (int v : liveIn[succ]) newLiveOut.insert(v); + } + if (newLiveOut != liveOut[bb]) { + liveOut[bb] = newLiveOut; + changed = true; + } + + // liveIn = use ∪ (liveOut - def) + std::set newLiveIn = info.use; + for (int v : liveOut[bb]) { + if (info.def.count(v) == 0) newLiveIn.insert(v); + } + if (newLiveIn != liveIn[bb]) { + liveIn[bb] = newLiveIn; + changed = true; + } + } + } + + // 生成 LiveInterval + // 注意:不将 liveIn 扩展到整个基本块的每个位置,因为线性扫描只需要 + // [start, end] 区间。扩展会导致过长的活跃区间,造成不必要的 spill。 + std::vector intervals; + for (int vreg : allVRegs) { + auto it = vregPositions.find(vreg); + if (it == vregPositions.end() || it->second.empty()) continue; + int start = *it->second.begin(); + int end = *it->second.rbegin(); + VRegClass rc = InferVRegClass(vreg, function); + intervals.emplace_back(vreg, start, end, rc); + } + + std::sort(intervals.begin(), intervals.end(), + [](const LiveInterval& a, const LiveInterval& b) { + return a.start < b.start; + }); + return intervals; +} + +// ========== 寄存器池选择 ========== +const PhysReg* GetRegPool(VRegClass rc, int& count) { + switch (rc) { + case VRegClass::kInt32: count = kNumGPR32; return kGPR32Pool; + case VRegClass::kInt64: count = kNumGPR64; return kGPR64Pool; + case VRegClass::kFloat32: count = kNumFPR32; return kFPR32Pool; + } + count = 0; + return nullptr; +} + +// ========== 活跃区间比较(按 end 排序,用于 active 集合) ========== +struct ByEnd { + bool operator()(const LiveInterval* a, const LiveInterval* b) const { + if (a->end != b->end) return a->end < b->end; + return a->vreg < b->vreg; // 打破平局:相同 end 按 vreg 区分 + } +}; + +// ========== 线性扫描寄存器分配(单函数) ========== +void RunRegAllocFunc(MachineFunction& function) { + auto intervals = ComputeLiveIntervals(function); + if (intervals.empty()) return; + + // vreg → 分配的物理寄存器 + std::unordered_map vregToPhys; + + // 活跃区间集合(按 end 排序) + std::set active; + + // Spill 槽:vreg → FrameIndex + std::unordered_map spillSlots; + + // 寄存器占用跟踪 + std::set occupiedRegIndices; + + // 每个寄存器池的空闲/占用状态 + auto allocReg = [&](const LiveInterval& interval) -> PhysReg { + int poolSize = 0; + const PhysReg* pool = GetRegPool(interval.reg_class, poolSize); + for (int i = 0; i < poolSize; ++i) { + int idx = GetRegIndex(pool[i]); + if (occupiedRegIndices.count(idx) == 0) { + occupiedRegIndices.insert(idx); + if (IsCalleeSaved(pool[i])) { + function.MarkCalleeSaved(pool[i]); + } + return pool[i]; + } + } + // 无法分配 + return PhysReg::W0; // will trigger spill logic + }; + + auto freeReg = [&](PhysReg reg) { + occupiedRegIndices.erase(GetRegIndex(reg)); + }; + + auto isFreeReg = [&](VRegClass rc) -> bool { + int poolSize = 0; + const PhysReg* pool = GetRegPool(rc, poolSize); + for (int i = 0; i < poolSize; ++i) { + if (occupiedRegIndices.count(GetRegIndex(pool[i])) == 0) + return true; + } + return false; + }; + + // 线性扫描 + for (auto& interval : intervals) { + // 1. Expire old intervals + std::vector toRemove; + for (auto* act : active) { + if (act->end < interval.start) { + toRemove.push_back(act); + auto it = vregToPhys.find(act->vreg); + if (it != vregToPhys.end()) { + freeReg(it->second); + } + } + } + for (auto* act : toRemove) { + active.erase(act); + } + + // 2. 尝试分配 + if (isFreeReg(interval.reg_class)) { + PhysReg reg = allocReg(interval); + vregToPhys[interval.vreg] = reg; + active.insert(&interval); + } else { + // 3. Spill: 选择 active 中最晚结束的区间 + if (active.empty()) { + // 所有寄存器都被占用(Wn/Xn 别名冲突等边缘情况) + // 直接 spill 当前 interval + int slotSize = (interval.reg_class == VRegClass::kInt64) ? 8 : 4; + int slot = function.CreateSpillSlot(slotSize); + spillSlots[interval.vreg] = slot; + continue; + } + const LiveInterval* spillCand = *active.rbegin(); // 最晚结束 + if (spillCand->end > interval.end) { + // Spill spillCand + PhysReg reg = vregToPhys[spillCand->vreg]; + vregToPhys.erase(spillCand->vreg); + freeReg(reg); + active.erase(spillCand); + + // 为其分配 spill slot + int slotSize = (spillCand->reg_class == VRegClass::kInt64) ? 8 : 4; + int slot = function.CreateSpillSlot(slotSize); + spillSlots[spillCand->vreg] = slot; + + // 分配当前 interval + PhysReg newReg = allocReg(interval); + vregToPhys[interval.vreg] = newReg; + active.insert(&interval); + } else { + // Spill 当前 interval + int slotSize = (interval.reg_class == VRegClass::kInt64) ? 8 : 4; + int slot = function.CreateSpillSlot(slotSize); + spillSlots[interval.vreg] = slot; + // 不分配物理寄存器 + } + } + } + + // ========== 重写指令:VReg → PhysReg + spill/reload ========== + for (auto& bb : function.GetBasicBlocks()) { + std::vector newInsts; + auto& insts = bb->GetInstructions(); + + for (auto& inst : insts) { + auto& ops = inst.GetOperands(); + std::vector& defs = inst.GetDefs(); + std::vector& uses = inst.GetUses(); + + // 辅助:获取 spilled vreg 的类型 + auto getSpillRC = [&](int vreg) -> VRegClass { + for (auto& iv : intervals) { + if (iv.vreg == vreg) return iv.reg_class; + } + return VRegClass::kInt32; + }; + + // 收集此指令中需要 reload 的 spilled vreg(去重) + std::vector spilledUses; + { + std::set seen; + for (int vreg : uses) { + if (spillSlots.count(vreg) && seen.insert(vreg).second) { + spilledUses.push_back(vreg); + } + } + } + + // === 插入 use 前的 reload(每个 spilled vreg 用不同 scratch) === + for (size_t si = 0; si < spilledUses.size(); ++si) { + int vreg = spilledUses[si]; + int slot = spillSlots[vreg]; + PhysReg loadReg; + auto it = vregToPhys.find(vreg); + if (it != vregToPhys.end()) { + loadReg = it->second; + } else { + loadReg = GetSpillScratch(getSpillRC(vreg), static_cast(si)); + } + newInsts.emplace_back(Opcode::LoadStack, + std::vector{Operand::Reg(loadReg), Operand::FrameIndex(slot)}); + } + + // === 替换 VReg 操作数为 PhysReg === + // 跟踪每条指令中 spilled vreg 的 scratch 索引 + int spillUseIdx = 0; + for (auto& op : ops) { + if (op.GetKind() == Operand::Kind::VReg) { + int vreg = op.GetVReg(); + auto it = vregToPhys.find(vreg); + if (it != vregToPhys.end()) { + op = Operand::Reg(it->second); + } else { + // spilled 或未分配:用 spill scratch + int idx = 0; + if (spillSlots.count(vreg)) { + for (size_t si = 0; si < spilledUses.size(); ++si) { + if (spilledUses[si] == vreg) { idx = static_cast(si); break; } + } + } else { + // 防御:vreg 未在 vregToPhys 或 spillSlots 中,创建临时 spill slot + VRegClass rc = getSpillRC(vreg); + int slotSize = (rc == VRegClass::kInt64) ? 8 : 4; + int slot = function.CreateSpillSlot(slotSize); + spillSlots[vreg] = slot; } - break; - case Operand::Kind::Imm: - case Operand::Kind::Cond: - case Operand::Kind::Label: - // 立即数、条件码、标签不需要检查 - break; + op = Operand::Reg(GetSpillScratch(getSpillRC(vreg), idx)); + spillUseIdx++; + } } } + + newInsts.push_back(inst); + + // === 插入 def 后的 store(用于 spilled vreg) === + // 收集此指令中 spilled def vreg(去重) + std::vector spilledDefs; + { + std::set seen; + for (int vreg : defs) { + if (spillSlots.count(vreg) && seen.insert(vreg).second) { + spilledDefs.push_back(vreg); + } + } + } + for (size_t si = 0; si < spilledDefs.size(); ++si) { + int vreg = spilledDefs[si]; + int slot = spillSlots[vreg]; + PhysReg storeReg; + auto it = vregToPhys.find(vreg); + if (it != vregToPhys.end()) { + storeReg = it->second; + } else { + storeReg = GetSpillScratch(getSpillRC(vreg), static_cast(si)); + } + newInsts.emplace_back(Opcode::StoreStack, + std::vector{Operand::Reg(storeReg), Operand::FrameIndex(slot)}); + } + } + insts = std::move(newInsts); + } + + // 清除所有指令的 def/use(RA 完成后不再需要) + for (auto& bb : function.GetBasicBlocks()) { + for (auto& inst : bb->GetInstructions()) { + inst.GetDefs().clear(); + inst.GetUses().clear(); } } - - // 注意:Lab3 阶段不实现真实寄存器分配 - // 所有值仍然使用栈槽模型,寄存器仅作为临时计算使用 } -// 模块版本的寄存器分配 +} // namespace + void RunRegAlloc(MachineModule& module) { - // 对模块中的每个函数执行寄存器分配 for (auto& func : module.GetFunctions()) { - RunRegAlloc(*func); + RunRegAllocFunc(*func); } } diff --git a/src/mir/Register.cpp b/src/mir/Register.cpp index 772322e..2605a2a 100644 --- a/src/mir/Register.cpp +++ b/src/mir/Register.cpp @@ -83,11 +83,35 @@ const char* PhysRegName(PhysReg reg) { case PhysReg::S5: return "s5"; case PhysReg::S6: return "s6"; case PhysReg::S7: return "s7"; - + case PhysReg::S8: return "s8"; + case PhysReg::S9: return "s9"; + case PhysReg::S10: return "s10"; + case PhysReg::S11: return "s11"; + case PhysReg::S12: return "s12"; + case PhysReg::S13: return "s13"; + case PhysReg::S14: return "s14"; + case PhysReg::S15: return "s15"; + case PhysReg::S16: return "s16"; + case PhysReg::S17: return "s17"; + case PhysReg::S18: return "s18"; + case PhysReg::S19: return "s19"; + case PhysReg::S20: return "s20"; + case PhysReg::S21: return "s21"; + case PhysReg::S22: return "s22"; + case PhysReg::S23: return "s23"; + case PhysReg::S24: return "s24"; + case PhysReg::S25: return "s25"; + case PhysReg::S26: return "s26"; + case PhysReg::S27: return "s27"; + case PhysReg::S28: return "s28"; + case PhysReg::S29: return "s29"; + case PhysReg::S30: return "s30"; + case PhysReg::S31: return "s31"; + // 特殊寄存器 case PhysReg::SP: return "sp"; case PhysReg::ZR: return "xzr"; - + default: return "unknown"; } throw std::runtime_error(FormatError("mir", "未知物理寄存器")); diff --git a/src/mir/passes/PassManager.cpp b/src/mir/passes/PassManager.cpp index c510460..b99836e 100644 --- a/src/mir/passes/PassManager.cpp +++ b/src/mir/passes/PassManager.cpp @@ -1,4 +1,16 @@ // MIR Pass 管理: -// - 组织后端 pass 的运行顺序(PreRA/PostRA/PEI 等阶段) -// - 统一运行 pass 与调试输出(按需要扩展) +// - 组织后端 pass 的运行顺序 +// - 统一运行 pass 与调试输出 +#include "mir/MIR.h" + +namespace mir { + +void RunPeephole(MachineModule& module); + +void RunMIRPasses(MachineModule& module) { + // Peephole:RA 后局部优化 + RunPeephole(module); +} + +} // namespace mir diff --git a/src/mir/passes/Peephole.cpp b/src/mir/passes/Peephole.cpp index c6d9ab7..feb961e 100644 --- a/src/mir/passes/Peephole.cpp +++ b/src/mir/passes/Peephole.cpp @@ -1,4 +1,197 @@ // 窥孔优化(Peephole): // - 删除冗余 move、合并常见指令模式 -// - 提升最终汇编质量(按实现范围裁剪) +// - 提升最终汇编质量 +#include "mir/MIR.h" + +#include +#include +#include +#include +#include + +#include "utils/Log.h" + +namespace mir { +namespace { + +// 检查指令是否有副作用(非纯计算) +bool HasSideEffects(const MachineInstr& inst) { + switch (inst.GetOpcode()) { + case Opcode::Call: + case Opcode::Ret: + case Opcode::B: + case Opcode::BCond: + case Opcode::StoreStack: + case Opcode::StoreStackPair: + case Opcode::Prologue: + case Opcode::Epilogue: + return true; + default: + return false; + } +} + +// 检查是否是纯 move 指令 +bool IsPureMove(const MachineInstr& inst) { + return inst.GetOpcode() == Opcode::MovReg; +} + +// 检查指令是否使用了某个物理寄存器 +bool InstUsesReg(const MachineInstr& inst, PhysReg reg) { + for (const auto& op : inst.GetOperands()) { + if (op.GetKind() == Operand::Kind::Reg && op.GetReg() == reg) + return true; + } + return false; +} + +// 检查指令是否定义了某个物理寄存器 +bool InstDefsReg(const MachineInstr& inst, PhysReg reg) { + // 大多数指令的 dest 是第一个操作数 + if (inst.GetOperands().empty()) return false; + const auto& dst = inst.GetOperands()[0]; + if (dst.GetKind() == Operand::Kind::Reg && dst.GetReg() == reg) + return true; + // StoreStackPair / LoadStackPair 有特殊格式 + return false; +} + +// 检查是否恒等操作 +bool IsIdentityOp(const MachineInstr& inst) { + if (inst.GetOperands().size() < 3) return false; + const auto& op2 = inst.GetOperands()[2]; + if (op2.GetKind() != Operand::Kind::Imm) return false; + if (op2.GetImm() != 0) return false; + + switch (inst.GetOpcode()) { + case Opcode::AddRI: + case Opcode::SubRI: + return true; + default: + return false; + } +} + +// 检查两个指令操作相同的栈偏移 +bool IsSameStackOffset(const MachineInstr& a, const MachineInstr& b) { + if (a.GetOperands().size() < 2 || b.GetOperands().size() < 2) return false; + const auto& aOff = a.GetOperands()[1]; + const auto& bOff = b.GetOperands()[1]; + if (aOff.GetKind() == Operand::Kind::FrameIndex && + bOff.GetKind() == Operand::Kind::FrameIndex) { + return aOff.GetFrameIndex() == bOff.GetFrameIndex(); + } + return false; +} + +// 单基本块窥孔优化(一次扫描) +int PeepholeBlock(MachineBasicBlock& bb) { + auto& insts = bb.GetInstructions(); + int changes = 0; + bool changed = true; + + // 迭代直到收敛 + while (changed) { + changed = false; + std::vector newInsts; + size_t n = insts.size(); + + for (size_t i = 0; i < n; ++i) { + MachineInstr& curr = insts[i]; + + // 跳过已标记删除的指令(通过空操作码) + if (curr.GetOpcode() == Opcode::Nop && curr.GetOperands().empty()) { + // 跳过(已经是 nop 但被标记删除) + if (curr.GetOperands().empty()) continue; + } + + // --- 规则1: 恒等操作消除 add/sub ..., #0 → mov --- + if (IsIdentityOp(curr)) { + const auto& dst = curr.GetOperands()[0]; + const auto& src = curr.GetOperands()[1]; + if (dst.GetKind() == Operand::Kind::Reg && src.GetKind() == Operand::Kind::Reg) { + MachineInstr mov(Opcode::MovReg, + std::vector{dst, Operand::Reg(src.GetReg())}); + newInsts.push_back(mov); + changed = true; + ++changes; + continue; + } + } + + // --- 规则2: mov wA, wA → 删除(自赋值) --- + if (IsPureMove(curr) && curr.GetOperands().size() >= 2) { + const auto& dst = curr.GetOperands()[0]; + const auto& src = curr.GetOperands()[1]; + if (dst.GetKind() == Operand::Kind::Reg && src.GetKind() == Operand::Kind::Reg && + dst.GetReg() == src.GetReg()) { + changed = true; + ++changes; + continue; // 删除 + } + } + + // --- 规则3: 冗余 mov → 删除第一条 --- + // mov wA, wB; mov wA, wC → 删除第一条(如果中间无其他使用 wA) + if (IsPureMove(curr) && i + 1 < n) { + const auto& dst0 = curr.GetOperands()[0]; + MachineInstr& next = insts[i + 1]; + if (IsPureMove(next) && next.GetOperands().size() >= 2) { + const auto& dst1 = next.GetOperands()[0]; + if (dst0.GetKind() == Operand::Kind::Reg && + dst1.GetKind() == Operand::Kind::Reg && + dst0.GetReg() == dst1.GetReg()) { + // 第一条 mov 的 dest 在第一条之后、第二条之前没有被使用 + // (两条相邻,中间无其他指令) + changed = true; + ++changes; + continue; // 删除第一条 + } + } + } + + // --- 规则4: Load after Store 消除 --- + // stur wA, [x29, #n]; ldur wB, [x29, #n] → mov wB, wA + if (curr.GetOpcode() == Opcode::StoreStack && i + 1 < n) { + MachineInstr& next = insts[i + 1]; + if (next.GetOpcode() == Opcode::LoadStack && + IsSameStackOffset(curr, next)) { + const auto& storeVal = curr.GetOperands()[0]; + const auto& loadDst = next.GetOperands()[0]; + if (storeVal.GetKind() == Operand::Kind::Reg && + loadDst.GetKind() == Operand::Kind::Reg) { + MachineInstr mov(Opcode::MovReg, + std::vector{loadDst, Operand::Reg(storeVal.GetReg())}); + newInsts.push_back(curr); // 保留 store + newInsts.push_back(mov); // mov 替换 load + ++i; // 跳过 next + changed = true; + ++changes; + continue; + } + } + } + + newInsts.push_back(curr); + } + + insts = std::move(newInsts); + } + + return changes; +} + +} // namespace + +// ========== RunPeephole(模块版本) ========== +void RunPeephole(MachineModule& module) { + int totalChanges = 0; + for (auto& func : module.GetFunctions()) { + for (auto& bb : func->GetBasicBlocks()) { + totalChanges += PeepholeBlock(*bb); + } + } +} + +} // namespace mir