forked from NUDT-compiler/nudt-compiler-cpp
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
328 lines
11 KiB
328 lines
11 KiB
#include "mir/MIR.h"
|
|
|
|
#include <ostream>
|
|
#include <stdexcept>
|
|
|
|
#include "utils/Log.h"
|
|
|
|
//#define DEBUG_Asm
|
|
|
|
#ifdef DEBUG_Asm
|
|
#include <iostream>
|
|
#define DEBUG_MSG(msg) std::cerr << "[Asm Debug] " << msg << std::endl
|
|
#else
|
|
#define DEBUG_MSG(msg)
|
|
#endif
|
|
|
|
namespace mir {
|
|
namespace {
|
|
|
|
const FrameSlot& GetFrameSlot(const MachineFunction& function,
|
|
const Operand& operand) {
|
|
if (operand.GetKind() != Operand::Kind::FrameIndex) {
|
|
throw std::runtime_error(FormatError("mir", "期望 FrameIndex 操作数"));
|
|
}
|
|
return function.GetFrameSlot(operand.GetFrameIndex());
|
|
}
|
|
|
|
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
|
|
int offset) {
|
|
//os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset
|
|
// << "]\n";
|
|
// 使用 sp 相对寻址
|
|
os << " " << mnemonic << " " << PhysRegName(reg) << ", [sp, #" << offset << "]\n";
|
|
}
|
|
|
|
// 打印单个操作数
|
|
void PrintOperand(std::ostream& os, const Operand& op) {
|
|
switch (op.GetKind()) {
|
|
case Operand::Kind::Reg:
|
|
os << PhysRegName(op.GetReg());
|
|
break;
|
|
case Operand::Kind::Imm:
|
|
os << "#" << op.GetImm();
|
|
break;
|
|
case Operand::Kind::FrameIndex:
|
|
os << "[sp, #" << op.GetFrameIndex() << "]";
|
|
break;
|
|
case Operand::Kind::Cond:
|
|
os << CondCodeName(op.GetCondCode());
|
|
break;
|
|
case Operand::Kind::Label:
|
|
DEBUG_MSG("label is" << op.GetLabel());
|
|
os << op.GetLabel();
|
|
break;
|
|
}
|
|
}
|
|
|
|
// 打印单条指令
|
|
void PrintInstruction(std::ostream& os, const MachineInstr& instr,
|
|
const MachineFunction& function) {
|
|
const auto& ops = instr.GetOperands();
|
|
|
|
switch (instr.GetOpcode()) {
|
|
case Opcode::Prologue:
|
|
os << " stp x29, x30, [sp, #-16]!\n";
|
|
os << " mov x29, sp\n";
|
|
if (function.GetFrameSize() > 0) {
|
|
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
|
|
}
|
|
break;
|
|
case Opcode::Epilogue:
|
|
if (function.GetFrameSize() > 0) {
|
|
os << " add sp, sp, #" << function.GetFrameSize() << "\n";
|
|
}
|
|
os << " ldp x29, x30, [sp], #16\n";
|
|
break;
|
|
case Opcode::MovImm:
|
|
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
|
|
<< ops.at(1).GetImm() << "\n";
|
|
break;
|
|
case Opcode::MovReg:
|
|
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::StoreStack: {
|
|
// 检查第二个操作数的类型
|
|
if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::FrameIndex) {
|
|
// 存储到栈槽
|
|
const auto& slot = GetFrameSlot(function, ops.at(1));
|
|
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
|
|
} else if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
|
|
// 间接存储:存储到寄存器指向的地址
|
|
// STR W9, [X8]
|
|
os << " str " << PhysRegName(ops.at(0).GetReg()) << ", ["
|
|
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
|
|
} else {
|
|
throw std::runtime_error("StoreStack: 无效的操作数类型");
|
|
}
|
|
break;
|
|
}
|
|
case Opcode::LoadStack: {
|
|
// 检查第二个操作数的类型
|
|
if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::FrameIndex) {
|
|
// 从栈槽加载
|
|
const auto& slot = GetFrameSlot(function, ops.at(1));
|
|
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
|
|
} else if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
|
|
// 间接加载:从寄存器指向的地址加载
|
|
// LDR W9, [X8]
|
|
os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", ["
|
|
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
|
|
} else {
|
|
throw std::runtime_error("LoadStack: 无效的操作数类型");
|
|
}
|
|
break;
|
|
}
|
|
case Opcode::StoreStackPair:
|
|
// stp x29, x30, [sp, #-16]!
|
|
os << " stp " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", [sp";
|
|
if (ops.size() > 2 && ops.at(2).GetKind() == Operand::Kind::Imm) {
|
|
int offset = ops.at(2).GetImm();
|
|
os << ", #" << offset;
|
|
}
|
|
os << "]!\n"; // 注意添加 ! 表示 pre-index
|
|
break;
|
|
case Opcode::LoadStackPair:
|
|
// ldp x29, x30, [sp], #16
|
|
os << " ldp " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", [sp]";
|
|
if (ops.size() > 2 && ops.at(2).GetKind() == Operand::Kind::Imm) {
|
|
int offset = ops.at(2).GetImm();
|
|
os << ", #" << offset;
|
|
}
|
|
os << "\n";
|
|
break;
|
|
case Opcode::AddRR:
|
|
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::AddRI:
|
|
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", #"
|
|
<< ops.at(2).GetImm() << "\n";
|
|
break;
|
|
case Opcode::SubRR:
|
|
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::SubRI:
|
|
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", #"
|
|
<< ops.at(2).GetImm() << "\n";
|
|
break;
|
|
case Opcode::MulRR:
|
|
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::SDivRR:
|
|
os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::UDivRR:
|
|
os << " udiv " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::FAddRR:
|
|
os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::FSubRR:
|
|
os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::FMulRR:
|
|
os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::FDivRR:
|
|
os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::CmpRR:
|
|
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::CmpRI:
|
|
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", #"
|
|
<< ops.at(1).GetImm() << "\n";
|
|
break;
|
|
case Opcode::FCmpRR:
|
|
os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::SIToFP:
|
|
os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::FPToSI:
|
|
os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::ZExt:
|
|
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", #1\n";
|
|
break;
|
|
case Opcode::AndRR:
|
|
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::OrRR:
|
|
os << " orr " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::EorRR:
|
|
os << " eor " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::LslRR:
|
|
os << " lsl " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::LsrRR:
|
|
os << " lsr " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::AsrRR:
|
|
os << " asr " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::B:
|
|
os << " b ";
|
|
PrintOperand(os, ops.at(0));
|
|
os << "\n";
|
|
break;
|
|
case Opcode::BCond:
|
|
os << " b.";
|
|
PrintOperand(os, ops.at(0));
|
|
os << " ";
|
|
PrintOperand(os, ops.at(1));
|
|
os << "\n";
|
|
break;
|
|
case Opcode::Call:
|
|
os << " bl ";
|
|
PrintOperand(os, ops.at(0));
|
|
os << "\n";
|
|
break;
|
|
case Opcode::Ret:
|
|
os << " ret\n";
|
|
break;
|
|
case Opcode::Nop:
|
|
os << " nop\n";
|
|
break;
|
|
case Opcode::Label:
|
|
os << ops.at(0).GetLabel() << ":\n";
|
|
break;
|
|
default:
|
|
os << " // unknown instruction\n";
|
|
break;
|
|
}
|
|
}
|
|
|
|
// 打印单个函数(单函数版本)
|
|
void PrintAsm(const MachineFunction& function, std::ostream& os) {
|
|
// 输出函数标签
|
|
os << ".text\n";
|
|
os << ".global " << function.GetName() << "\n";
|
|
os << ".type " << function.GetName() << ", %function\n";
|
|
os << function.GetName() << ":\n";
|
|
|
|
// 计算栈帧大小
|
|
int frameSize = function.GetFrameSize();
|
|
|
|
// 输出每个基本块
|
|
const auto& blocks = function.GetBasicBlocks();
|
|
bool firstBlock = true;
|
|
|
|
for (const auto& bb : blocks) {
|
|
DEBUG_MSG("block");
|
|
// 输出基本块标签(非第一个基本块)
|
|
if (!firstBlock) {
|
|
os << bb->GetName() << ":\n";
|
|
}
|
|
firstBlock = false;
|
|
|
|
// 输出基本块中的指令
|
|
for (const auto& inst : bb->GetInstructions()) {
|
|
DEBUG_MSG("inst");
|
|
PrintInstruction(os, inst, function);
|
|
}
|
|
}
|
|
|
|
os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n";
|
|
}
|
|
|
|
} // namespace
|
|
|
|
// 打印模块(模块版本)
|
|
void PrintAsm(const MachineModule& module, std::ostream& os) {
|
|
// 输出文件头
|
|
os << ".arch armv8-a\n";
|
|
os << ".text\n";
|
|
|
|
DEBUG_MSG("module");
|
|
// 遍历所有函数,输出汇编
|
|
for (const auto& func : module.GetFunctions()) {
|
|
DEBUG_MSG("func");
|
|
PrintAsm(*func, os);
|
|
os << "\n";
|
|
}
|
|
}
|
|
|
|
} // namespace mir
|