You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/mir/Lowering.cpp

697 lines
31 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "mir/MIR.h"
#include <stdexcept>
#include <unordered_map>
#include "ir/IR.h"
#include "utils/Log.h"
namespace mir {
namespace {
using ValueSlotMap = std::unordered_map<const ir::Value*, int>;
// GEP 结果:(base_slot_index, byte_offset, global_symbol)
// - base_slot >= 0: 本地数组base_slot 是栈槽索引
// - base_slot = -1: 全局数组global_symbol 是全局变量名
// - byte_offset >= 0: 常量索引
// - byte_offset < 0: 变量索引,编码为 -1 - index_slot
struct GepInfo {
int base_slot;
int byte_offset;
std::string global_symbol;
};
using GepMap = std::unordered_map<const ir::Value*, GepInfo>;
void EmitValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) {
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(constant->GetValue())});
return;
}
// 检查是否是全局变量
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(value)) {
block.Append(Opcode::LoadGlobal,
{Operand::Reg(target), Operand::Symbol(gv->GetName())});
return;
}
auto it = slots.find(value);
if (it == slots.end()) {
throw std::runtime_error(
FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
}
block.Append(Opcode::LoadStack,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
}
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
MachineBasicBlock& block, ValueSlotMap& slots,
GepMap& geps) {
switch (inst.GetOpcode()) {
case ir::Opcode::Alloca: {
auto& alloca = static_cast<const ir::AllocaInst&>(inst);
int size = alloca.GetCount() * 4; // count * sizeof(i32)
slots.emplace(&inst, function.CreateFrameIndex(size));
return;
}
case ir::Opcode::Gep: {
auto& gep = static_cast<const ir::GepInst&>(inst);
auto* base = gep.GetBase();
auto* index = gep.GetIndex();
// 为 GEP 结果分配一个栈槽(用于存储指针值)
int ptr_slot = function.CreateFrameIndex(8); // 64-bit pointer
// 检查 base 是什么类型:全局数组、本地数组、还是指针参数
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(base)) {
// 全局数组
if (auto* const_index = dynamic_cast<const ir::ConstantInt*>(index)) {
// 常量索引:计算地址并存储
int byte_offset = const_index->GetValue() * 4;
geps.emplace(&inst, GepInfo{-1, byte_offset, gv->GetName()});
// 计算地址x9 = &global_array + offset
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::X9), Operand::Symbol(gv->GetName())});
if (byte_offset > 0) {
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm(byte_offset)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
}
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)});
} else {
// 变量索引
int index_slot = function.CreateFrameIndex();
EmitValueToReg(index, PhysReg::W8, slots, block);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(index_slot)});
geps.emplace(&inst, GepInfo{-1, -1 - index_slot, gv->GetName()});
// 计算地址x9 = &global_array + (index * 4)
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W8), Operand::Imm(2)});
block.Append(Opcode::LslRR, {Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W8)});
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::X9), Operand::Symbol(gv->GetName())});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)});
}
slots.emplace(&inst, ptr_slot);
return;
}
// 检查 base 是否在 slots 中(本地变量或参数)
auto base_it = slots.find(base);
if (base_it == slots.end()) {
throw std::runtime_error(
FormatError("mir", "GEP base 必须是 alloca、指针参数或全局变量"));
}
// 检查 base 是否是指针参数:如果是 Argument 且类型是指针
if (dynamic_cast<const ir::Argument*>(base) && base->GetType()->IsPtrInt32()) {
// 指针参数:从栈加载指针值,然后加上索引
if (auto* const_index = dynamic_cast<const ir::ConstantInt*>(index)) {
// 常量索引
int byte_offset = const_index->GetValue() * 4;
// 注意:这里不记录到 geps因为我们已经计算出最终地址了
// x9 = 从栈加载指针
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)});
if (byte_offset > 0) {
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm(byte_offset)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
}
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)});
} else {
// 变量索引
int index_slot = function.CreateFrameIndex();
EmitValueToReg(index, PhysReg::W8, slots, block);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(index_slot)});
// x9 = 从栈加载指针
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)});
// w10 = index * 4
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W8), Operand::Imm(2)});
block.Append(Opcode::LslRR, {Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W8)});
// x9 = x9 + w10
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)});
}
slots.emplace(&inst, ptr_slot);
return;
}
// 本地数组alloca 的结果)
// 检查是否是常量索引
if (auto* const_index = dynamic_cast<const ir::ConstantInt*>(index)) {
int byte_offset = const_index->GetValue() * 4;
geps.emplace(&inst, GepInfo{base_it->second, byte_offset, ""});
// 计算地址x9 = &array_base + byte_offset
block.Append(Opcode::LoadStackAddr,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)});
if (byte_offset > 0) {
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm(byte_offset)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
}
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)});
} else {
// 变量索引
int index_slot = function.CreateFrameIndex();
EmitValueToReg(index, PhysReg::W8, slots, block);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(index_slot)});
geps.emplace(&inst, GepInfo{base_it->second, -1 - index_slot, ""});
// 计算地址x9 = x29 + base_offset + (index * 4)
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W8), Operand::Imm(2)});
block.Append(Opcode::LslRR, {Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W8)});
block.Append(Opcode::LoadStackAddr,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)});
}
slots.emplace(&inst, ptr_slot);
return;
}
case ir::Opcode::Store: {
auto& store = static_cast<const ir::StoreInst&>(inst);
auto* ptr = store.GetPtr();
// 检查是否是 GEP 结果(数组元素)
auto gep_it = geps.find(ptr);
if (gep_it != geps.end()) {
const auto& gep_info = gep_it->second;
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
if (gep_info.base_slot == -1) {
// 全局数组
if (gep_info.byte_offset >= 0) {
// 常量索引global_array[const_idx]
// adrp x9, symbol; add x9, x9, :lo12:symbol; add x9, x9, #offset; str w8, [x9]
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)});
if (gep_info.byte_offset > 0) {
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm(gep_info.byte_offset)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
}
block.Append(Opcode::StoreIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
} else {
// 变量索引global_array[var_idx]
int index_slot = -1 - gep_info.byte_offset;
// 1. 加载 index
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
// 2. index * 4
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(2)});
block.Append(Opcode::LslRR, {Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W9)});
// 3. 获取全局数组基址
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)});
// 4. x9 + offset
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
// 5. 存储
block.Append(Opcode::StoreIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
}
} else if (gep_info.byte_offset >= 0) {
// 本地数组,常量索引
block.Append(Opcode::StoreStackOffset,
{Operand::Reg(PhysReg::W8),
Operand::FrameIndex(gep_info.base_slot),
Operand::Imm(gep_info.byte_offset)});
} else {
// 本地数组,变量索引
int index_slot = -1 - gep_info.byte_offset;
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(2)});
block.Append(Opcode::LslRR, {Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::LoadStackAddr,
{Operand::Reg(PhysReg::X9),
Operand::FrameIndex(gep_info.base_slot)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::StoreIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
}
return;
}
// 检查是否是全局变量
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(ptr)) {
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
block.Append(Opcode::StoreGlobal,
{Operand::Reg(PhysReg::W8), Operand::Symbol(gv->GetName())});
return;
}
// 栈变量或GEP结果
auto dst = slots.find(ptr);
if (dst == slots.end()) {
throw std::runtime_error(
FormatError("mir", "暂不支持对非栈/全局变量地址进行写入"));
}
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
// 检查是否是GEP结果如果ptr的类型是指针且slot大小是8字节说明存储的是地址
const auto& dst_slot = function.GetFrameSlot(dst->second);
if (ptr->GetType()->IsPtrInt32() && dst_slot.size == 8) {
// GEP结果先加载指针地址再通过指针存储值
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(dst->second)});
block.Append(Opcode::StoreIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
} else {
// 普通栈变量:直接存储
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)});
}
return;
}
case ir::Opcode::Load: {
auto& load = static_cast<const ir::LoadInst&>(inst);
auto* ptr = load.GetPtr();
// 检查是否是 GEP 结果(数组元素)
auto gep_it = geps.find(ptr);
if (gep_it != geps.end()) {
const auto& gep_info = gep_it->second;
int dst_slot = function.CreateFrameIndex();
if (gep_info.base_slot == -1) {
// 全局数组
if (gep_info.byte_offset >= 0) {
// 常量索引
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)});
if (gep_info.byte_offset > 0) {
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm(gep_info.byte_offset)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
}
block.Append(Opcode::LoadIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
} else {
// 变量索引
int index_slot = -1 - gep_info.byte_offset;
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(2)});
block.Append(Opcode::LslRR, {Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::LoadIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
}
} else if (gep_info.byte_offset >= 0) {
// 本地数组,常量索引
block.Append(Opcode::LoadStackOffset,
{Operand::Reg(PhysReg::W8),
Operand::FrameIndex(gep_info.base_slot),
Operand::Imm(gep_info.byte_offset)});
} else {
// 本地数组,变量索引
int index_slot = -1 - gep_info.byte_offset;
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(2)});
block.Append(Opcode::LslRR, {Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::LoadStackAddr,
{Operand::Reg(PhysReg::X9),
Operand::FrameIndex(gep_info.base_slot)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::LoadIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
}
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
// 检查是否是全局变量
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(ptr)) {
int dst_slot = function.CreateFrameIndex();
block.Append(Opcode::LoadGlobal,
{Operand::Reg(PhysReg::W8), Operand::Symbol(gv->GetName())});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
// 栈变量或GEP结果
auto src = slots.find(ptr);
if (src == slots.end()) {
throw std::runtime_error(
FormatError("mir", "暂不支持对非栈/全局变量地址进行读取"));
}
int dst_slot = function.CreateFrameIndex();
// 检查是否是GEP结果如果ptr的类型是指针且slot大小是8字节说明存储的是地址
const auto& src_slot = function.GetFrameSlot(src->second);
if (ptr->GetType()->IsPtrInt32() && src_slot.size == 8) {
// GEP结果先加载指针地址再通过指针加载值
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(src->second)});
block.Append(Opcode::LoadIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
} else {
// 普通栈变量:直接加载
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)});
}
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Add: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Sub: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Mul: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Div: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::DivRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Mod: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
// AArch64 没有模运算指令,使用 a - (a/b)*b
// w8 = a, w9 = b
block.Append(Opcode::DivRR, {Operand::Reg(PhysReg::W10), // w10 = a/b
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W10), // w10 = (a/b)*b
Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), // w8 = a - (a/b)*b
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W10)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Cmp: {
auto& cmp = static_cast<const ir::CmpInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block);
// cmp 操作符通过 operand 传递
block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9),
Operand::Imm(static_cast<int>(cmp.GetCmpOp()))});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
if (ret.GetValue()) {
// int/float 返回值
EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block);
}
// void 返回:不设置 w0
block.Append(Opcode::Ret);
return;
}
case ir::Opcode::Call: {
auto& call = static_cast<const ir::CallInst&>(inst);
auto* callee = call.GetCallee();
if (!callee) {
throw std::runtime_error(FormatError("mir", "Call 指令缺少被调用函数"));
}
// 参数传递:根据类型使用 w0-w7整数或 x0-x7指针
size_t num_args = call.GetNumArgs();
if (num_args > 8) {
throw std::runtime_error(FormatError("mir", "暂不支持超过 8 个参数的函数调用"));
}
const auto& param_types = callee->GetParamTypes();
for (size_t i = 0; i < num_args; i++) {
auto* arg_value = call.GetArg(i);
// 检查参数类型是否是指针
bool is_ptr = (i < param_types.size() && param_types[i]->IsPtrInt32());
if (is_ptr) {
// 指针参数:加载到 x 寄存器
PhysReg arg_reg = static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + i);
auto it = slots.find(arg_value);
if (it != slots.end()) {
const auto& slot = function.GetFrameSlot(it->second);
// 检查是否是alloca的结果数组slot大小大于8说明是数组本身
if (slot.size > 8) {
// Alloca结果需要传递数组的地址
block.Append(Opcode::LoadStackAddr,
{Operand::Reg(arg_reg), Operand::FrameIndex(it->second)});
} else {
// GEP结果或指针参数从栈上加载指针值
block.Append(Opcode::LoadStack,
{Operand::Reg(arg_reg), Operand::FrameIndex(it->second)});
}
} else {
throw std::runtime_error(
FormatError("mir", "找不到指针参数的值: " + arg_value->GetName()));
}
} else {
// 整数参数:加载到 w 寄存器
PhysReg arg_reg = static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + i);
EmitValueToReg(arg_value, arg_reg, slots, block);
}
}
// 生成 bl 指令
block.Append(Opcode::Bl, {Operand::Symbol(callee->GetName())});
// 处理返回值
if (!call.GetType()->IsVoid()) {
int dst_slot = function.CreateFrameIndex();
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W0), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
}
return;
}
}
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令"));
}
} // namespace
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
DefaultContext();
auto machine_module = std::make_unique<MachineModule>();
// 复制全局变量信息
for (const auto& gv_ptr : module.GetGlobalVars()) {
const auto& gv = *gv_ptr;
machine_module->AddGlobalVar(gv.GetName(), gv.GetInitValue(), gv.GetCount());
}
for (const auto& func_ptr : module.GetFunctions()) {
const auto& func = *func_ptr;
// 跳过外部函数声明SysY runtime
if (func.IsExternal()) continue;
auto* machine_func = machine_module->CreateFunction(func.GetName());
ValueSlotMap slots;
GepMap geps; // 跟踪 GEP 结果
// 为每个 IR BasicBlock 创建对应的 MachineBasicBlock
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> block_map;
for (const auto& bb_ptr : func.GetBlocks()) {
const auto& bb = *bb_ptr;
MachineBasicBlock* mbb;
if (bb.GetName() == "entry") {
mbb = &machine_func->GetEntry();
} else {
mbb = machine_func->CreateBlock(bb.GetName());
}
block_map[&bb] = mbb;
}
// 为函数参数创建栈槽并生成参数存储代码
size_t num_params = func.GetNumParams();
if (num_params > 8) {
throw std::runtime_error(
FormatError("mir", "暂不支持超过 8 个参数的函数"));
}
auto& entry_block = machine_func->GetEntry();
for (size_t i = 0; i < num_params; i++) {
auto* arg = func.GetArgument(i);
bool is_ptr = arg->GetType()->IsPtrInt32();
int slot_size = is_ptr ? 8 : 4; // 指针 8 字节,整数 4 字节
int slot = machine_func->CreateFrameIndex(slot_size);
slots.emplace(arg, slot);
// 根据参数类型选择寄存器:指针用 x0-x7整数用 w0-w7
PhysReg param_reg;
if (is_ptr) {
param_reg = static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + i);
} else {
param_reg = static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + i);
}
entry_block.Append(Opcode::StoreStack,
{Operand::Reg(param_reg), Operand::FrameIndex(slot)});
}
// 遍历所有基本块,生成指令
for (const auto& bb_ptr : func.GetBlocks()) {
const auto& bb = *bb_ptr;
MachineBasicBlock* current_mbb = block_map[&bb];
for (const auto& inst : bb.GetInstructions()) {
auto opcode = inst->GetOpcode();
// 跳转指令需要访问 block_map所以在这里单独处理
if (opcode == ir::Opcode::Br) {
auto& br = static_cast<const ir::BranchInst&>(*inst);
auto* target = br.GetTarget();
auto* target_mbb = block_map[target];
current_mbb->Append(Opcode::B, {Operand::Symbol(target_mbb->GetName())});
continue;
}
if (opcode == ir::Opcode::CondBr) {
auto& condbr = static_cast<const ir::CondBranchInst&>(*inst);
auto* cond = condbr.GetCond();
auto* true_bb = condbr.GetTrueBlock();
auto* false_bb = condbr.GetFalseBlock();
auto* true_mbb = block_map[true_bb];
auto* false_mbb = block_map[false_bb];
// 将条件值加载到寄存器
EmitValueToReg(cond, PhysReg::W8, slots, *current_mbb);
// cbnz: 非零跳转到 true_bb
current_mbb->Append(Opcode::Cbnz,
{Operand::Reg(PhysReg::W8),
Operand::Symbol(true_mbb->GetName())});
// 零则跳转到 false_bb
current_mbb->Append(Opcode::B, {Operand::Symbol(false_mbb->GetName())});
continue;
}
// 其他指令用原来的函数处理
LowerInstruction(*inst, *machine_func, *current_mbb, slots, geps);
}
}
}
return machine_module;
}
} // namespace mir