#include "mir/MIR.h" #include #include #include #include "utils/Log.h" namespace mir { namespace { std::string NormalizeAsmSymbol(std::string name) { while (!name.empty() && (name.front() == '@' || name.front() == '%')) { name.erase(name.begin()); } return name; } const char *OpcodeToAsm(Opcode opcode) { switch (opcode) { case Opcode::Prologue: case Opcode::Epilogue: case Opcode::LoadStackAddr: case Opcode::LoadGlobal: case Opcode::StoreGlobal: case Opcode::LoadGlobalAddr: case Opcode::LoadMem: case Opcode::StoreMem: case Opcode::Uxtw: case Opcode::Sxtw: return ""; case Opcode::MovImm: return "mov"; case Opcode::LoadStack: return "ldur"; case Opcode::StoreStack: return "stur"; case Opcode::AddRR: return "add"; case Opcode::SubRR: return "sub"; case Opcode::MulRR: return "mul"; case Opcode::DivRR: return "sdiv"; case Opcode::FAddRR: return "fadd"; case Opcode::FSubRR: return "fsub"; case Opcode::FMulRR: return "fmul"; case Opcode::FDivRR: return "fdiv"; case Opcode::ModRR: return "msub"; case Opcode::AndRR: return "and"; case Opcode::OrRR: return "orr"; case Opcode::XorRR: return "eor"; case Opcode::ShlRR: return "lsl"; case Opcode::ShrRR: return "lsr"; case Opcode::CmpRR: return "cmp"; case Opcode::CmpImm: return "cmp"; case Opcode::FCmpRR: return "fcmp"; case Opcode::CSet: return "cset"; case Opcode::Csneg: return "csneg"; case Opcode::Scvtf: return "scvtf"; case Opcode::FCvtzs: return "fcvtzs"; case Opcode::FMovWS: return "fmov"; case Opcode::Br: return "b"; case Opcode::CondBr: return "b"; case Opcode::Call: return "bl"; case Opcode::Ret: return "ret"; case Opcode::LoadAddr: return "adrp"; case Opcode::MovReg: return "mov"; default: return ""; } } const char *CondCodeToAsm(CondCode cond) { switch (cond) { case CondCode::EQ: return "eq"; case CondCode::NE: return "ne"; case CondCode::LT: return "lt"; case CondCode::LE: return "le"; case CondCode::GT: return "gt"; case CondCode::GE: return "ge"; default: return ""; } } bool IsXReg(PhysReg reg) { return reg >= PhysReg::X0 && reg <= PhysReg::X30; } bool IsWReg(PhysReg reg) { return reg >= PhysReg::W0 && reg <= PhysReg::W30; } bool IsSReg(PhysReg reg) { return reg >= PhysReg::S0 && reg <= PhysReg::S31; } PhysReg PickModQuotientReg(const Operand &dst, const Operand &lhs, const Operand &rhs) { const PhysReg candidates[] = { PhysReg::W14, PhysReg::W12, PhysReg::W15, PhysReg::W11, PhysReg::W10, PhysReg::W9, PhysReg::W8}; for (PhysReg reg : candidates) { bool conflict = false; for (const Operand *op : {&dst, &lhs, &rhs}) { if (op->GetKind() == Operand::Kind::Reg && IsWReg(op->GetReg()) && op->GetReg() == reg) { conflict = true; break; } } if (!conflict) { return reg; } } return PhysReg::W14; } std::string BlockLabel(const MachineFunction &function, int label_id) { return ".L." + NormalizeAsmSymbol(function.GetName()) + "." + std::to_string(label_id); } void PrintBlockLabelRef(const MachineFunction &function, int label_id, std::ostream &os) { os << BlockLabel(function, label_id); } void PrintOperand(const Operand &operand, std::ostream &os) { switch (operand.GetKind()) { case Operand::Kind::Reg: os << PhysRegName(operand.GetReg()); break; case Operand::Kind::VReg: os << "%vreg" << operand.GetVRegId(); break; case Operand::Kind::Imm: os << "#" << operand.GetImm(); break; case Operand::Kind::FrameIndex: os << ""; break; case Operand::Kind::Label: os << ".L" << operand.GetLabel(); break; case Operand::Kind::Symbol: os << NormalizeAsmSymbol(operand.GetSymbol()); break; } } void EmitLargeImmediate(PhysReg target, int value, std::ostream &os) { if (value >= -32768 && value <= 65535) { os << " mov " << PhysRegName(target) << ", #" << value << "\n"; return; } const unsigned int uvalue = static_cast(value); bool emitted = false; for (int shift = 0; shift < 32; shift += 16) { const unsigned short part = (uvalue >> shift) & 0xFFFF; if (part == 0 && emitted) { continue; } if (!emitted) { os << " movz " << PhysRegName(target) << ", #" << part; if (shift > 0) { os << ", lsl #" << shift; } os << "\n"; emitted = true; } else if (part != 0) { os << " movk " << PhysRegName(target) << ", #" << part; if (shift > 0) { os << ", lsl #" << shift; } os << "\n"; } } if (!emitted) { os << " mov " << PhysRegName(target) << ", #0\n"; } } void EmitStackAdjust(const char *op, int amount, std::ostream &os) { while (amount > 0) { const int chunk = amount > 4095 ? 4095 : amount; os << " " << op << " sp, sp, #" << chunk << "\n"; amount -= chunk; } } PhysReg PrinterScratchXReg() { // 汇编打印阶段展开伪指令时,需要一个全局保留的 scratch。 // 不能再复用 x14/x15,否则会把 lowering 已经放好的数组/全局基址冲掉。 // 当前 lowering 不使用 x13,因此固定保留 x13 给打印阶段使用。 return PhysReg::X13; } void EmitAddressFromBase(PhysReg target_xreg, PhysReg base_reg, int offset, std::ostream &os) { os << " mov " << PhysRegName(target_xreg) << ", " << PhysRegName(base_reg) << "\n"; while (offset > 0) { const int chunk = offset > 4095 ? 4095 : offset; os << " add " << PhysRegName(target_xreg) << ", " << PhysRegName(target_xreg) << ", #" << chunk << "\n"; offset -= chunk; } while (offset < 0) { const int chunk = (-offset) > 4095 ? 4095 : (-offset); os << " sub " << PhysRegName(target_xreg) << ", " << PhysRegName(target_xreg) << ", #" << chunk << "\n"; offset += chunk; } } void PrintStackAccess(Opcode opcode, const Operand ®, int offset, std::ostream &os) { const char *narrow_op = (opcode == Opcode::LoadStack) ? "ldur" : "stur"; const char *wide_op = (opcode == Opcode::LoadStack) ? "ldr" : "str"; if (offset >= -256 && offset <= 255) { os << " " << narrow_op << " "; PrintOperand(reg, os); os << ", [x29, #" << offset << "]\n"; return; } const PhysReg scratch_xreg = PrinterScratchXReg(); EmitAddressFromBase(scratch_xreg, PhysReg::X29, offset, os); os << " " << wide_op << " "; PrintOperand(reg, os); os << ", [" << PhysRegName(scratch_xreg) << "]\n"; } void PrintGlobalAddr(const Operand &dst, const std::string &symbol, std::ostream &os) { const std::string asm_symbol = NormalizeAsmSymbol(symbol); os << " adrp "; PrintOperand(dst, os); os << ", " << asm_symbol << "\n"; os << " add "; PrintOperand(dst, os); os << ", "; PrintOperand(dst, os); os << ", :lo12:" << asm_symbol << "\n"; } void PrintGlobalAccess(Opcode opcode, const Operand ®, const std::string &symbol, std::ostream &os) { const std::string asm_symbol = NormalizeAsmSymbol(symbol); const PhysReg scratch_xreg = PrinterScratchXReg(); os << " adrp " << PhysRegName(scratch_xreg) << ", " << asm_symbol << "\n"; os << " " << (opcode == Opcode::LoadGlobal ? "ldr " : "str "); PrintOperand(reg, os); os << ", [" << PhysRegName(scratch_xreg) << ", #:lo12:" << asm_symbol << "]\n"; } void PrintMemAccess(Opcode opcode, const Operand &data_reg, const Operand &addr_reg, std::ostream &os) { os << " " << (opcode == Opcode::LoadMem ? "ldr " : "str "); PrintOperand(data_reg, os); os << ", ["; PrintOperand(addr_reg, os); os << "]\n"; } int ResolveFrameOffset(const MachineFunction &function, const Operand &operand) { if (operand.GetKind() != Operand::Kind::FrameIndex) { throw std::runtime_error(FormatError("mir", "期望 FrameIndex 操作数")); } return function.GetFrameSlot(operand.GetFrameIndex()).offset; } bool IsStackArgSlot(const MachineFunction &function, const Operand &operand) { if (operand.GetKind() != Operand::Kind::FrameIndex) { return false; } const auto &slot = function.GetFrameSlot(operand.GetFrameIndex()); return slot.is_stack_arg && !slot.is_callee_stack_arg; } void PrintStackArgAccess(Opcode opcode, const Operand ®, int offset, std::ostream &os) { const char *wide_op = (opcode == Opcode::LoadStack) ? "ldr" : "str"; bool is_32bit = IsWReg(reg.GetReg()) || IsSReg(reg.GetReg()); int max_imm = is_32bit ? 16380 : 32760; if (offset >= 0 && offset <= max_imm) { os << " " << wide_op << " "; PrintOperand(reg, os); os << ", [sp, #" << offset << "]\n"; return; } const PhysReg scratch_xreg = PrinterScratchXReg(); EmitAddressFromBase(scratch_xreg, PhysReg::SP, offset, os); os << " " << wide_op << " "; PrintOperand(reg, os); os << ", [" << PhysRegName(scratch_xreg) << "]\n"; } void PrintInstr(const MachineFunction &function, const MachineInstr &instr, std::ostream &os) { const char *asm_op = OpcodeToAsm(instr.GetOpcode()); const auto &operands = instr.GetOperands(); switch (instr.GetOpcode()) { case Opcode::Prologue: { const auto &cs_regs = function.GetCalleeSavedRegs(); os << " stp x29, x30, [sp, #-16]!\n"; os << " mov x29, sp\n"; if (function.GetFrameSize() > 0) { EmitStackAdjust("sub", function.GetFrameSize(), os); } int cs_offset = 0; for (auto r : cs_regs) { if (r >= PhysReg::X0 && r <= PhysReg::X30) { os << " str " << PhysRegName(r) << ", [sp, #" << cs_offset << "]\n"; cs_offset += 8; } else if (r >= PhysReg::S0 && r <= PhysReg::S31) { os << " str " << PhysRegName(r) << ", [sp, #" << cs_offset << "]\n"; cs_offset += 4; } } return; } case Opcode::Epilogue: { const auto &cs_regs = function.GetCalleeSavedRegs(); int cs_offset = 0; for (auto r : cs_regs) { if (r >= PhysReg::X0 && r <= PhysReg::X30) { os << " ldr " << PhysRegName(r) << ", [sp, #" << cs_offset << "]\n"; cs_offset += 8; } else if (r >= PhysReg::S0 && r <= PhysReg::S31) { os << " ldr " << PhysRegName(r) << ", [sp, #" << cs_offset << "]\n"; cs_offset += 4; } } if (function.GetFrameSize() > 0) { EmitStackAdjust("add", function.GetFrameSize(), os); } os << " ldp x29, x30, [sp], #16\n"; os << " ret\n"; return; } case Opcode::Ret: os << " ret\n"; return; case Opcode::MovImm: if (operands.size() >= 2) { EmitLargeImmediate(operands[0].GetReg(), operands[1].GetImm(), os); } return; case Opcode::LoadStack: case Opcode::StoreStack: if (operands.size() >= 2 && operands[0].GetKind() == Operand::Kind::Reg && operands[1].GetKind() == Operand::Kind::FrameIndex) { if (IsStackArgSlot(function, operands[1])) { PrintStackArgAccess(instr.GetOpcode(), operands[0], ResolveFrameOffset(function, operands[1]), os); } else { PrintStackAccess(instr.GetOpcode(), operands[0], ResolveFrameOffset(function, operands[1]), os); } } return; case Opcode::LoadStackAddr: if (operands.size() >= 2 && operands[0].GetKind() == Operand::Kind::Reg && operands[1].GetKind() == Operand::Kind::FrameIndex) { EmitAddressFromBase(operands[0].GetReg(), PhysReg::X29, ResolveFrameOffset(function, operands[1]), os); } return; case Opcode::LoadGlobal: case Opcode::StoreGlobal: if (operands.size() >= 2 && operands[0].GetKind() == Operand::Kind::Reg && operands[1].GetKind() == Operand::Kind::Symbol) { PrintGlobalAccess(instr.GetOpcode(), operands[0], operands[1].GetSymbol(), os); } return; case Opcode::LoadGlobalAddr: if (operands.size() >= 2 && operands[0].GetKind() == Operand::Kind::Reg && operands[1].GetKind() == Operand::Kind::Symbol) { PrintGlobalAddr(operands[0], operands[1].GetSymbol(), os); } return; case Opcode::LoadMem: case Opcode::StoreMem: if (operands.size() >= 2 && operands[0].GetKind() == Operand::Kind::Reg && operands[1].GetKind() == Operand::Kind::Reg && IsXReg(operands[1].GetReg())) { PrintMemAccess(instr.GetOpcode(), operands[0], operands[1], os); } return; case Opcode::Uxtw: case Opcode::Sxtw: if (operands.size() >= 2) { os << " " << (instr.GetOpcode() == Opcode::Uxtw ? "uxtw" : "sxtw") << " "; PrintOperand(operands[0], os); os << ", "; PrintOperand(operands[1], os); os << "\n"; } return; case Opcode::CSet: if (operands.size() >= 2) { os << " cset "; PrintOperand(operands[0], os); os << ", " << CondCodeToAsm(static_cast(operands[1].GetImm())) << "\n"; } return; case Opcode::Br: if (!operands.empty() && operands[0].GetKind() == Operand::Kind::Label) { os << " b "; PrintBlockLabelRef(function, operands[0].GetLabel(), os); os << "\n"; } return; case Opcode::CondBr: if (operands.size() >= 2 && operands[0].GetKind() == Operand::Kind::Imm && operands[1].GetKind() == Operand::Kind::Label) { os << " b." << CondCodeToAsm(static_cast(operands[0].GetImm())) << " "; PrintBlockLabelRef(function, operands[1].GetLabel(), os); os << "\n"; } return; case Opcode::ModRR: if (operands.size() >= 3) { const PhysReg qreg = PickModQuotientReg(operands[0], operands[1], operands[2]); os << " sdiv " << PhysRegName(qreg) << ", "; PrintOperand(operands[1], os); os << ", "; PrintOperand(operands[2], os); os << "\n"; os << " msub "; PrintOperand(operands[0], os); os << ", " << PhysRegName(qreg) << ", "; PrintOperand(operands[2], os); os << ", "; PrintOperand(operands[1], os); os << "\n"; } return; case Opcode::ShlRR: case Opcode::ShrRR: case Opcode::AsrRR: if (operands.size() >= 3) { const char *shift_op = "lsl"; if (instr.GetOpcode() == Opcode::ShrRR) shift_op = "lsr"; else if (instr.GetOpcode() == Opcode::AsrRR) shift_op = "asr"; os << " " << shift_op << " "; PrintOperand(operands[0], os); os << ", "; PrintOperand(operands[1], os); os << ", "; if (operands[2].GetKind() == Operand::Kind::Imm) { os << "#" << operands[2].GetImm(); } else { PrintOperand(operands[2], os); } os << "\n"; } return; case Opcode::Asr64RR: if (operands.size() >= 3) { os << " asr "; if (operands[0].GetKind() == Operand::Kind::Reg && IsWReg(operands[0].GetReg())) os << PhysRegName(static_cast(static_cast(operands[0].GetReg()) + 31)); else PrintOperand(operands[0], os); os << ", "; if (operands[1].GetKind() == Operand::Kind::Reg && IsWReg(operands[1].GetReg())) os << PhysRegName(static_cast(static_cast(operands[1].GetReg()) + 31)); else PrintOperand(operands[1], os); os << ", "; if (operands[2].GetKind() == Operand::Kind::Imm) os << "#" << operands[2].GetImm(); else PrintOperand(operands[2], os); os << "\n"; } return; case Opcode::CmpImm: if (operands.size() >= 2) { os << " cmp "; PrintOperand(operands[0], os); os << ", #" << operands[1].GetImm() << "\n"; } return; case Opcode::Csel: if (operands.size() >= 4) { os << " csel "; PrintOperand(operands[0], os); os << ", "; PrintOperand(operands[1], os); os << ", "; PrintOperand(operands[2], os); os << ", " << CondCodeToAsm(static_cast(operands[3].GetImm())) << "\n"; } return; case Opcode::Csneg: if (operands.size() >= 4) { os << " csneg "; PrintOperand(operands[0], os); os << ", "; PrintOperand(operands[1], os); os << ", "; PrintOperand(operands[2], os); os << ", " << CondCodeToAsm(static_cast(operands[3].GetImm())) << "\n"; } return; case Opcode::Smull: if (operands.size() >= 3) { os << " smull "; if (operands[0].GetKind() == Operand::Kind::Reg && IsWReg(operands[0].GetReg())) os << PhysRegName(static_cast(static_cast(operands[0].GetReg()) + 31)); else PrintOperand(operands[0], os); os << ", "; PrintOperand(operands[1], os); os << ", "; PrintOperand(operands[2], os); os << "\n"; } return; case Opcode::Msub: if (operands.size() >= 4) { os << " msub "; PrintOperand(operands[0], os); os << ", "; PrintOperand(operands[1], os); os << ", "; PrintOperand(operands[2], os); os << ", "; PrintOperand(operands[3], os); os << "\n"; } return; case Opcode::NegRR: if (operands.size() >= 2) { os << " neg "; PrintOperand(operands[0], os); os << ", "; PrintOperand(operands[1], os); os << "\n"; } return; case Opcode::MovReg: if (operands.size() >= 2) { bool is_float = IsSReg(operands[0].GetReg()) || IsSReg(operands[1].GetReg()); os << " " << (is_float ? "fmov" : "mov") << " "; PrintOperand(operands[0], os); os << ", "; PrintOperand(operands[1], os); os << "\n"; } return; case Opcode::FMovWS: if (operands.size() >= 2) { os << " fmov "; PrintOperand(operands[0], os); os << ", "; PrintOperand(operands[1], os); os << "\n"; } return; case Opcode::Scvtf: if (operands.size() >= 2) { os << " scvtf "; PrintOperand(operands[0], os); os << ", "; PrintOperand(operands[1], os); os << "\n"; } return; case Opcode::FCvtzs: if (operands.size() >= 2) { os << " fcvtzs "; PrintOperand(operands[0], os); os << ", "; PrintOperand(operands[1], os); os << "\n"; } return; default: break; } if (asm_op && asm_op[0] != '\0' && !operands.empty()) { os << " " << asm_op; for (size_t i = 0; i < operands.size(); ++i) { os << (i == 0 ? " " : ", "); PrintOperand(operands[i], os); } os << "\n"; } } void PrintGlobals(const MachineModule &module, std::ostream &os) { if (module.GetGlobals().empty()) { return; } os << " .data\n"; for (const auto &global : module.GetGlobals()) { const std::string asm_name = NormalizeAsmSymbol(global.name); bool is_zero_init = false; if (global.kind == MachineGlobal::Kind::I32Scalar && global.init_value == 0) { is_zero_init = true; } if (global.kind == MachineGlobal::Kind::I32Array) { bool all_zero = true; for (auto v : global.init_values) { if (v != 0) { all_zero = false; break; } } if (all_zero) { is_zero_init = true; } } if (is_zero_init) { os << " .bss\n"; os << " .globl " << asm_name << "\n"; os << " .p2align 2\n"; os << asm_name << ":\n"; if (global.kind == MachineGlobal::Kind::I32Scalar) { os << " .space 4\n"; } else { os << " .space " << (global.array_size * 4) << "\n"; } os << " .data\n"; continue; } os << " .globl " << asm_name << "\n"; os << " .p2align 2\n"; os << asm_name << ":\n"; if (global.kind == MachineGlobal::Kind::I32Scalar) { os << " .word " << global.init_value << "\n"; continue; } const size_t init_count = global.init_values.size(); for (size_t i = 0; i < init_count; ++i) { os << " .word " << global.init_values[i] << "\n"; } if (global.array_size > init_count) { os << " .zero " << ((global.array_size - init_count) * 4) << "\n"; } } os << "\n"; } } // namespace void PrintAsm(const MachineFunction &function, std::ostream &os) { const std::string asm_name = NormalizeAsmSymbol(function.GetName()); os << " .text\n"; os << " .globl " << asm_name << "\n"; os << " .p2align 2\n"; os << asm_name << ":\n"; for (const auto &block_ptr : function.GetBlocks()) { if (!block_ptr) { continue; } const auto &block = *block_ptr; PrintBlockLabelRef(function, block.GetLabelId(), os); os << ":\n"; for (const auto &instr : block.GetInstructions()) { PrintInstr(function, instr, os); } } } void PrintAsm(const MachineModule &module, std::ostream &os) { PrintGlobals(module, os); bool first = true; for (const auto &function_ptr : module.GetFunctions()) { if (!function_ptr) { continue; } if (!first) { os << "\n"; } PrintAsm(*function_ptr, os); first = false; } } } // namespace mir