You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

428 lines
16 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "mir/MIR.h"
#include <cstdint>
#include <ostream>
#include <stdexcept>
#include "ir/IR.h"
#include "utils/Log.h"
namespace mir {
namespace {
const FrameSlot& GetFrameSlot(const MachineFunction& function,
const Operand& operand) {
if (operand.GetKind() != Operand::Kind::FrameIndex) {
throw std::runtime_error(FormatError("mir", "期望 FrameIndex 操作数"));
}
return function.GetFrameSlot(operand.GetFrameIndex());
}
std::string LocalBlockLabel(const MachineFunction& function,
const std::string& block_name) {
return "." + function.GetName() + "." + block_name;
}
void PrintMoveImm32(std::ostream& os, PhysReg reg, int imm) {
std::uint32_t u = static_cast<std::uint32_t>(imm);
std::uint32_t lo = u & 0xFFFFu;
std::uint32_t hi = (u >> 16) & 0xFFFFu;
os << " movz " << PhysRegName(reg) << ", #" << lo << "\n";
if (hi != 0) {
os << " movk " << PhysRegName(reg) << ", #" << hi << ", lsl #16\n";
}
}
void PrintStackAdjust(std::ostream& os, const char* mnemonic, int size) {
if (size >= 0 && size <= 4095) {
os << " " << mnemonic << " sp, sp, #" << size << "\n";
return;
}
PrintMoveImm32(os, PhysReg::X10, size);
os << " " << mnemonic << " sp, sp, x10\n";
}
void PrintAddrFromX29(std::ostream& os, PhysReg dst, int offset) {
if (offset >= -4095 && offset <= 4095) {
if (offset >= 0) {
os << " add " << PhysRegName(dst) << ", x29, #" << offset << "\n";
} else {
os << " sub " << PhysRegName(dst) << ", x29, #" << (-offset) << "\n";
}
return;
}
// 使用 X11 而不是 X10避免与数组索引偏移量冲突
PrintMoveImm32(os, PhysReg::X11, offset < 0 ? -offset : offset);
if (offset >= 0) {
os << " add " << PhysRegName(dst) << ", x29, x11\n";
} else {
os << " sub " << PhysRegName(dst) << ", x29, x11\n";
}
}
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
int offset) {
// AArch64 ldur/stur 只支持 -256..255 的立即数偏移
if (offset >= -256 && offset <= 255) {
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset
<< "]\n";
} else {
// 大偏移:使用 x11 作为临时寄存器X10 用于数组索引)
bool is_load = (mnemonic[0] == 'l'); // ldur -> ldr
const char* base_mnemonic = is_load ? "ldr" : "str";
PrintAddrFromX29(os, PhysReg::X11, offset);
os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x11]\n";
}
}
const char* CondSuffix(ir::CmpOp cmp_op) {
switch (cmp_op) {
case ir::CmpOp::Eq:
return "eq";
case ir::CmpOp::Ne:
return "ne";
case ir::CmpOp::Lt:
return "lt";
case ir::CmpOp::Le:
return "le";
case ir::CmpOp::Gt:
return "gt";
case ir::CmpOp::Ge:
return "ge";
}
return "eq";
}
// 浮点比较使用 IEEE 754 兼容的条件码(正确处理 NaN
const char* FloatCondSuffix(ir::CmpOp cmp_op) {
switch (cmp_op) {
case ir::CmpOp::Eq:
return "eq"; // Z==1
case ir::CmpOp::Ne:
return "ne"; // Z==0
case ir::CmpOp::Lt:
return "mi"; // N==1 (minus, 正确处理 NaN)
case ir::CmpOp::Le:
return "ls"; // !(C==1 && Z==0) (lower or same, 正确处理 NaN)
case ir::CmpOp::Gt:
return "gt"; // Z==0 && N==V (已正确处理 NaN)
case ir::CmpOp::Ge:
return "ge"; // N==V (已正确处理 NaN)
}
return "eq";
}
} // namespace
void PrintAsm(const MachineModule& module, std::ostream& os) {
// 输出全局变量定义
if (!module.GetGlobalVars().empty()) {
os << ".data\n";
for (const auto& [name, init_val, count, is_float, init_elems] :
module.GetGlobalVars()) {
(void)is_float;
os << ".global " << name << "\n";
os << ".type " << name << ", %object\n";
os << name << ":\n";
if (count == 1) {
// 标量全局变量
os << " .word " << init_val << "\n";
} else {
// 数组全局变量:优先输出显式初始化元素,剩余部分补零。
int emitted = 0;
for (int elem : init_elems) {
if (emitted >= count) {
break;
}
os << " .word " << elem << "\n";
++emitted;
}
if (emitted == 0) {
os << " .zero " << (count * 4) << "\n";
} else if (emitted < count) {
os << " .zero " << ((count - emitted) * 4) << "\n";
}
}
}
os << "\n";
}
os << ".text\n";
for (const auto& func_ptr : module.GetFunctions()) {
const auto& function = *func_ptr;
os << ".global " << function.GetName() << "\n";
os << ".type " << function.GetName() << ", %function\n";
os << function.GetName() << ":\n";
// 遍历所有基本块
for (const auto& bb_ptr : function.GetBlocks()) {
const auto& bb = *bb_ptr;
// 打印块标签entry 块不需要标签,因为函数名已经是标签了)
if (bb.GetName() != "entry") {
os << LocalBlockLabel(function, bb.GetName()) << ":\n";
}
for (const auto& inst : bb.GetInstructions()) {
const auto& ops = inst.GetOperands();
switch (inst.GetOpcode()) {
case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
PrintStackAdjust(os, "sub", function.GetFrameSize());
}
break;
case Opcode::Epilogue:
if (function.GetFrameSize() > 0) {
PrintStackAdjust(os, "add", function.GetFrameSize());
}
os << " ldp x29, x30, [sp], #16\n";
break;
case Opcode::MovImm:
PrintMoveImm32(os, ops.at(0).GetReg(), ops.at(1).GetImm());
break;
case Opcode::MovReg:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FMovImm:
// 通用浮点立即数:先装载 bit pattern再位级移动到 s 寄存器。
PrintMoveImm32(os, PhysReg::W10, ops.at(1).GetImm());
os << " fmov " << PhysRegName(ops.at(0).GetReg()) << ", w10\n";
break;
case Opcode::FMovReg:
os << " fmov " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::LoadStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::StoreStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::LoadStackOffset: {
// ops: reg, frame_index, imm_offset
const auto& slot = GetFrameSlot(function, ops.at(1));
int final_offset = slot.offset + ops.at(2).GetImm();
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), final_offset);
break;
}
case Opcode::StoreStackOffset: {
// ops: reg, frame_index, imm_offset
const auto& slot = GetFrameSlot(function, ops.at(1));
int final_offset = slot.offset + ops.at(2).GetImm();
PrintStackAccess(os, "stur", ops.at(0).GetReg(), final_offset);
break;
}
case Opcode::LoadStackAddr: {
// ops: xN, frame_index
// add xN, x29, #offset
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintAddrFromX29(os, ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::LoadIndirect: {
// ops: wN, xM
// ldr wN, [xM]
os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
break;
}
case Opcode::StoreIndirect: {
// ops: wN, xM
// str wN, [xM]
os << " str " << PhysRegName(ops.at(0).GetReg()) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
break;
}
case Opcode::LoadGlobal: {
// adrp x9, global_var
// add x9, x9, :lo12:global_var
// ldr wN, [x9]
const std::string& name = ops.at(1).GetSymbol();
os << " adrp x9, " << name << "\n";
os << " add x9, x9, :lo12:" << name << "\n";
os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", [x9]\n";
break;
}
case Opcode::StoreGlobal: {
// adrp x9, global_var
// add x9, x9, :lo12:global_var
// str wN, [x9]
const std::string& name = ops.at(1).GetSymbol();
os << " adrp x9, " << name << "\n";
os << " add x9, x9, :lo12:" << name << "\n";
os << " str " << PhysRegName(ops.at(0).GetReg()) << ", [x9]\n";
break;
}
case Opcode::LoadGlobalAddr: {
// adrp xN, global_var
// add xN, xN, :lo12:global_var
const std::string& name = ops.at(1).GetSymbol();
os << " adrp " << PhysRegName(ops.at(0).GetReg()) << ", " << name << "\n";
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(0).GetReg()) << ", :lo12:" << name << "\n";
break;
}
case Opcode::AddRI:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #"
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::SubRI:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #"
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::AddRR_UXTW:
// add xN, xM, wK, uxtw — 零扩展W寄存器后加到X寄存器
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << ", uxtw\n";
break;
case Opcode::SubRR:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::MulRR:
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::DivRR:
os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FAddRR:
os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FSubRR:
os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FMulRR:
os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FDivRR:
os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FSqrtRR:
os << " fsqrt " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::SIToFP:
os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FPToSI:
os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::ModRR:
// 不应该出现Mod 在 lowering 时已展开为 div+mul+sub
throw std::runtime_error(FormatError("mir", "ModRR 不应被打印"));
case Opcode::LsrRI:
os << " lsr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #"
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::LslRI:
os << " lsl " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #"
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::LslRR:
os << " lsl " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::CmpOnlyRR:
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FCmpOnlyRR:
os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::CmpRR: {
// ops: dst, lhs, rhs, cmpop(imm)
auto cmp_op = static_cast<ir::CmpOp>(ops.at(3).GetImm());
os << " cmp " << PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", "
<< CondSuffix(cmp_op) << "\n";
break;
}
case Opcode::FCmpRR: {
// ops: dst(wN), lhs(sN), rhs(sN), cmpop(imm)
auto cmp_op = static_cast<ir::CmpOp>(ops.at(3).GetImm());
os << " fcmp " << PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", "
<< FloatCondSuffix(cmp_op) << "\n";
break;
}
case Opcode::Bl:
os << " bl " << ops.at(0).GetSymbol() << "\n";
break;
case Opcode::B:
os << " b " << LocalBlockLabel(function, ops.at(0).GetSymbol())
<< "\n";
break;
case Opcode::Cbnz:
os << " cbnz " << PhysRegName(ops.at(0).GetReg())
<< ", " << LocalBlockLabel(function, ops.at(1).GetSymbol())
<< "\n";
break;
case Opcode::Cbz:
os << " cbz " << PhysRegName(ops.at(0).GetReg())
<< ", " << LocalBlockLabel(function, ops.at(1).GetSymbol())
<< "\n";
break;
case Opcode::Bcond:
// ops: symbol, cmpop(imm)
os << " b." << CondSuffix(static_cast<ir::CmpOp>(ops.at(1).GetImm()))
<< " " << LocalBlockLabel(function, ops.at(0).GetSymbol())
<< "\n";
break;
case Opcode::FBcond:
// ops: symbol, cmpop(imm) - 浮点条件分支
os << " b." << FloatCondSuffix(static_cast<ir::CmpOp>(ops.at(1).GetImm()))
<< " " << LocalBlockLabel(function, ops.at(0).GetSymbol())
<< "\n";
break;
case Opcode::Ret:
os << " ret\n";
break;
}
}
}
os << ".size " << function.GetName() << ", .-" << function.GetName()
<< "\n\n";
}
}
} // namespace mir