You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/mir/AsmPrinter.cpp

622 lines
20 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "mir/MIR.h"
#include <ostream>
#include <stdexcept>
#include <set>
#include "utils/Log.h"
//#define DEBUG_Asm
#ifdef DEBUG_Asm
#include <iostream>
#define DEBUG_MSG(msg) std::cerr << "[Asm Debug] " << msg << std::endl
#else
#define DEBUG_MSG(msg)
#endif
namespace mir {
namespace {
static void PrintLoadImm64(std::ostream& os, PhysReg reg, uint64_t imm);
const FrameSlot& GetFrameSlot(const MachineFunction& function,
const Operand& operand) {
if (operand.GetKind() != Operand::Kind::FrameIndex) {
throw std::runtime_error(FormatError("mir", "期望 FrameIndex 操作数"));
}
return function.GetFrameSlot(operand.GetFrameIndex());
}
void PrintStackAccess(std::ostream& os, const char* insn, PhysReg reg, int64_t offset) {
// offset 通常是负数,例如 -8, -24, -40 等
if (offset >= -256 && offset <= 255) {
os << " " << insn << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n";
return;
}
// 大偏移量:用 x16 计算 x29 + offset然后间接访问
os << " mov x16, x29\n";
int64_t abs_offset = (offset >= 0) ? offset : -offset;
if (abs_offset <= 4095) {
if (offset >= 0) {
os << " add x16, x16, #" << offset << "\n";
} else {
os << " sub x16, x16, #" << abs_offset << "\n";
}
} else {
// 分解大偏移量
PrintLoadImm64(os, PhysReg::X17, abs_offset);
if (offset >= 0) {
os << " add x16, x16, x17\n";
} else {
os << " sub x16, x16, x17\n";
}
}
os << " " << insn << " " << PhysRegName(reg) << ", [x16]\n";
}
// 打印单个操作数
void PrintOperand(std::ostream& os, const Operand& op) {
switch (op.GetKind()) {
case Operand::Kind::Reg:
os << PhysRegName(op.GetReg());
break;
case Operand::Kind::VReg:
throw std::runtime_error(
FormatError("asm", "寄存器分配未完成: 存在虚拟寄存器 #" +
std::to_string(op.GetVReg())));
case Operand::Kind::Imm:
os << "#" << op.GetImm();
break;
case Operand::Kind::FrameIndex:
os << "[sp, #" << op.GetFrameIndex() << "]";
break;
case Operand::Kind::Cond:
os << CondCodeName(op.GetCondCode());
break;
case Operand::Kind::Label:
DEBUG_MSG("label is" << op.GetLabel());
os << op.GetLabel();
break;
}
}
// 判断立即数是否可作为 AArch64 ADD/SUB 指令的 12 位立即数(可左移 0 或 12 位)
static bool IsLegalAddSubImm(int64_t imm) {
if (imm < 0) imm = -imm; // 取绝对值,因为移位规则对称
if (imm <= 4095) return true; // 0-4095 直接合法
if ((imm & 0xFFF) == 0 && imm <= 4095 * 4096) return true; // 4096 的倍数且 ≤ 16773120
return false;
}
// ---- 寄存器宽度规范化 ----
static bool IsWReg(PhysReg reg) {
return reg >= PhysReg::W0 && reg <= PhysReg::W30;
}
static bool IsXReg(PhysReg reg) {
return reg >= PhysReg::X0 && reg <= PhysReg::X30;
}
static bool IsSReg(PhysReg reg) {
return reg >= PhysReg::S0 && reg <= PhysReg::S31;
}
// Xn → Wn, Wn → Wn, Sn → Sn
static PhysReg ToW(PhysReg reg) {
if (IsXReg(reg))
return static_cast<PhysReg>(
static_cast<int>(reg) - static_cast<int>(PhysReg::X0) + static_cast<int>(PhysReg::W0));
return reg;
}
// Wn → Xn, Xn → Xn, Sn → Sn
static PhysReg ToX(PhysReg reg) {
if (IsWReg(reg))
return static_cast<PhysReg>(
static_cast<int>(reg) - static_cast<int>(PhysReg::W0) + static_cast<int>(PhysReg::X0));
return reg;
}
// 检查一组操作数是否全是同一宽度W/X/S
static bool AllSameRegWidth(const std::vector<Operand>& ops) {
int kind = -1;
for (const auto& op : ops) {
if (op.GetKind() != Operand::Kind::Reg) continue;
PhysReg r = op.GetReg();
if (IsWReg(r)) { if (kind == -1) kind = 0; else if (kind != 0) return false; }
else if (IsXReg(r)) { if (kind == -1) kind = 1; else if (kind != 1) return false; }
else if (IsSReg(r)) { if (kind == -1) kind = 2; else if (kind != 2) return false; }
}
return true;
}
// 根据目的地宽度规范化所有寄存器操作数
static void NormalizeRegOps(std::vector<Operand>& ops, PhysReg dst) {
PhysReg base = dst;
bool wantW = IsWReg(base);
bool wantX = IsXReg(base);
for (auto& op : ops) {
if (op.GetKind() != Operand::Kind::Reg) continue;
if (wantW) op = Operand::Reg(ToW(op.GetReg()));
else if (wantX) op = Operand::Reg(ToX(op.GetReg()));
}
}
// 在匿名命名空间添加辅助函数
static void PrintLoadImm64(std::ostream& os, PhysReg reg, uint64_t imm) {
// 输出 movz + movk 序列
uint16_t part0 = imm & 0xFFFF;
uint16_t part1 = (imm >> 16) & 0xFFFF;
uint16_t part2 = (imm >> 32) & 0xFFFF;
uint16_t part3 = (imm >> 48) & 0xFFFF;
os << " movz " << PhysRegName(reg) << ", #" << part0;
if (part1 != 0 || part2 != 0 || part3 != 0) {
os << ", lsl #0";
}
os << "\n";
if (part1 != 0) {
os << " movk " << PhysRegName(reg) << ", #" << part1 << ", lsl #16\n";
}
if (part2 != 0) {
os << " movk " << PhysRegName(reg) << ", #" << part2 << ", lsl #32\n";
}
if (part3 != 0) {
os << " movk " << PhysRegName(reg) << ", #" << part3 << ", lsl #48\n";
}
}
// 打印单条指令
void PrintInstruction(std::ostream& os, const MachineInstr& instr,
const MachineFunction& function) {
const auto& ops = instr.GetOperands();
switch (instr.GetOpcode()) {
case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
int64_t size = function.GetFrameSize();
if (IsLegalAddSubImm(size)) {
os << " sub sp, sp, #" << size << "\n";
} else {
PrintLoadImm64(os, PhysReg::X16, size);
os << " sub sp, sp, x16\n";
}
}
break;
case Opcode::Epilogue:
if (function.GetFrameSize() > 0) {
int64_t size = function.GetFrameSize();
if (IsLegalAddSubImm(size)) {
os << " add sp, sp, #" << size << "\n";
} else {
PrintLoadImm64(os, PhysReg::X16, size);
os << " add sp, sp, x16\n";
}
}
os << " ldp x29, x30, [sp], #16\n";
break;
case Opcode::MovImm:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::MovReg:{
PhysReg dst = ops.at(0).GetReg();
PhysReg src = ops.at(1).GetReg();
if (IsSReg(dst) || IsSReg(src)) {
// 涉及 S 寄存器的 move使用 fmov
if (!IsSReg(dst)) dst = ToW(dst); // 确保是 W 寄存器
if (!IsSReg(src)) src = ToW(src);
os << " fmov " << PhysRegName(dst) << ", " << PhysRegName(src) << "\n";
} else {
// GPR move规范化宽度
if (IsWReg(dst) && IsXReg(src)) {
src = ToW(src);
} else if (IsXReg(dst) && IsWReg(src)) {
src = ToX(src);
}
os << " mov " << PhysRegName(dst) << ", " << PhysRegName(src) << "\n";
}
break;
}
case Opcode::StoreStack: {
if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::FrameIndex) {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
} else if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
// 间接存储:基址必须是 X 寄存器
PhysReg base = ToX(ops.at(1).GetReg());
os << " str " << PhysRegName(ops.at(0).GetReg()) << ", ["
<< PhysRegName(base) << "]\n";
} else {
throw std::runtime_error("StoreStack: 无效的操作数类型");
}
break;
}
case Opcode::LoadStack: {
if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::FrameIndex) {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
} else if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
// 间接加载:基址必须是 X 寄存器
PhysReg base = ToX(ops.at(1).GetReg());
os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", ["
<< PhysRegName(base) << "]\n";
} else {
throw std::runtime_error("LoadStack: 无效的操作数类型");
}
break;
}
case Opcode::StoreStackPair:
// stp x29, x30, [sp, #-16]!
os << " stp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", [sp";
if (ops.size() > 2 && ops.at(2).GetKind() == Operand::Kind::Imm) {
int offset = ops.at(2).GetImm();
os << ", #" << offset;
}
os << "]!\n"; // 注意添加 ! 表示 pre-index
break;
case Opcode::LoadStackPair:
// ldp x29, x30, [sp], #16
os << " ldp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", [sp]";
if (ops.size() > 2 && ops.at(2).GetKind() == Operand::Kind::Imm) {
int offset = ops.at(2).GetImm();
os << ", #" << offset;
}
os << "\n";
break;
case Opcode::AddRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " add " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::AddRI: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " add " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", #"
<< nops[2].GetImm() << "\n";
break;
}
case Opcode::SubRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " sub " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::SubRI: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " sub " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", #"
<< nops[2].GetImm() << "\n";
break;
}
case Opcode::MulRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " mul " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::SDivRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " sdiv " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::UDivRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " udiv " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::FAddRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " fadd " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::FSubRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " fsub " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::FMulRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " fmul " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::FDivRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " fdiv " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::CmpRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " cmp " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << "\n";
break;
}
case Opcode::CmpRI: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " cmp " << PhysRegName(nops[0].GetReg()) << ", #"
<< nops[1].GetImm() << "\n";
break;
}
case Opcode::FCmpRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " fcmp " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << "\n";
break;
}
case Opcode::SIToFP: {
PhysReg dst = ops.at(0).GetReg();
PhysReg src = ops.at(1).GetReg();
if (!IsWReg(src)) src = ToW(src);
os << " scvtf " << PhysRegName(dst) << ", " << PhysRegName(src) << "\n";
break;
}
case Opcode::FPToSI: {
PhysReg dst = ops.at(0).GetReg();
PhysReg src = ops.at(1).GetReg();
if (!IsWReg(dst)) dst = ToW(dst);
os << " fcvtzs " << PhysRegName(dst) << ", " << PhysRegName(src) << "\n";
break;
}
case Opcode::ZExt:
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #1\n";
break;
case Opcode::AndRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " and " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::OrRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " orr " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::EorRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " eor " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::LslRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " lsl " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::LsrRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " lsr " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::AsrRR: {
std::vector<Operand> nops = ops;
NormalizeRegOps(nops, nops[0].GetReg());
os << " asr " << PhysRegName(nops[0].GetReg()) << ", "
<< PhysRegName(nops[1].GetReg()) << ", "
<< PhysRegName(nops[2].GetReg()) << "\n";
break;
}
case Opcode::B:
os << " b ";
PrintOperand(os, ops.at(0));
os << "\n";
break;
case Opcode::BCond:
os << " b.";
PrintOperand(os, ops.at(0));
os << " ";
PrintOperand(os, ops.at(1));
os << "\n";
break;
case Opcode::Call:
os << " bl ";
PrintOperand(os, ops.at(0));
os << "\n";
break;
case Opcode::Ret:
os << " ret\n";
break;
case Opcode::Nop:
os << " nop\n";
break;
case Opcode::Label:
os << ops.at(0).GetLabel() << ":\n";
break;
case Opcode::Movk:
os << " movk " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << ", lsl #" << ops.at(2).GetImm() << "\n";
break;
case Opcode::LoadStackAddr: {
const FrameSlot& slot = GetFrameSlot(function, ops.at(1));
int64_t offset = slot.offset;
PhysReg dst = ToX(ops.at(0).GetReg()); // 地址必须是 X 寄存器
auto tryEmitSimple = [&]() -> bool {
if (offset >= 0 && offset <= 4095) {
os << " add " << PhysRegName(dst) << ", x29, #" << offset << "\n";
return true;
} else if (offset < 0 && offset >= -4095) {
os << " sub " << PhysRegName(dst) << ", x29, #" << (-offset) << "\n";
return true;
}
return false;
};
if (tryEmitSimple()) break;
// 复杂偏移
uint64_t absOffset = (offset >= 0) ? offset : -offset;
PrintLoadImm64(os, PhysReg::X16, absOffset);
if (offset >= 0) {
os << " add " << PhysRegName(dst) << ", x29, x16\n";
} else {
os << " sub " << PhysRegName(dst) << ", x29, x16\n";
}
break;
}
case Opcode::Adrp: {
// adrp Xd, label
os << " adrp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< ops.at(1).GetLabel() << "\n";
break;
}
case Opcode::AddLabel: {
// add Xd, Xn, :lo12:label
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", :lo12:"
<< ops.at(2).GetLabel() << "\n";
break;
}
case Opcode::Sxtw: {
PhysReg dst = ops.at(0).GetReg();
PhysReg src = ops.at(1).GetReg();
// sxtw 要求 X 目标W 源
if (!IsXReg(dst)) dst = ToX(dst);
if (!IsWReg(src)) src = ToW(src);
os << " sxtw " << PhysRegName(dst) << ", " << PhysRegName(src) << "\n";
break;
}
default:
os << " // unknown instruction\n";
break;
}
}
// 打印单个函数(单函数版本)
void PrintAsm(const MachineFunction& function, std::ostream& os) {
// 输出函数标签
os << ".text\n";
os << ".global " << function.GetName() << "\n";
os << ".type " << function.GetName() << ", %function\n";
os << function.GetName() << ":\n";
// 计算栈帧大小
int frameSize = function.GetFrameSize();
// 输出每个基本块
const auto& blocks = function.GetBasicBlocks();
bool firstBlock = true;
for (const auto& bb : blocks) {
DEBUG_MSG("block");
// 输出基本块标签(非第一个基本块)
if (!firstBlock) {
os << bb->GetName() << ":\n";
}
firstBlock = false;
// 输出基本块中的指令
for (const auto& inst : bb->GetInstructions()) {
DEBUG_MSG("inst");
PrintInstruction(os, inst, function);
}
}
os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n";
}
} // namespace
// 打印模块(模块版本)
void PrintAsm(const MachineModule& module, std::ostream& os) {
// 输出文件头
os << ".arch armv8-a\n";
// 输出数据段:全局变量
const auto& globals = module.GetGlobals();
if (!globals.empty()) {
os << "\n.data\n";
for (const auto& g : globals) {
os << ".global " << g.name << "\n";
os << ".type " << g.name << ", %object\n";
os << ".align " << g.alignment << "\n";
os << g.name << ":\n";
if (g.is_zero_init) {
os << " .zero " << g.size << "\n";
} else if (g.has_init_data) {
if (g.size == 4) {
os << " .word " << static_cast<uint32_t>(g.init_data) << "\n";
} else if (g.size == 8) {
os << " .quad " << g.init_data << "\n";
} else {
// 暂不支持的标量大小,回退为零初始化
os << " .zero " << g.size << " // unhandled init size\n";
}
} else {
// 有初始值但无法提取(例如数组、结构体)
os << " .zero " << g.size << " // unhandled initializer\n";
}
os << ".size " << g.name << ", " << g.size << "\n\n";
}
}
static const std::set<std::string> externalFuncs = {
"getint", "getch", "getarray", "putint", "putch", "putarray", "puts",
"_sysy_starttime", "_sysy_stoptime", "starttime", "stoptime",
"getfloat", "putfloat", "getfarray", "putfarray", "memset",
"sysy_alloc_i32", "sysy_alloc_f32", "sysy_free_i32", "sysy_free_f32",
"sysy_zero_i32", "sysy_zero_f32"
};
DEBUG_MSG("module");
// 遍历所有函数,输出汇编
for (const auto& func : module.GetFunctions()) {
if (externalFuncs.count(func->GetName())) {
continue; // 跳过库函数桩
}
DEBUG_MSG("func");
PrintAsm(*func, os);
os << "\n";
}
}
} // namespace mir