diff --git a/include/mir/MIR.h b/include/mir/MIR.h index 47b8959..55da51e 100644 --- a/include/mir/MIR.h +++ b/include/mir/MIR.h @@ -19,7 +19,17 @@ class MIRContext { MIRContext& DefaultContext(); -enum class PhysReg { W0, W8, W9, X29, X30, SP }; +// AArch64 physical registers +enum class PhysReg { + W0, W1, W2, W3, W4, W5, W6, W7, + W8, W9, W10, W11, W12, W13, W14, W15, + X0, X1, X2, X3, X4, X5, X6, X7, + X8, X9, X10, X11, X12, X13, X14, X15, + X16, X17, + S0, S1, S2, S3, S4, S5, S6, S7, + S8, S9, S10, S11, S12, S13, S14, S15, + X29, X30, SP, WZR, XZR +}; const char* PhysRegName(PhysReg reg); @@ -27,31 +37,67 @@ enum class Opcode { Prologue, Epilogue, MovImm, + MovRR, LoadStack, StoreStack, + AddrStack, + LoadGlobal, + StoreGlobal, AddRR, + AddRRI, + AddRRR_LSL, + SubRR, + MulRR, + SDivRR, + MSubRRR, + Sxtw, + NegR, + CmpRR, + CSet, + FAdd, + FSub, + FMUL, + FDiv, + FNeg, + FCmp, + FCvtSI2FP, + FCvtFP2SI, + LoadR, + StoreR, + Call, + B, + BCond, Ret, }; +enum class CondCode { EQ, NE, LT, LE, GT, GE }; + class Operand { public: - enum class Kind { Reg, Imm, FrameIndex }; + enum class Kind { Reg, Imm, FrameIndex, Label, Global, Cond }; static Operand Reg(PhysReg reg); static Operand Imm(int value); static Operand FrameIndex(int index); + static Operand Label(const std::string& name); + static Operand Global(const std::string& name); + static Operand Cond(CondCode cc); Kind GetKind() const { return kind_; } PhysReg GetReg() const { return reg_; } int GetImm() const { return imm_; } int GetFrameIndex() const { return imm_; } + const std::string& GetLabel() const { return label_; } + const std::string& GetGlobal() const { return label_; } + CondCode GetCond() const { return static_cast(imm_); } private: - Operand(Kind kind, PhysReg reg, int imm); + Operand(Kind kind, PhysReg reg, int imm, std::string label = ""); Kind kind_; PhysReg reg_; int imm_; + std::string label_; }; class MachineInstr { @@ -93,8 +139,10 @@ class MachineFunction { explicit MachineFunction(std::string name); const std::string& GetName() const { return name_; } - MachineBasicBlock& GetEntry() { return entry_; } - const MachineBasicBlock& GetEntry() const { return entry_; } + + MachineBasicBlock& CreateBlock(const std::string& name); + std::vector>& GetBlocks() { return blocks_; } + const std::vector>& GetBlocks() const { return blocks_; } int CreateFrameIndex(int size = 4); FrameSlot& GetFrameSlot(int index); @@ -106,14 +154,35 @@ class MachineFunction { private: std::string name_; - MachineBasicBlock entry_; + std::vector> blocks_; std::vector frame_slots_; int frame_size_ = 0; }; -std::unique_ptr LowerToMIR(const ir::Module& module); +struct GlobalVariable { + std::string name; + int init_value = 0; + size_t size = 4; + bool is_const = false; +}; + +class MachineModule { + public: + MachineModule() = default; + std::vector>& GetFunctions() { return functions_; } + const std::vector>& GetFunctions() const { return functions_; } + + std::vector& GetGlobals() { return globals_; } + const std::vector& GetGlobals() const { return globals_; } + + private: + std::vector> functions_; + std::vector globals_; +}; + +std::unique_ptr LowerToMIR(const ir::Module& module); void RunRegAlloc(MachineFunction& function); void RunFrameLowering(MachineFunction& function); -void PrintAsm(const MachineFunction& function, std::ostream& os); +void PrintAsm(const MachineModule& module, std::ostream& os); } // namespace mir diff --git a/src/main.cpp b/src/main.cpp index 88ed747..2b2ad62 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -46,13 +46,15 @@ int main(int argc, char** argv) { } if (opts.emit_asm) { - auto machine_func = mir::LowerToMIR(*module); - mir::RunRegAlloc(*machine_func); - mir::RunFrameLowering(*machine_func); + auto machine_module = mir::LowerToMIR(*module); + for (auto& func : machine_module->GetFunctions()) { + mir::RunRegAlloc(*func); + mir::RunFrameLowering(*func); + } if (need_blank_line) { std::cout << "\n"; } - mir::PrintAsm(*machine_func, std::cout); + mir::PrintAsm(*machine_module, std::cout); } #else if (opts.emit_ir || opts.emit_asm) { diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index 4d1f65f..71ce7f8 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -16,63 +16,290 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function, return function.GetFrameSlot(operand.GetFrameIndex()); } +void PrintMovImm(std::ostream& os, PhysReg reg, int imm) { + const char* reg_name = PhysRegName(reg); + if (imm >= -32768 && imm <= 65535) { + os << " mov " << reg_name << ", #" << imm << "\n"; + } else { + uint32_t uimm = static_cast(imm); + os << " mov " << reg_name << ", #" << (uimm & 0xFFFF) << "\n"; + os << " movk " << reg_name << ", #" << ((uimm >> 16) & 0xFFFF) << ", lsl #16\n"; + } +} + void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg, int offset) { - os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset - << "]\n"; + if (offset >= -256 && offset <= 255) { + os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset + << "]\n"; + } else { + // Offset out of range for ldur/stur + if (offset < 0) { + PrintMovImm(os, PhysReg::X16, -offset); + os << " sub x16, x29, x16\n"; + } else { + PrintMovImm(os, PhysReg::X16, offset); + os << " add x16, x29, x16\n"; + } + + if (mnemonic[0] == 'l') { // load + os << " ldr " << PhysRegName(reg) << ", [x16]\n"; + } else { // store + os << " str " << PhysRegName(reg) << ", [x16]\n"; + } + } +} + +const char* CondCodeName(CondCode cc) { + switch (cc) { + case CondCode::EQ: return "eq"; + case CondCode::NE: return "ne"; + case CondCode::LT: return "lt"; + case CondCode::LE: return "le"; + case CondCode::GT: return "gt"; + case CondCode::GE: return "ge"; + } + return "??"; } } // namespace -void PrintAsm(const MachineFunction& function, std::ostream& os) { +void PrintAsm(const MachineModule& module, std::ostream& os) { + // Print global variables + if (!module.GetGlobals().empty()) { + os << ".data\n"; + for (const auto& gv : module.GetGlobals()) { + os << ".global " << gv.name << "\n"; + os << ".align 4\n"; + os << gv.name << ":\n"; + if (gv.size > 4 || gv.init_value == 0) { + os << " .zero " << gv.size << "\n"; + } else { + os << " .word " << gv.init_value << "\n"; + } + } + os << "\n"; + } + os << ".text\n"; - os << ".global " << function.GetName() << "\n"; - os << ".type " << function.GetName() << ", %function\n"; - os << function.GetName() << ":\n"; + for (const auto& function : module.GetFunctions()) { + os << ".global " << function->GetName() << "\n"; + os << ".type " << function->GetName() << ", %function\n"; + os << function->GetName() << ":\n"; - for (const auto& inst : function.GetEntry().GetInstructions()) { - const auto& ops = inst.GetOperands(); - switch (inst.GetOpcode()) { - case Opcode::Prologue: - os << " stp x29, x30, [sp, #-16]!\n"; - os << " mov x29, sp\n"; - if (function.GetFrameSize() > 0) { - os << " sub sp, sp, #" << function.GetFrameSize() << "\n"; - } - break; - case Opcode::Epilogue: - if (function.GetFrameSize() > 0) { - os << " add sp, sp, #" << function.GetFrameSize() << "\n"; + for (const auto& block : function->GetBlocks()) { + os << ".L" << function->GetName() << "_" << block->GetName() << ":\n"; + + for (const auto& inst : block->GetInstructions()) { + const auto& ops = inst.GetOperands(); + switch (inst.GetOpcode()) { + case Opcode::Prologue: + os << " stp x29, x30, [sp, #-16]!\n"; + os << " mov x29, sp\n"; + if (function->GetFrameSize() > 0) { + if (function->GetFrameSize() <= 4095) { + os << " sub sp, sp, #" << function->GetFrameSize() << "\n"; + } else { + PrintMovImm(os, PhysReg::X11, function->GetFrameSize()); + os << " sub sp, sp, x11\n"; + } + } + break; + case Opcode::Epilogue: + if (function->GetFrameSize() > 0) { + if (function->GetFrameSize() <= 4095) { + os << " add sp, sp, #" << function->GetFrameSize() << "\n"; + } else { + PrintMovImm(os, PhysReg::X11, function->GetFrameSize()); + os << " add sp, sp, x11\n"; + } + } + os << " ldp x29, x30, [sp], #16\n"; + break; + case Opcode::MovImm: + if (ops.at(1).GetKind() == Operand::Kind::Global) { + os << " adrp " << PhysRegName(ops.at(0).GetReg()) << ", " << ops.at(1).GetGlobal() << "\n"; + os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(0).GetReg()) + << ", :lo12:" << ops.at(1).GetGlobal() << "\n"; + } else { + PrintMovImm(os, ops.at(0).GetReg(), ops.at(1).GetImm()); + } + break; + case Opcode::MovRR: { + const char* dst = PhysRegName(ops.at(0).GetReg()); + const char* src = PhysRegName(ops.at(1).GetReg()); + if (dst[0] == 's' && src[0] == 'w') { + os << " fmov " << dst << ", " << src << "\n"; + } else if (dst[0] == 'w' && src[0] == 's') { + os << " fmov " << dst << ", " << src << "\n"; + } else if (dst[0] == 's' && src[0] == 's') { + os << " fmov " << dst << ", " << src << "\n"; + } else { + os << " mov " << dst << ", " << src << "\n"; + } + break; + } + case Opcode::LoadStack: { + const auto& slot = GetFrameSlot(*function, ops.at(1)); + PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset); + break; + } + case Opcode::StoreStack: { + const auto& slot = GetFrameSlot(*function, ops.at(1)); + PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset); + break; + } + case Opcode::AddrStack: { + const auto& slot = GetFrameSlot(*function, ops.at(1)); + int offset = slot.offset; + if (offset >= 0) { + if (offset <= 4095) { + os << " add " << PhysRegName(ops.at(0).GetReg()) << ", x29, #" << offset << "\n"; + } else { + PrintMovImm(os, PhysReg::X16, offset); + os << " add " << PhysRegName(ops.at(0).GetReg()) << ", x29, x16\n"; + } + } else { + int abs_offset = -offset; + if (abs_offset <= 4095) { + os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", x29, #" << abs_offset << "\n"; + } else { + PrintMovImm(os, PhysReg::X16, abs_offset); + os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", x29, x16\n"; + } + } + break; + } + case Opcode::LoadGlobal: + os << " adrp x16, " << ops.at(1).GetGlobal() << "\n"; + os << " add x16, x16, :lo12:" << ops.at(1).GetGlobal() << "\n"; + os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", [x16]\n"; + break; + case Opcode::StoreGlobal: + os << " adrp x16, " << ops.at(1).GetGlobal() << "\n"; + os << " add x16, x16, :lo12:" << ops.at(1).GetGlobal() << "\n"; + os << " str " << PhysRegName(ops.at(0).GetReg()) << ", [x16]\n"; + break; + case Opcode::AddRR: + os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::AddRRI: + os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", #" << ops.at(2).GetImm() << "\n"; + break; + case Opcode::AddRRR_LSL: { + const char* reg2_name = PhysRegName(ops.at(2).GetReg()); + std::string reg2_str = reg2_name; + std::string extension = "lsl"; + if (reg2_name[0] == 'w') { + extension = "sxtw"; + } + os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << reg2_str << ", " << extension << " #" << ops.at(3).GetImm() << "\n"; + break; + } + case Opcode::SubRR: + os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::MulRR: + os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::SDivRR: + os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::MSubRRR: + os << " msub " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << ", " + << PhysRegName(ops.at(3).GetReg()) << "\n"; + break; + case Opcode::Sxtw: + os << " sxtw " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + break; + case Opcode::NegR: + os << " neg " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + break; + case Opcode::CmpRR: + os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + break; + case Opcode::CSet: + os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", " + << CondCodeName(ops.at(1).GetCond()) << "\n"; + break; + case Opcode::FAdd: + os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::FSub: + os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::FMUL: + os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::FDiv: + os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::FNeg: + os << " fneg " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + break; + case Opcode::FCmp: + os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + break; + case Opcode::FCvtSI2FP: + os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + break; + case Opcode::FCvtFP2SI: + os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + break; + case Opcode::LoadR: + os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", [" + << PhysRegName(ops.at(1).GetReg()) << "]\n"; + break; + case Opcode::StoreR: + os << " str " << PhysRegName(ops.at(0).GetReg()) << ", [" + << PhysRegName(ops.at(1).GetReg()) << "]\n"; + break; + case Opcode::Call: + os << " bl " << ops.at(0).GetLabel() << "\n"; + break; + case Opcode::B: + os << " b .L" << function->GetName() << "_" << ops.at(0).GetLabel() << "\n"; + break; + case Opcode::BCond: + os << " cmp " << PhysRegName(ops.at(1).GetReg()) << ", #0\n"; + os << " b." << CondCodeName(ops.at(0).GetCond()) << " .L" << function->GetName() << "_" << ops.at(2).GetLabel() << "\n"; + break; + case Opcode::Ret: + os << " ret\n"; + break; } - os << " ldp x29, x30, [sp], #16\n"; - break; - case Opcode::MovImm: - os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #" - << ops.at(1).GetImm() << "\n"; - break; - case Opcode::LoadStack: { - const auto& slot = GetFrameSlot(function, ops.at(1)); - PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset); - break; } - case Opcode::StoreStack: { - const auto& slot = GetFrameSlot(function, ops.at(1)); - PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset); - break; - } - case Opcode::AddRR: - os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::Ret: - os << " ret\n"; - break; } + os << ".size " << function->GetName() << ", .-" << function->GetName() << "\n\n"; } - - os << ".size " << function.GetName() << ", .-" << function.GetName() - << "\n"; } } // namespace mir diff --git a/src/mir/FrameLowering.cpp b/src/mir/FrameLowering.cpp index 679ab68..5f1bba4 100644 --- a/src/mir/FrameLowering.cpp +++ b/src/mir/FrameLowering.cpp @@ -19,7 +19,8 @@ void RunFrameLowering(MachineFunction& function) { for (const auto& slot : function.GetFrameSlots()) { cursor += slot.size; if (-cursor < -256) { - throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧")); + // For now, keep the 256-byte limit for simplicity (ldur/stur range) + // throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧")); } } @@ -30,16 +31,24 @@ void RunFrameLowering(MachineFunction& function) { } function.SetFrameSize(AlignTo(cursor, 16)); - auto& insts = function.GetEntry().GetInstructions(); - std::vector lowered; - lowered.emplace_back(Opcode::Prologue); - for (const auto& inst : insts) { - if (inst.GetOpcode() == Opcode::Ret) { - lowered.emplace_back(Opcode::Epilogue); + // Add Prologue to the first block + if (!function.GetBlocks().empty()) { + auto& entry_insts = function.GetBlocks().front()->GetInstructions(); + entry_insts.insert(entry_insts.begin(), MachineInstr(Opcode::Prologue)); + } + + // Add Epilogue before every Ret + for (auto& block : function.GetBlocks()) { + auto& insts = block->GetInstructions(); + std::vector lowered; + for (const auto& inst : insts) { + if (inst.GetOpcode() == Opcode::Ret) { + lowered.emplace_back(Opcode::Epilogue); + } + lowered.push_back(inst); } - lowered.push_back(inst); + insts = std::move(lowered); } - insts = std::move(lowered); } } // namespace mir