You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/mir/AsmPrinter.cpp

904 lines
30 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "mir/MIR.h"
#include <iostream>
#include <stdexcept>
#include <string>
#include "utils/Log.h"
namespace mir
{
namespace
{
std::string NormalizeAsmSymbol(std::string name)
{
while (!name.empty() && (name.front() == '@' || name.front() == '%'))
{
name.erase(name.begin());
}
return name;
}
const char *OpcodeToAsm(Opcode opcode)
{
switch (opcode)
{
case Opcode::Prologue:
case Opcode::Epilogue:
case Opcode::LoadStackAddr:
case Opcode::LoadGlobal:
case Opcode::StoreGlobal:
case Opcode::LoadGlobalAddr:
case Opcode::LoadMem:
case Opcode::StoreMem:
case Opcode::Uxtw:
case Opcode::Sxtw:
return "";
case Opcode::MovImm:
return "mov";
case Opcode::LoadStack:
return "ldur";
case Opcode::StoreStack:
return "stur";
case Opcode::AddRR:
return "add";
case Opcode::SubRR:
return "sub";
case Opcode::MulRR:
return "mul";
case Opcode::DivRR:
return "sdiv";
case Opcode::FAddRR:
return "fadd";
case Opcode::FSubRR:
return "fsub";
case Opcode::FMulRR:
return "fmul";
case Opcode::FDivRR:
return "fdiv";
case Opcode::ModRR:
return "msub";
case Opcode::AndRR:
return "and";
case Opcode::OrRR:
return "orr";
case Opcode::XorRR:
return "eor";
case Opcode::ShlRR:
return "lsl";
case Opcode::ShrRR:
return "lsr";
case Opcode::CmpRR:
return "cmp";
case Opcode::CmpImm:
return "cmp";
case Opcode::FCmpRR:
return "fcmp";
case Opcode::CSet:
return "cset";
case Opcode::Csneg:
return "csneg";
case Opcode::Scvtf:
return "scvtf";
case Opcode::FCvtzs:
return "fcvtzs";
case Opcode::FMovWS:
return "fmov";
case Opcode::Br:
return "b";
case Opcode::CondBr:
return "b";
case Opcode::Call:
return "bl";
case Opcode::Ret:
return "ret";
case Opcode::LoadAddr:
return "adrp";
case Opcode::MovReg:
return "mov";
default:
return "";
}
}
const char *CondCodeToAsm(CondCode cond)
{
switch (cond)
{
case CondCode::EQ:
return "eq";
case CondCode::NE:
return "ne";
case CondCode::LT:
return "lt";
case CondCode::LE:
return "le";
case CondCode::GT:
return "gt";
case CondCode::GE:
return "ge";
default:
return "";
}
}
bool IsXReg(PhysReg reg)
{
return reg >= PhysReg::X0 && reg <= PhysReg::X30;
}
bool IsWReg(PhysReg reg)
{
return reg >= PhysReg::W0 && reg <= PhysReg::W30;
}
bool IsSReg(PhysReg reg)
{
return reg >= PhysReg::S0 && reg <= PhysReg::S31;
}
PhysReg PickModQuotientReg(const Operand &dst,
const Operand &lhs,
const Operand &rhs)
{
const PhysReg candidates[] = {
PhysReg::W14, PhysReg::W12, PhysReg::W15,
PhysReg::W11, PhysReg::W10, PhysReg::W9, PhysReg::W8};
for (PhysReg reg : candidates)
{
bool conflict = false;
for (const Operand *op : {&dst, &lhs, &rhs})
{
if (op->GetKind() == Operand::Kind::Reg &&
IsWReg(op->GetReg()) &&
op->GetReg() == reg)
{
conflict = true;
break;
}
}
if (!conflict)
{
return reg;
}
}
return PhysReg::W14;
}
std::string BlockLabel(const MachineFunction &function, int label_id)
{
return ".L." + NormalizeAsmSymbol(function.GetName()) + "." +
std::to_string(label_id);
}
void PrintBlockLabelRef(const MachineFunction &function, int label_id,
std::ostream &os)
{
os << BlockLabel(function, label_id);
}
void PrintOperand(const Operand &operand, std::ostream &os)
{
switch (operand.GetKind())
{
case Operand::Kind::Reg:
os << PhysRegName(operand.GetReg());
break;
case Operand::Kind::VReg:
os << "%vreg" << operand.GetVRegId();
break;
case Operand::Kind::Imm:
os << "#" << operand.GetImm();
break;
case Operand::Kind::FrameIndex:
os << "<fi#" << operand.GetFrameIndex() << ">";
break;
case Operand::Kind::Label:
os << ".L" << operand.GetLabel();
break;
case Operand::Kind::Symbol:
os << NormalizeAsmSymbol(operand.GetSymbol());
break;
}
}
void EmitLargeImmediate(PhysReg target, int value, std::ostream &os)
{
if (value >= -32768 && value <= 65535)
{
os << " mov " << PhysRegName(target) << ", #" << value << "\n";
return;
}
const unsigned int uvalue = static_cast<unsigned int>(value);
bool emitted = false;
for (int shift = 0; shift < 32; shift += 16)
{
const unsigned short part = (uvalue >> shift) & 0xFFFF;
if (part == 0 && emitted)
{
continue;
}
if (!emitted)
{
os << " movz " << PhysRegName(target) << ", #" << part;
if (shift > 0)
{
os << ", lsl #" << shift;
}
os << "\n";
emitted = true;
}
else if (part != 0)
{
os << " movk " << PhysRegName(target) << ", #" << part;
if (shift > 0)
{
os << ", lsl #" << shift;
}
os << "\n";
}
}
if (!emitted)
{
os << " mov " << PhysRegName(target) << ", #0\n";
}
}
void EmitStackAdjust(const char *op, int amount, std::ostream &os)
{
while (amount > 0)
{
const int chunk = amount > 4095 ? 4095 : amount;
os << " " << op << " sp, sp, #" << chunk << "\n";
amount -= chunk;
}
}
PhysReg PrinterScratchXReg()
{
// 汇编打印阶段展开伪指令时,需要一个全局保留的 scratch。
// 不能再复用 x14/x15否则会把 lowering 已经放好的数组/全局基址冲掉。
// 当前 lowering 不使用 x13因此固定保留 x13 给打印阶段使用。
return PhysReg::X13;
}
void EmitAddressFromBase(PhysReg target_xreg, PhysReg base_reg, int offset,
std::ostream &os)
{
os << " mov " << PhysRegName(target_xreg) << ", "
<< PhysRegName(base_reg) << "\n";
while (offset > 0)
{
const int chunk = offset > 4095 ? 4095 : offset;
os << " add " << PhysRegName(target_xreg) << ", "
<< PhysRegName(target_xreg) << ", #" << chunk << "\n";
offset -= chunk;
}
while (offset < 0)
{
const int chunk = (-offset) > 4095 ? 4095 : (-offset);
os << " sub " << PhysRegName(target_xreg) << ", "
<< PhysRegName(target_xreg) << ", #" << chunk << "\n";
offset += chunk;
}
}
void PrintStackAccess(Opcode opcode, const Operand &reg, int offset,
std::ostream &os)
{
const char *narrow_op = (opcode == Opcode::LoadStack) ? "ldur" : "stur";
const char *wide_op = (opcode == Opcode::LoadStack) ? "ldr" : "str";
if (offset >= -256 && offset <= 255)
{
os << " " << narrow_op << " ";
PrintOperand(reg, os);
os << ", [x29, #" << offset << "]\n";
return;
}
const PhysReg scratch_xreg = PrinterScratchXReg();
EmitAddressFromBase(scratch_xreg, PhysReg::X29, offset, os);
os << " " << wide_op << " ";
PrintOperand(reg, os);
os << ", [" << PhysRegName(scratch_xreg) << "]\n";
}
void PrintGlobalAddr(const Operand &dst, const std::string &symbol,
std::ostream &os)
{
const std::string asm_symbol = NormalizeAsmSymbol(symbol);
os << " adrp ";
PrintOperand(dst, os);
os << ", " << asm_symbol << "\n";
os << " add ";
PrintOperand(dst, os);
os << ", ";
PrintOperand(dst, os);
os << ", :lo12:" << asm_symbol << "\n";
}
void PrintGlobalAccess(Opcode opcode, const Operand &reg,
const std::string &symbol, std::ostream &os)
{
const std::string asm_symbol = NormalizeAsmSymbol(symbol);
const PhysReg scratch_xreg = PrinterScratchXReg();
os << " adrp " << PhysRegName(scratch_xreg) << ", " << asm_symbol << "\n";
os << " " << (opcode == Opcode::LoadGlobal ? "ldr " : "str ");
PrintOperand(reg, os);
os << ", [" << PhysRegName(scratch_xreg) << ", #:lo12:" << asm_symbol << "]\n";
}
void PrintMemAccess(Opcode opcode, const Operand &data_reg,
const Operand &addr_reg, std::ostream &os)
{
os << " " << (opcode == Opcode::LoadMem ? "ldr " : "str ");
PrintOperand(data_reg, os);
os << ", [";
PrintOperand(addr_reg, os);
os << "]\n";
}
int ResolveFrameOffset(const MachineFunction &function, const Operand &operand)
{
if (operand.GetKind() != Operand::Kind::FrameIndex)
{
throw std::runtime_error(FormatError("mir", "期望 FrameIndex 操作数"));
}
return function.GetFrameSlot(operand.GetFrameIndex()).offset;
}
bool IsStackArgSlot(const MachineFunction &function, const Operand &operand)
{
if (operand.GetKind() != Operand::Kind::FrameIndex)
{
return false;
}
const auto &slot = function.GetFrameSlot(operand.GetFrameIndex());
return slot.is_stack_arg && !slot.is_callee_stack_arg;
}
void PrintStackArgAccess(Opcode opcode, const Operand &reg, int offset,
std::ostream &os)
{
const char *wide_op = (opcode == Opcode::LoadStack) ? "ldr" : "str";
bool is_32bit = IsWReg(reg.GetReg()) || IsSReg(reg.GetReg());
int max_imm = is_32bit ? 16380 : 32760;
if (offset >= 0 && offset <= max_imm)
{
os << " " << wide_op << " ";
PrintOperand(reg, os);
os << ", [sp, #" << offset << "]\n";
return;
}
const PhysReg scratch_xreg = PrinterScratchXReg();
EmitAddressFromBase(scratch_xreg, PhysReg::SP, offset, os);
os << " " << wide_op << " ";
PrintOperand(reg, os);
os << ", [" << PhysRegName(scratch_xreg) << "]\n";
}
void PrintInstr(const MachineFunction &function, const MachineInstr &instr,
std::ostream &os)
{
const char *asm_op = OpcodeToAsm(instr.GetOpcode());
const auto &operands = instr.GetOperands();
switch (instr.GetOpcode())
{
case Opcode::Prologue:
{
const auto &cs_regs = function.GetCalleeSavedRegs();
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0)
{
EmitStackAdjust("sub", function.GetFrameSize(), os);
}
int cs_offset = 0;
for (auto r : cs_regs)
{
if (r >= PhysReg::X0 && r <= PhysReg::X30)
{
os << " str " << PhysRegName(r) << ", [sp, #" << cs_offset << "]\n";
cs_offset += 8;
}
else if (r >= PhysReg::S0 && r <= PhysReg::S31)
{
os << " str " << PhysRegName(r) << ", [sp, #" << cs_offset << "]\n";
cs_offset += 4;
}
}
return;
}
case Opcode::Epilogue:
{
const auto &cs_regs = function.GetCalleeSavedRegs();
int cs_offset = 0;
for (auto r : cs_regs)
{
if (r >= PhysReg::X0 && r <= PhysReg::X30)
{
os << " ldr " << PhysRegName(r) << ", [sp, #" << cs_offset << "]\n";
cs_offset += 8;
}
else if (r >= PhysReg::S0 && r <= PhysReg::S31)
{
os << " ldr " << PhysRegName(r) << ", [sp, #" << cs_offset << "]\n";
cs_offset += 4;
}
}
if (function.GetFrameSize() > 0)
{
EmitStackAdjust("add", function.GetFrameSize(), os);
}
os << " ldp x29, x30, [sp], #16\n";
os << " ret\n";
return;
}
case Opcode::Ret:
os << " ret\n";
return;
case Opcode::MovImm:
if (operands.size() >= 2)
{
EmitLargeImmediate(operands[0].GetReg(), operands[1].GetImm(), os);
}
return;
case Opcode::LoadStack:
case Opcode::StoreStack:
if (operands.size() >= 2 &&
operands[0].GetKind() == Operand::Kind::Reg &&
operands[1].GetKind() == Operand::Kind::FrameIndex)
{
if (IsStackArgSlot(function, operands[1]))
{
PrintStackArgAccess(instr.GetOpcode(), operands[0],
ResolveFrameOffset(function, operands[1]), os);
}
else
{
PrintStackAccess(instr.GetOpcode(), operands[0],
ResolveFrameOffset(function, operands[1]), os);
}
}
return;
case Opcode::LoadStackAddr:
if (operands.size() >= 2 &&
operands[0].GetKind() == Operand::Kind::Reg &&
operands[1].GetKind() == Operand::Kind::FrameIndex)
{
EmitAddressFromBase(operands[0].GetReg(), PhysReg::X29,
ResolveFrameOffset(function, operands[1]), os);
}
return;
case Opcode::LoadGlobal:
case Opcode::StoreGlobal:
if (operands.size() >= 2 &&
operands[0].GetKind() == Operand::Kind::Reg &&
operands[1].GetKind() == Operand::Kind::Symbol)
{
PrintGlobalAccess(instr.GetOpcode(), operands[0],
operands[1].GetSymbol(), os);
}
return;
case Opcode::LoadGlobalAddr:
if (operands.size() >= 2 &&
operands[0].GetKind() == Operand::Kind::Reg &&
operands[1].GetKind() == Operand::Kind::Symbol)
{
PrintGlobalAddr(operands[0], operands[1].GetSymbol(), os);
}
return;
case Opcode::LoadMem:
case Opcode::StoreMem:
if (operands.size() >= 2 &&
operands[0].GetKind() == Operand::Kind::Reg &&
operands[1].GetKind() == Operand::Kind::Reg &&
IsXReg(operands[1].GetReg()))
{
PrintMemAccess(instr.GetOpcode(), operands[0], operands[1], os);
}
return;
case Opcode::Uxtw:
case Opcode::Sxtw:
if (operands.size() >= 2)
{
os << " " << (instr.GetOpcode() == Opcode::Uxtw ? "uxtw" : "sxtw") << " ";
PrintOperand(operands[0], os);
os << ", ";
PrintOperand(operands[1], os);
os << "\n";
}
return;
case Opcode::CSet:
if (operands.size() >= 2)
{
os << " cset ";
PrintOperand(operands[0], os);
os << ", "
<< CondCodeToAsm(static_cast<CondCode>(operands[1].GetImm()))
<< "\n";
}
return;
case Opcode::Br:
if (!operands.empty() && operands[0].GetKind() == Operand::Kind::Label)
{
os << " b ";
PrintBlockLabelRef(function, operands[0].GetLabel(), os);
os << "\n";
}
return;
case Opcode::CondBr:
if (operands.size() >= 2 &&
operands[0].GetKind() == Operand::Kind::Imm &&
operands[1].GetKind() == Operand::Kind::Label)
{
os << " b."
<< CondCodeToAsm(static_cast<CondCode>(operands[0].GetImm()))
<< " ";
PrintBlockLabelRef(function, operands[1].GetLabel(), os);
os << "\n";
}
return;
case Opcode::ModRR:
if (operands.size() >= 3)
{
const PhysReg qreg = PickModQuotientReg(operands[0], operands[1], operands[2]);
os << " sdiv " << PhysRegName(qreg) << ", ";
PrintOperand(operands[1], os);
os << ", ";
PrintOperand(operands[2], os);
os << "\n";
os << " msub ";
PrintOperand(operands[0], os);
os << ", " << PhysRegName(qreg) << ", ";
PrintOperand(operands[2], os);
os << ", ";
PrintOperand(operands[1], os);
os << "\n";
}
return;
case Opcode::ShlRR:
case Opcode::ShrRR:
case Opcode::AsrRR:
if (operands.size() >= 3)
{
const char *shift_op = "lsl";
if (instr.GetOpcode() == Opcode::ShrRR)
shift_op = "lsr";
else if (instr.GetOpcode() == Opcode::AsrRR)
shift_op = "asr";
os << " " << shift_op << " ";
PrintOperand(operands[0], os);
os << ", ";
PrintOperand(operands[1], os);
os << ", ";
if (operands[2].GetKind() == Operand::Kind::Imm)
{
os << "#" << operands[2].GetImm();
}
else
{
PrintOperand(operands[2], os);
}
os << "\n";
}
return;
case Opcode::Asr64RR:
if (operands.size() >= 3)
{
os << " asr ";
if (operands[0].GetKind() == Operand::Kind::Reg && IsWReg(operands[0].GetReg()))
os << PhysRegName(static_cast<PhysReg>(static_cast<int>(operands[0].GetReg()) + 31));
else
PrintOperand(operands[0], os);
os << ", ";
if (operands[1].GetKind() == Operand::Kind::Reg && IsWReg(operands[1].GetReg()))
os << PhysRegName(static_cast<PhysReg>(static_cast<int>(operands[1].GetReg()) + 31));
else
PrintOperand(operands[1], os);
os << ", ";
if (operands[2].GetKind() == Operand::Kind::Imm)
os << "#" << operands[2].GetImm();
else
PrintOperand(operands[2], os);
os << "\n";
}
return;
case Opcode::CmpImm:
if (operands.size() >= 2)
{
os << " cmp ";
PrintOperand(operands[0], os);
os << ", #" << operands[1].GetImm() << "\n";
}
return;
case Opcode::Csel:
if (operands.size() >= 4)
{
os << " csel ";
PrintOperand(operands[0], os);
os << ", ";
PrintOperand(operands[1], os);
os << ", ";
PrintOperand(operands[2], os);
os << ", " << CondCodeToAsm(static_cast<CondCode>(operands[3].GetImm())) << "\n";
}
return;
case Opcode::Csneg:
if (operands.size() >= 4)
{
os << " csneg ";
PrintOperand(operands[0], os);
os << ", ";
PrintOperand(operands[1], os);
os << ", ";
PrintOperand(operands[2], os);
os << ", " << CondCodeToAsm(static_cast<CondCode>(operands[3].GetImm())) << "\n";
}
return;
case Opcode::Smull:
if (operands.size() >= 3)
{
os << " smull ";
if (operands[0].GetKind() == Operand::Kind::Reg && IsWReg(operands[0].GetReg()))
os << PhysRegName(static_cast<PhysReg>(static_cast<int>(operands[0].GetReg()) + 31));
else
PrintOperand(operands[0], os);
os << ", ";
PrintOperand(operands[1], os);
os << ", ";
PrintOperand(operands[2], os);
os << "\n";
}
return;
case Opcode::Msub:
if (operands.size() >= 4)
{
os << " msub ";
PrintOperand(operands[0], os);
os << ", ";
PrintOperand(operands[1], os);
os << ", ";
PrintOperand(operands[2], os);
os << ", ";
PrintOperand(operands[3], os);
os << "\n";
}
return;
case Opcode::NegRR:
if (operands.size() >= 2)
{
os << " neg ";
PrintOperand(operands[0], os);
os << ", ";
PrintOperand(operands[1], os);
os << "\n";
}
return;
case Opcode::MovReg:
if (operands.size() >= 2)
{
bool is_float = IsSReg(operands[0].GetReg()) || IsSReg(operands[1].GetReg());
os << " " << (is_float ? "fmov" : "mov") << " ";
PrintOperand(operands[0], os);
os << ", ";
PrintOperand(operands[1], os);
os << "\n";
}
return;
case Opcode::FMovWS:
if (operands.size() >= 2)
{
os << " fmov ";
PrintOperand(operands[0], os);
os << ", ";
PrintOperand(operands[1], os);
os << "\n";
}
return;
case Opcode::Scvtf:
if (operands.size() >= 2)
{
os << " scvtf ";
PrintOperand(operands[0], os);
os << ", ";
PrintOperand(operands[1], os);
os << "\n";
}
return;
case Opcode::FCvtzs:
if (operands.size() >= 2)
{
os << " fcvtzs ";
PrintOperand(operands[0], os);
os << ", ";
PrintOperand(operands[1], os);
os << "\n";
}
return;
default:
break;
}
if (asm_op && asm_op[0] != '\0' && !operands.empty())
{
os << " " << asm_op;
for (size_t i = 0; i < operands.size(); ++i)
{
os << (i == 0 ? " " : ", ");
PrintOperand(operands[i], os);
}
os << "\n";
}
}
void PrintGlobals(const MachineModule &module, std::ostream &os)
{
if (module.GetGlobals().empty())
{
return;
}
os << " .data\n";
for (const auto &global : module.GetGlobals())
{
const std::string asm_name = NormalizeAsmSymbol(global.name);
bool is_zero_init = false;
if (global.kind == MachineGlobal::Kind::I32Scalar && global.init_value == 0)
{
is_zero_init = true;
}
if (global.kind == MachineGlobal::Kind::I32Array)
{
bool all_zero = true;
for (auto v : global.init_values)
{
if (v != 0)
{
all_zero = false;
break;
}
}
if (all_zero)
{
is_zero_init = true;
}
}
if (is_zero_init)
{
os << " .bss\n";
os << " .globl " << asm_name << "\n";
os << " .p2align 2\n";
os << asm_name << ":\n";
if (global.kind == MachineGlobal::Kind::I32Scalar)
{
os << " .space 4\n";
}
else
{
os << " .space " << (global.array_size * 4) << "\n";
}
os << " .data\n";
continue;
}
os << " .globl " << asm_name << "\n";
os << " .p2align 2\n";
os << asm_name << ":\n";
if (global.kind == MachineGlobal::Kind::I32Scalar)
{
os << " .word " << global.init_value << "\n";
continue;
}
const size_t init_count = global.init_values.size();
for (size_t i = 0; i < init_count; ++i)
{
os << " .word " << global.init_values[i] << "\n";
}
if (global.array_size > init_count)
{
os << " .zero " << ((global.array_size - init_count) * 4) << "\n";
}
}
os << "\n";
}
} // namespace
void PrintAsm(const MachineFunction &function, std::ostream &os)
{
const std::string asm_name = NormalizeAsmSymbol(function.GetName());
os << " .text\n";
os << " .globl " << asm_name << "\n";
os << " .p2align 2\n";
os << asm_name << ":\n";
for (const auto &block_ptr : function.GetBlocks())
{
if (!block_ptr)
{
continue;
}
const auto &block = *block_ptr;
PrintBlockLabelRef(function, block.GetLabelId(), os);
os << ":\n";
for (const auto &instr : block.GetInstructions())
{
PrintInstr(function, instr, os);
}
}
}
void PrintAsm(const MachineModule &module, std::ostream &os)
{
PrintGlobals(module, os);
bool first = true;
for (const auto &function_ptr : module.GetFunctions())
{
if (!function_ptr)
{
continue;
}
if (!first)
{
os << "\n";
}
PrintAsm(*function_ptr, os);
first = false;
}
}
} // namespace mir