You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/mir/AsmPrinter.cpp

520 lines
18 KiB

#include "mir/MIR.h"
#include <ostream>
#include <stdexcept>
#include <iostream>
#include <vector>
#include <unordered_map>
#include "utils/Log.h"
// 引用全局变量(定义在 Lowering.cpp 中)
extern std::vector<mir::GlobalVarInfo> g_globalVars;
namespace mir {
namespace {
const FrameSlot& GetFrameSlot(const MachineFunction& function,
const Operand& operand) {
if (operand.GetKind() != Operand::Kind::FrameIndex) {
throw std::runtime_error(FormatError("mir", "期望 FrameIndex 操作数"));
}
return function.GetFrameSlot(operand.GetFrameIndex());
}
void EmitStackLoad(std::ostream& os, PhysReg dst, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " lw " << PhysRegName(dst) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " lw " << PhysRegName(dst) << ", 0(t4)\n";
}
}
void EmitStackStore(std::ostream& os, PhysReg src, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " sw " << PhysRegName(src) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " sw " << PhysRegName(src) << ", 0(t4)\n";
}
}
void EmitStackLoadFloat(std::ostream& os, PhysReg dst, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " flw " << PhysRegName(dst) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " flw " << PhysRegName(dst) << ", 0(t4)\n";
}
}
void EmitStackStoreFloat(std::ostream& os, PhysReg src, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " fsw " << PhysRegName(src) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " fsw " << PhysRegName(src) << ", 0(t4)\n";
}
}
// 输出单个函数的汇编
void PrintAsmFunction(const MachineFunction& function, std::ostream& os) {
// 收集所有基本块名称
std::unordered_map<const MachineBasicBlock*, std::string> block_names;
for (const auto& block_ptr : function.GetBlocks()) {
block_names[block_ptr.get()] = block_ptr->GetName();
}
int total_frame_size = 16 + function.GetFrameSize();
bool prologue_done = false;
for (const auto& block_ptr : function.GetBlocks()) {
const auto& block = *block_ptr;
// 输出基本块标签(入口块不输出,因为函数名已经是标签)
if (block.GetName() != "entry") {
os << block.GetName() << ":\n";
}
for (const auto& inst : block.GetInstructions()) {
const auto& ops = inst.GetOperands();
// 在入口块的第一条指令前输出序言
if (!prologue_done && block.GetName() == "entry") {
// 处理大栈帧的情况
if (total_frame_size <= 2047) {
os << " addi sp, sp, -" << total_frame_size << "\n";
} else {
os << " li t4, -" << total_frame_size << "\n";
os << " add sp, sp, t4\n";
}
// 保存 ra 和 s0
int ra_offset = total_frame_size - 8;
int s0_offset = total_frame_size - 16;
if (ra_offset <= 2047) {
os << " sw ra, " << ra_offset << "(sp)\n";
} else {
os << " li t4, " << ra_offset << "\n";
os << " add t4, sp, t4\n";
os << " sw ra, 0(t4)\n";
}
if (s0_offset <= 2047) {
os << " sw s0, " << s0_offset << "(sp)\n";
} else {
os << " li t4, " << s0_offset << "\n";
os << " add t4, sp, t4\n";
os << " sw s0, 0(t4)\n";
}
prologue_done = true;
}
switch (inst.GetOpcode()) {
case Opcode::Prologue:
case Opcode::Epilogue:
break;
case Opcode::MovImm:
os << " li " << PhysRegName(ops.at(0).GetReg()) << ", "
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::Load: {
if (ops.size() == 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
} else {
int frame_idx = ops.at(1).GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackLoad(os, ops.at(0).GetReg(), slot.offset);
}
break;
}
case Opcode::Store: {
if (ops.size() == 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
} else {
int frame_idx = ops.at(1).GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackStore(os, ops.at(0).GetReg(), slot.offset);
}
break;
}
case Opcode::Add:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Sub:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Mul: {
if (ops.size() > 2 && ops.at(2).GetKind() == Operand::Kind::Imm) {
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
} else {
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
}
break;
}
case Opcode::Div:
os << " div " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Rem:
os << " rem " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Slt:
os << " slt " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Slti:
os << " slti " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::Sltu:
os << " sltu " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Sltiu: // <-- 添加这个
os << " sltiu " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::Xori:
os << " xori " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::LoadGlobalAddr: {
std::string global_name = ops.at(1).GetGlobalName();
os << " la " << PhysRegName(ops.at(0).GetReg()) << ", " << global_name << "\n";
break;
}
case Opcode::LoadGlobal:
os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
break;
case Opcode::StoreGlobal: {
std::string global_name = ops.at(1).GetGlobalName();
os << " la t1, " << global_name << "\n";
os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0(t1)\n";
break;
}
case Opcode::GEP:
break;
case Opcode::LoadIndirect:
os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
break;
case Opcode::Call: {
std::string func_name = "memset";
if (!ops.empty() && ops[0].GetKind() == Operand::Kind::Func) {
func_name = ops[0].GetFuncName();
}
os << " call " << func_name << "\n";
break;
}
case Opcode::LoadAddr: {
int frame_idx = ops.at(1).GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
if (slot.offset >= -2048 && slot.offset <= 2047) {
os << " addi " << PhysRegName(ops.at(0).GetReg()) << ", sp, " << slot.offset << "\n";
} else {
os << " li " << PhysRegName(ops.at(0).GetReg()) << ", " << slot.offset << "\n";
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", sp, "
<< PhysRegName(ops.at(0).GetReg()) << "\n";
}
break;
}
case Opcode::Slli:
os << " slli " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::StoreIndirect:
os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
break;
case Opcode::Ret:{
// 恢复 ra 和 s0
int ra_offset = total_frame_size - 8;
int s0_offset = total_frame_size - 16;
if (ra_offset <= 2047) {
os << " lw ra, " << ra_offset << "(sp)\n";
} else {
os << " li t3, " << ra_offset << "\n";
os << " add t3, sp, t3\n";
os << " lw ra, 0(t3)\n";
}
if (s0_offset <= 2047) {
os << " lw s0, " << s0_offset << "(sp)\n";
} else {
os << " li t3, " << s0_offset << "\n";
os << " add t3, sp, t3\n";
os << " lw s0, 0(t3)\n";
}
// 恢复 sp
if (total_frame_size <= 2047) {
os << " addi sp, sp, " << total_frame_size << "\n";
} else {
os << " li t3, " << total_frame_size << "\n";
os << " add sp, sp, t3\n";
}
os << " ret\n";
break;
}
case Opcode::Br: {
auto* target = reinterpret_cast<MachineBasicBlock*>(ops[0].GetImm64());
os << " j " << target->GetName() << "\n";
break;
}
case Opcode::CondBr: {
auto* true_target = reinterpret_cast<MachineBasicBlock*>(ops[1].GetImm64());
auto* false_target = reinterpret_cast<MachineBasicBlock*>(ops[2].GetImm64());
auto true_it = block_names.find(true_target);
auto false_it = block_names.find(false_target);
if (true_it == block_names.end() || false_it == block_names.end()) {
throw std::runtime_error(FormatError("mir", "CondBr: 找不到基本块名称"));
}
os << " bnez " << PhysRegName(ops[0].GetReg()) << ", "
<< true_it->second << "\n";
os << " j " << false_it->second << "\n";
break;
}
// 浮点运算
case Opcode::FAdd:
os << " fadd.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FSub:
os << " fsub.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FMul:
os << " fmul.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FDiv:
os << " fdiv.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FEq:
os << " feq.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FLt:
os << " flt.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FLe:
os << " fle.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FMov:
os << " fmv.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FMovWX:
os << " fmv.w.x " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FMovXW:
os << " fmv.x.w " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::SIToFP:
os << " fcvt.s.w " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FPToSI:
os << " fcvt.w.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::LoadFloat:
if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Reg) {
os << " flw " << PhysRegName(ops[0].GetReg()) << ", 0("
<< PhysRegName(ops[1].GetReg()) << ")\n";
} else {
int frame_idx = ops[1].GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackLoadFloat(os, ops[0].GetReg(), slot.offset);
}
break;
case Opcode::StoreFloat:
if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Reg) {
os << " fsw " << PhysRegName(ops[0].GetReg()) << ", 0("
<< PhysRegName(ops[1].GetReg()) << ")\n";
} else {
int frame_idx = ops[1].GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackStoreFloat(os, ops[0].GetReg(), slot.offset);
}
break;
default:
break;
}
}
}
os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n";
}
} // namespace
// 输出多个函数的汇编
void PrintAsm(const std::vector<std::unique_ptr<MachineFunction>>& functions, std::ostream& os) {
// ========== 输出全局变量 ==========
// 输出 .data 段(已初始化的全局变量)
bool hasData = false;
for (const auto& gv : g_globalVars) {
if (!gv.isConst) {
if (!hasData) {
os << ".data\n";
hasData = true;
}
os << " .global " << gv.name << "\n";
os << " .type " << gv.name << ", @object\n";
if (gv.isArray && gv.arraySize > 1) {
int totalSize = gv.arraySize * 4;
os << " .size " << gv.name << ", " << totalSize << "\n";
os << gv.name << ":\n";
if (!gv.arrayValues.empty()) {
for (int val : gv.arrayValues) {
os << " .word " << val << "\n";
}
} else {
for (int i = 0; i < gv.arraySize; i++) {
os << " .word 0\n";
}
}
} else {
os << " .size " << gv.name << ", 4\n";
os << gv.name << ":\n";
os << " .word " << gv.value << "\n";
}
}
}
// 输出 .rodata 段(只读常量)
bool hasRodata = false;
for (const auto& gv : g_globalVars) {
if (gv.isConst) {
if (!hasRodata) {
os << ".section .rodata\n";
hasRodata = true;
}
os << " .global " << gv.name << "\n";
os << " .type " << gv.name << ", @object\n";
if (gv.isArray && gv.arraySize > 1) {
int totalSize = gv.arraySize * 4;
os << " .size " << gv.name << ", " << totalSize << "\n";
os << gv.name << ":\n";
if (!gv.arrayValues.empty()) {
for (int val : gv.arrayValues) {
os << " .word " << val << "\n";
}
} else {
for (int i = 0; i < gv.arraySize; i++) {
os << " .word 0\n";
}
}
} else {
os << " .size " << gv.name << ", 4\n";
os << gv.name << ":\n";
os << " .word " << gv.value << "\n";
}
}
}
// ========== 输出代码段 ==========
os << ".text\n";
// 输出每个函数
for (const auto& func_ptr : functions) {
os << ".global " << func_ptr->GetName() << "\n";
os << ".type " << func_ptr->GetName() << ", @function\n";
os << func_ptr->GetName() << ":\n";
PrintAsmFunction(*func_ptr, os);
os << "\n"; // 函数之间加空行
}
}
} // namespace mir