#include "mir/MIR.h" #include #include #include #include "ir/IR.h" #include "utils/Log.h" namespace mir { namespace { const FrameSlot& GetFrameSlot(const MachineFunction& function, const Operand& operand) { if (operand.GetKind() != Operand::Kind::FrameIndex) { throw std::runtime_error(FormatError("mir", "期望 FrameIndex 操作数")); } return function.GetFrameSlot(operand.GetFrameIndex()); } void PrintMoveImm32(std::ostream& os, PhysReg reg, int imm) { std::uint32_t u = static_cast(imm); std::uint32_t lo = u & 0xFFFFu; std::uint32_t hi = (u >> 16) & 0xFFFFu; os << " movz " << PhysRegName(reg) << ", #" << lo << "\n"; if (hi != 0) { os << " movk " << PhysRegName(reg) << ", #" << hi << ", lsl #16\n"; } } void PrintStackAdjust(std::ostream& os, const char* mnemonic, int size) { if (size >= 0 && size <= 4095) { os << " " << mnemonic << " sp, sp, #" << size << "\n"; return; } PrintMoveImm32(os, PhysReg::X10, size); os << " " << mnemonic << " sp, sp, x10\n"; } void PrintAddrFromX29(std::ostream& os, PhysReg dst, int offset) { if (offset >= -4095 && offset <= 4095) { if (offset >= 0) { os << " add " << PhysRegName(dst) << ", x29, #" << offset << "\n"; } else { os << " sub " << PhysRegName(dst) << ", x29, #" << (-offset) << "\n"; } return; } // 使用 X11 而不是 X10,避免与数组索引偏移量冲突 PrintMoveImm32(os, PhysReg::X11, offset < 0 ? -offset : offset); if (offset >= 0) { os << " add " << PhysRegName(dst) << ", x29, x11\n"; } else { os << " sub " << PhysRegName(dst) << ", x29, x11\n"; } } void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg, int offset) { // AArch64 ldur/stur 只支持 -256..255 的立即数偏移 if (offset >= -256 && offset <= 255) { os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n"; } else { // 大偏移:使用 x11 作为临时寄存器(X10 用于数组索引) bool is_load = (mnemonic[0] == 'l'); // ldur -> ldr const char* base_mnemonic = is_load ? "ldr" : "str"; PrintAddrFromX29(os, PhysReg::X11, offset); os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x11]\n"; } } const char* CondSuffix(ir::CmpOp cmp_op) { switch (cmp_op) { case ir::CmpOp::Eq: return "eq"; case ir::CmpOp::Ne: return "ne"; case ir::CmpOp::Lt: return "lt"; case ir::CmpOp::Le: return "le"; case ir::CmpOp::Gt: return "gt"; case ir::CmpOp::Ge: return "ge"; } return "eq"; } } // namespace void PrintAsm(const MachineModule& module, std::ostream& os) { // 输出全局变量定义 if (!module.GetGlobalVars().empty()) { os << ".data\n"; for (const auto& [name, init_val, count, is_float, init_elems] : module.GetGlobalVars()) { (void)is_float; os << ".global " << name << "\n"; os << ".type " << name << ", %object\n"; os << name << ":\n"; if (count == 1) { // 标量全局变量 os << " .word " << init_val << "\n"; } else { // 数组全局变量:优先输出显式初始化元素,剩余部分补零。 int emitted = 0; for (int elem : init_elems) { if (emitted >= count) { break; } os << " .word " << elem << "\n"; ++emitted; } if (emitted == 0) { os << " .zero " << (count * 4) << "\n"; } else if (emitted < count) { os << " .zero " << ((count - emitted) * 4) << "\n"; } } } os << "\n"; } os << ".text\n"; for (const auto& func_ptr : module.GetFunctions()) { const auto& function = *func_ptr; os << ".global " << function.GetName() << "\n"; os << ".type " << function.GetName() << ", %function\n"; os << function.GetName() << ":\n"; // 遍历所有基本块 for (const auto& bb_ptr : function.GetBlocks()) { const auto& bb = *bb_ptr; // 打印块标签(entry 块不需要标签,因为函数名已经是标签了) if (bb.GetName() != "entry") { os << "." << bb.GetName() << ":\n"; } for (const auto& inst : bb.GetInstructions()) { const auto& ops = inst.GetOperands(); switch (inst.GetOpcode()) { case Opcode::Prologue: os << " stp x29, x30, [sp, #-16]!\n"; os << " mov x29, sp\n"; if (function.GetFrameSize() > 0) { PrintStackAdjust(os, "sub", function.GetFrameSize()); } break; case Opcode::Epilogue: if (function.GetFrameSize() > 0) { PrintStackAdjust(os, "add", function.GetFrameSize()); } os << " ldp x29, x30, [sp], #16\n"; break; case Opcode::MovImm: PrintMoveImm32(os, ops.at(0).GetReg(), ops.at(1).GetImm()); break; case Opcode::MovReg: os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; break; case Opcode::FMovImm: // 通用浮点立即数:先装载 bit pattern,再位级移动到 s 寄存器。 PrintMoveImm32(os, PhysReg::W10, ops.at(1).GetImm()); os << " fmov " << PhysRegName(ops.at(0).GetReg()) << ", w10\n"; break; case Opcode::FMovReg: os << " fmov " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; break; case Opcode::LoadStack: { const auto& slot = GetFrameSlot(function, ops.at(1)); PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset); break; } case Opcode::StoreStack: { const auto& slot = GetFrameSlot(function, ops.at(1)); PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset); break; } case Opcode::LoadStackOffset: { // ops: reg, frame_index, imm_offset const auto& slot = GetFrameSlot(function, ops.at(1)); int final_offset = slot.offset + ops.at(2).GetImm(); PrintStackAccess(os, "ldur", ops.at(0).GetReg(), final_offset); break; } case Opcode::StoreStackOffset: { // ops: reg, frame_index, imm_offset const auto& slot = GetFrameSlot(function, ops.at(1)); int final_offset = slot.offset + ops.at(2).GetImm(); PrintStackAccess(os, "stur", ops.at(0).GetReg(), final_offset); break; } case Opcode::LoadStackAddr: { // ops: xN, frame_index // add xN, x29, #offset const auto& slot = GetFrameSlot(function, ops.at(1)); PrintAddrFromX29(os, ops.at(0).GetReg(), slot.offset); break; } case Opcode::LoadIndirect: { // ops: wN, xM // ldr wN, [xM] os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", [" << PhysRegName(ops.at(1).GetReg()) << "]\n"; break; } case Opcode::StoreIndirect: { // ops: wN, xM // str wN, [xM] os << " str " << PhysRegName(ops.at(0).GetReg()) << ", [" << PhysRegName(ops.at(1).GetReg()) << "]\n"; break; } case Opcode::LoadGlobal: { // adrp x9, global_var // add x9, x9, :lo12:global_var // ldr wN, [x9] const std::string& name = ops.at(1).GetSymbol(); os << " adrp x9, " << name << "\n"; os << " add x9, x9, :lo12:" << name << "\n"; os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", [x9]\n"; break; } case Opcode::StoreGlobal: { // adrp x9, global_var // add x9, x9, :lo12:global_var // str wN, [x9] const std::string& name = ops.at(1).GetSymbol(); os << " adrp x9, " << name << "\n"; os << " add x9, x9, :lo12:" << name << "\n"; os << " str " << PhysRegName(ops.at(0).GetReg()) << ", [x9]\n"; break; } case Opcode::LoadGlobalAddr: { // adrp xN, global_var // add xN, xN, :lo12:global_var const std::string& name = ops.at(1).GetSymbol(); os << " adrp " << PhysRegName(ops.at(0).GetReg()) << ", " << name << "\n"; os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(0).GetReg()) << ", :lo12:" << name << "\n"; break; } case Opcode::AddRI: os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", #" << ops.at(2).GetImm() << "\n"; break; case Opcode::SubRI: os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", #" << ops.at(2).GetImm() << "\n"; break; case Opcode::AddRR: os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::SubRR: os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::MulRR: os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::DivRR: os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::FAddRR: os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::FSubRR: os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::FMulRR: os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::FDivRR: os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::FSqrtRR: os << " fsqrt " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; break; case Opcode::SIToFP: os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; break; case Opcode::FPToSI: os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; break; case Opcode::ModRR: // 不应该出现(Mod 在 lowering 时已展开为 div+mul+sub) throw std::runtime_error(FormatError("mir", "ModRR 不应被打印")); case Opcode::LsrRI: os << " lsr " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", #" << ops.at(2).GetImm() << "\n"; break; case Opcode::LslRI: os << " lsl " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", #" << ops.at(2).GetImm() << "\n"; break; case Opcode::LslRR: os << " lsl " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::CmpOnlyRR: os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; break; case Opcode::FCmpOnlyRR: os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; break; case Opcode::CmpRR: { // ops: dst, lhs, rhs, cmpop(imm) auto cmp_op = static_cast(ops.at(3).GetImm()); os << " cmp " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", " << CondSuffix(cmp_op) << "\n"; break; } case Opcode::FCmpRR: { // ops: dst(wN), lhs(sN), rhs(sN), cmpop(imm) auto cmp_op = static_cast(ops.at(3).GetImm()); os << " fcmp " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", " << CondSuffix(cmp_op) << "\n"; break; } case Opcode::Bl: os << " bl " << ops.at(0).GetSymbol() << "\n"; break; case Opcode::B: os << " b ." << ops.at(0).GetSymbol() << "\n"; break; case Opcode::Cbnz: os << " cbnz " << PhysRegName(ops.at(0).GetReg()) << ", ." << ops.at(1).GetSymbol() << "\n"; break; case Opcode::Cbz: os << " cbz " << PhysRegName(ops.at(0).GetReg()) << ", ." << ops.at(1).GetSymbol() << "\n"; break; case Opcode::Bcond: // ops: symbol, cmpop(imm) os << " b." << CondSuffix(static_cast(ops.at(1).GetImm())) << " ." << ops.at(0).GetSymbol() << "\n"; break; case Opcode::Ret: os << " ret\n"; break; } } } os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n\n"; } } } // namespace mir