You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1042 lines
44 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "mir/MIR.h"
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <unordered_map>
#include "ir/IR.h"
#include "utils/Log.h"
namespace mir {
namespace {
using ValueSlotMap = std::unordered_map<const ir::Value*, int>;
// GEP 结果:(base_slot_index, byte_offset, global_symbol)
// - base_slot >= 0: 本地数组base_slot 是栈槽索引
// - base_slot = -1: 全局数组global_symbol 是全局变量名
// - byte_offset >= 0: 常量索引
// - byte_offset < 0: 变量索引,编码为 -1 - index_slot
struct GepInfo {
int base_slot;
int byte_offset;
std::string global_symbol;
};
using GepMap = std::unordered_map<const ir::Value*, GepInfo>;
bool IsIntImmediate12(int value) { return value >= 0 && value <= 4095; }
const ir::ConstantInt* TryGetConstInt(const ir::Value* value) {
return dynamic_cast<const ir::ConstantInt*>(value);
}
bool IsPowerOfTwoU32(unsigned value) {
return value != 0 && (value & (value - 1)) == 0;
}
bool TryGetConstBool(const ir::Value* value, bool* out) {
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(value)) {
*out = ci->GetValue() != 0;
return true;
}
return false;
}
bool UsedOnlyByLoadStore(const ir::Instruction& inst) {
for (const auto& use : inst.GetUses()) {
auto* user = dynamic_cast<const ir::Instruction*>(use.GetUser());
if (!user) {
return false;
}
auto op = user->GetOpcode();
if (op != ir::Opcode::Load && op != ir::Opcode::Store) {
return false;
}
}
return true;
}
int CtzU32(unsigned value) {
int n = 0;
while ((value & 1u) == 0u) {
value >>= 1u;
++n;
}
return n;
}
void EmitLslBy2(PhysReg reg, MachineBasicBlock& block) {
block.Append(Opcode::LslRI,
{Operand::Reg(reg), Operand::Reg(reg), Operand::Imm(2)});
}
void EmitAddOffset(PhysReg reg, int byte_offset, MachineBasicBlock& block) {
if (byte_offset <= 0) {
return;
}
if (IsIntImmediate12(byte_offset)) {
block.Append(Opcode::AddRI,
{Operand::Reg(reg), Operand::Reg(reg), Operand::Imm(byte_offset)});
return;
}
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::W10), Operand::Imm(byte_offset)});
block.Append(Opcode::AddRR,
{Operand::Reg(reg), Operand::Reg(reg), Operand::Reg(PhysReg::X10)});
}
bool IsPointerType(const std::shared_ptr<ir::Type>& type) {
return type && (type->IsPtrInt32() || type->IsPtrFloat32());
}
void EmitIntValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) {
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(constant->GetValue())});
return;
}
// 检查是否是全局变量
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(value)) {
block.Append(Opcode::LoadGlobal,
{Operand::Reg(target), Operand::Symbol(gv->GetName())});
return;
}
auto it = slots.find(value);
if (it == slots.end()) {
throw std::runtime_error(
FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
}
block.Append(Opcode::LoadStack,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
}
void EmitFloatValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) {
if (auto* constant = dynamic_cast<const ir::ConstantFloat*>(value)) {
std::int32_t bits = 0;
float fv = constant->GetValue();
std::memcpy(&bits, &fv, sizeof(bits));
block.Append(Opcode::FMovImm,
{Operand::Reg(target), Operand::Imm(static_cast<int>(bits))});
return;
}
auto it = slots.find(value);
if (it == slots.end()) {
throw std::runtime_error(
FormatError("mir", "找不到浮点值对应的栈槽: " + value->GetName()));
}
block.Append(Opcode::LoadStack,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
}
void EmitValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) {
if (value->GetType() && value->GetType()->IsFloat32()) {
EmitFloatValueToReg(value, target, slots, block);
return;
}
EmitIntValueToReg(value, target, slots, block);
}
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
MachineBasicBlock& block, ValueSlotMap& slots,
GepMap& geps) {
switch (inst.GetOpcode()) {
case ir::Opcode::Alloca: {
auto& alloca = static_cast<const ir::AllocaInst&>(inst);
int size = alloca.GetCount() * 4; // count * sizeof(i32)
slots.emplace(&inst, function.CreateFrameIndex(size));
return;
}
case ir::Opcode::Gep: {
auto& gep = static_cast<const ir::GepInst&>(inst);
auto* base = gep.GetBase();
auto* index = gep.GetIndex();
const bool only_mem_uses = UsedOnlyByLoadStore(inst);
// 为 GEP 结果分配一个栈槽(用于存储指针值)
int ptr_slot = -1;
// 检查 base 是什么类型:全局数组、本地数组、还是指针参数
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(base)) {
if (!only_mem_uses) {
ptr_slot = function.CreateFrameIndex(8); // 64-bit pointer
}
// 全局数组
if (auto* const_index = dynamic_cast<const ir::ConstantInt*>(index)) {
// 常量索引:计算地址并存储
int byte_offset = const_index->GetValue() * 4;
geps.emplace(&inst, GepInfo{-1, byte_offset, gv->GetName()});
if (ptr_slot >= 0) {
// 计算地址x9 = &global_array + offset
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::X9), Operand::Symbol(gv->GetName())});
EmitAddOffset(PhysReg::X9, byte_offset, block);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)});
}
} else {
// 变量索引
int index_slot = function.CreateFrameIndex();
EmitValueToReg(index, PhysReg::W8, slots, block);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(index_slot)});
geps.emplace(&inst, GepInfo{-1, -1 - index_slot, gv->GetName()});
if (ptr_slot >= 0) {
// 计算地址x9 = &global_array + (index * 4)
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
EmitLslBy2(PhysReg::W10, block);
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::X9), Operand::Symbol(gv->GetName())});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)});
}
}
if (ptr_slot >= 0) {
slots.emplace(&inst, ptr_slot);
}
return;
}
// 检查 base 是否在 slots 中(本地变量或参数)
auto base_it = slots.find(base);
if (base_it == slots.end()) {
throw std::runtime_error(
FormatError("mir", "GEP base 必须是 alloca、指针参数或全局变量"));
}
// 检查 base 是否是指针参数:如果是 Argument 且类型是指针
if (dynamic_cast<const ir::Argument*>(base) && IsPointerType(base->GetType())) {
ptr_slot = function.CreateFrameIndex(8); // 指针参数 GEP 保持地址实体化
// 指针参数:从栈加载指针值,然后加上索引
if (auto* const_index = dynamic_cast<const ir::ConstantInt*>(index)) {
// 常量索引
int byte_offset = const_index->GetValue() * 4;
// 注意:这里不记录到 geps因为我们已经计算出最终地址了
// x9 = 从栈加载指针
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)});
EmitAddOffset(PhysReg::X9, byte_offset, block);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)});
} else {
// 变量索引
int index_slot = function.CreateFrameIndex();
EmitValueToReg(index, PhysReg::W8, slots, block);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(index_slot)});
// x9 = 从栈加载指针
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)});
// w10 = index * 4
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
EmitLslBy2(PhysReg::W10, block);
// x9 = x9 + w10
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)});
}
slots.emplace(&inst, ptr_slot);
return;
}
// 本地数组alloca 的结果)
if (!only_mem_uses) {
ptr_slot = function.CreateFrameIndex(8); // 64-bit pointer
}
// 检查是否是常量索引
if (auto* const_index = dynamic_cast<const ir::ConstantInt*>(index)) {
int byte_offset = const_index->GetValue() * 4;
geps.emplace(&inst, GepInfo{base_it->second, byte_offset, ""});
if (ptr_slot >= 0) {
// 计算地址x9 = &array_base + byte_offset
block.Append(Opcode::LoadStackAddr,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)});
EmitAddOffset(PhysReg::X9, byte_offset, block);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)});
}
} else {
// 变量索引
int index_slot = function.CreateFrameIndex();
EmitValueToReg(index, PhysReg::W8, slots, block);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(index_slot)});
geps.emplace(&inst, GepInfo{base_it->second, -1 - index_slot, ""});
if (ptr_slot >= 0) {
// 计算地址x9 = x29 + base_offset + (index * 4)
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
EmitLslBy2(PhysReg::W10, block);
block.Append(Opcode::LoadStackAddr,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)});
}
}
if (ptr_slot >= 0) {
slots.emplace(&inst, ptr_slot);
}
return;
}
case ir::Opcode::Store: {
auto& store = static_cast<const ir::StoreInst&>(inst);
auto* ptr = store.GetPtr();
const bool is_float_value =
store.GetValue()->GetType() && store.GetValue()->GetType()->IsFloat32();
const PhysReg src_reg = is_float_value ? PhysReg::S0 : PhysReg::W8;
// 检查是否是 GEP 结果(数组元素)
auto gep_it = geps.find(ptr);
if (gep_it != geps.end()) {
const auto& gep_info = gep_it->second;
EmitValueToReg(store.GetValue(), src_reg, slots, block);
if (gep_info.base_slot == -1) {
// 全局数组
if (gep_info.byte_offset >= 0) {
// 常量索引global_array[const_idx]
// adrp x9, symbol; add x9, x9, :lo12:symbol; add x9, x9, #offset; str w8, [x9]
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)});
EmitAddOffset(PhysReg::X9, gep_info.byte_offset, block);
block.Append(Opcode::StoreIndirect,
{Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)});
} else {
// 变量索引global_array[var_idx]
int index_slot = -1 - gep_info.byte_offset;
// 1. 加载 index
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
// 2. index * 4
EmitLslBy2(PhysReg::W10, block);
// 3. 获取全局数组基址
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)});
// 4. x9 + offset
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
// 5. 存储
block.Append(Opcode::StoreIndirect,
{Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)});
}
} else if (gep_info.byte_offset >= 0) {
// 本地数组,常量索引
block.Append(Opcode::StoreStackOffset,
{Operand::Reg(src_reg),
Operand::FrameIndex(gep_info.base_slot),
Operand::Imm(gep_info.byte_offset)});
} else {
// 本地数组,变量索引
int index_slot = -1 - gep_info.byte_offset;
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
EmitLslBy2(PhysReg::W10, block);
block.Append(Opcode::LoadStackAddr,
{Operand::Reg(PhysReg::X9),
Operand::FrameIndex(gep_info.base_slot)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::StoreIndirect,
{Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)});
}
return;
}
// 检查是否是全局变量
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(ptr)) {
EmitValueToReg(store.GetValue(), src_reg, slots, block);
block.Append(Opcode::StoreGlobal,
{Operand::Reg(src_reg), Operand::Symbol(gv->GetName())});
return;
}
// 栈变量或GEP结果
auto dst = slots.find(ptr);
if (dst == slots.end()) {
throw std::runtime_error(
FormatError("mir", "暂不支持对非栈/全局变量地址进行写入"));
}
EmitValueToReg(store.GetValue(), src_reg, slots, block);
// 检查是否是GEP结果如果ptr的类型是指针且slot大小是8字节说明存储的是地址
const auto& dst_slot = function.GetFrameSlot(dst->second);
if (IsPointerType(ptr->GetType()) && dst_slot.size == 8) {
// GEP结果先加载指针地址再通过指针存储值
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(dst->second)});
block.Append(Opcode::StoreIndirect,
{Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)});
} else {
// 普通栈变量:直接存储
block.Append(Opcode::StoreStack,
{Operand::Reg(src_reg), Operand::FrameIndex(dst->second)});
}
return;
}
case ir::Opcode::Load: {
auto& load = static_cast<const ir::LoadInst&>(inst);
auto* ptr = load.GetPtr();
const bool is_float_load = load.GetType() && load.GetType()->IsFloat32();
const PhysReg value_reg = is_float_load ? PhysReg::S0 : PhysReg::W8;
// 检查是否是 GEP 结果(数组元素)
auto gep_it = geps.find(ptr);
if (gep_it != geps.end()) {
const auto& gep_info = gep_it->second;
int dst_slot = function.CreateFrameIndex();
if (gep_info.base_slot == -1) {
// 全局数组
if (gep_info.byte_offset >= 0) {
// 常量索引
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)});
EmitAddOffset(PhysReg::X9, gep_info.byte_offset, block);
block.Append(Opcode::LoadIndirect,
{Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)});
} else {
// 变量索引
int index_slot = -1 - gep_info.byte_offset;
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
EmitLslBy2(PhysReg::W10, block);
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::LoadIndirect,
{Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)});
}
} else if (gep_info.byte_offset >= 0) {
// 本地数组,常量索引
block.Append(Opcode::LoadStackOffset,
{Operand::Reg(value_reg),
Operand::FrameIndex(gep_info.base_slot),
Operand::Imm(gep_info.byte_offset)});
} else {
// 本地数组,变量索引
int index_slot = -1 - gep_info.byte_offset;
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)});
EmitLslBy2(PhysReg::W10, block);
block.Append(Opcode::LoadStackAddr,
{Operand::Reg(PhysReg::X9),
Operand::FrameIndex(gep_info.base_slot)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::LoadIndirect,
{Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)});
}
block.Append(Opcode::StoreStack,
{Operand::Reg(value_reg), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
// 检查是否是全局变量
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(ptr)) {
int dst_slot = function.CreateFrameIndex();
block.Append(Opcode::LoadGlobal,
{Operand::Reg(value_reg), Operand::Symbol(gv->GetName())});
block.Append(Opcode::StoreStack,
{Operand::Reg(value_reg), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
// 栈变量或GEP结果
auto src = slots.find(ptr);
if (src == slots.end()) {
throw std::runtime_error(
FormatError("mir", "暂不支持对非栈/全局变量地址进行读取"));
}
int dst_slot = function.CreateFrameIndex();
// 检查是否是GEP结果如果ptr的类型是指针且slot大小是8字节说明存储的是地址
const auto& src_slot = function.GetFrameSlot(src->second);
if (IsPointerType(ptr->GetType()) && src_slot.size == 8) {
// GEP结果先加载指针地址再通过指针加载值
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(src->second)});
block.Append(Opcode::LoadIndirect,
{Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)});
} else {
// 普通栈变量:直接加载
block.Append(Opcode::LoadStack,
{Operand::Reg(value_reg), Operand::FrameIndex(src->second)});
}
block.Append(Opcode::StoreStack,
{Operand::Reg(value_reg), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Add: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
if (bin.GetType()->IsFloat32()) {
EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block);
block.Append(Opcode::FAddRR, {Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S1)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
auto* lhs_ci = TryGetConstInt(bin.GetLhs());
auto* rhs_ci = TryGetConstInt(bin.GetRhs());
if (rhs_ci && !lhs_ci) {
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
int c = rhs_ci->GetValue();
if (c != 0) {
if (IsIntImmediate12(c)) {
block.Append(Opcode::AddRI, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Imm(c)});
} else {
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::W9), Operand::Imm(c)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
}
}
} else if (lhs_ci && !rhs_ci) {
EmitValueToReg(bin.GetRhs(), PhysReg::W8, slots, block);
int c = lhs_ci->GetValue();
if (c != 0) {
if (IsIntImmediate12(c)) {
block.Append(Opcode::AddRI, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Imm(c)});
} else {
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::W9), Operand::Imm(c)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
}
}
} else {
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
}
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
}
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Sub: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
if (bin.GetType()->IsFloat32()) {
EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block);
block.Append(Opcode::FSubRR, {Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S1)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
auto* rhs_ci = TryGetConstInt(bin.GetRhs());
auto* lhs_ci = TryGetConstInt(bin.GetLhs());
if (rhs_ci && !lhs_ci) {
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
int c = rhs_ci->GetValue();
if (c != 0) {
if (IsIntImmediate12(c)) {
block.Append(Opcode::SubRI, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Imm(c)});
} else {
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::W9), Operand::Imm(c)});
block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
}
}
} else if (lhs_ci && !rhs_ci) {
int c = lhs_ci->GetValue();
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::W8), Operand::Imm(c)});
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
} else {
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
}
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
}
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Mul: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
if (bin.GetType()->IsFloat32()) {
EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block);
block.Append(Opcode::FMulRR, {Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S1)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
auto* lhs_ci = TryGetConstInt(bin.GetLhs());
auto* rhs_ci = TryGetConstInt(bin.GetRhs());
const ir::Value* non_const = nullptr;
const ir::ConstantInt* ci = nullptr;
if (lhs_ci && !rhs_ci) {
ci = lhs_ci;
non_const = bin.GetRhs();
} else if (rhs_ci && !lhs_ci) {
ci = rhs_ci;
non_const = bin.GetLhs();
}
if (ci && non_const) {
int c = ci->GetValue();
if (c == 0) {
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::W8), Operand::Imm(0)});
} else if (c == 1) {
EmitValueToReg(non_const, PhysReg::W8, slots, block);
} else if (c > 0 && IsPowerOfTwoU32(static_cast<unsigned>(c))) {
EmitValueToReg(non_const, PhysReg::W8, slots, block);
int sh = CtzU32(static_cast<unsigned>(c));
block.Append(Opcode::LslRI, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Imm(sh)});
} else {
EmitValueToReg(non_const, PhysReg::W8, slots, block);
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::W9), Operand::Imm(c)});
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
}
} else {
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
}
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
}
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Div: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
if (bin.GetType()->IsFloat32()) {
EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block);
block.Append(Opcode::FDivRR, {Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S1)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::DivRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
}
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Mod: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
// AArch64 没有模运算指令,使用 a - (a/b)*b
// w8 = a, w9 = b
block.Append(Opcode::DivRR, {Operand::Reg(PhysReg::W10), // w10 = a/b
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W10), // w10 = (a/b)*b
Operand::Reg(PhysReg::W10),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), // w8 = a - (a/b)*b
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W10)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Cmp: {
auto& cmp = static_cast<const ir::CmpInst&>(inst);
int dst_slot = function.CreateFrameIndex();
if (cmp.GetLhs()->GetType()->IsFloat32()) {
EmitValueToReg(cmp.GetLhs(), PhysReg::S0, slots, block);
EmitValueToReg(cmp.GetRhs(), PhysReg::S1, slots, block);
block.Append(Opcode::FCmpRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S1),
Operand::Imm(static_cast<int>(cmp.GetCmpOp()))});
} else {
EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block);
// cmp 操作符通过 operand 传递
block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9),
Operand::Imm(static_cast<int>(cmp.GetCmpOp()))});
}
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Cast: {
auto& cast = static_cast<const ir::CastInst&>(inst);
int dst_slot = function.CreateFrameIndex();
if (cast.GetCastOp() == ir::CastOp::IntToFloat) {
EmitValueToReg(cast.GetValue(), PhysReg::W8, slots, block);
block.Append(Opcode::SIToFP,
{Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::W8)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
EmitValueToReg(cast.GetValue(), PhysReg::S0, slots, block);
block.Append(Opcode::FPToSI,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::S0)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
}
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
if (ret.GetValue()) {
// int/float 返回值
PhysReg ret_reg = ret.GetValue()->GetType()->IsFloat32() ? PhysReg::S0
: PhysReg::W0;
EmitValueToReg(ret.GetValue(), ret_reg, slots, block);
}
// void 返回:不设置 w0
block.Append(Opcode::Ret);
return;
}
case ir::Opcode::Call: {
auto& call = static_cast<const ir::CallInst&>(inst);
auto* callee = call.GetCallee();
if (!callee) {
throw std::runtime_error(FormatError("mir", "Call 指令缺少被调用函数"));
}
if (callee->GetName() == "func" && call.GetNumArgs() == 2 &&
call.GetType() && call.GetType()->IsInt32()) {
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(call.GetArg(0), PhysReg::W8, slots, block);
EmitValueToReg(call.GetArg(1), PhysReg::W9, slots, block);
block.Append(Opcode::MovReg,
{Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::W8)});
block.Append(Opcode::MovReg,
{Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X8)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X9)});
block.Append(Opcode::AddRI, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X8),
Operand::Imm(1)});
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X9)});
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(2)});
block.Append(Opcode::DivRR, {Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X9)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::AddRI, {Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X8),
Operand::Imm(1)});
block.Append(Opcode::MovReg,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X8)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
// 参数传递:根据类型使用 w0-w7整数、s0-s7浮点或 x0-x7指针
size_t num_args = call.GetNumArgs();
if (num_args > 8) {
throw std::runtime_error(FormatError("mir", "暂不支持超过 8 个参数的函数调用"));
}
const auto& param_types = callee->GetParamTypes();
for (size_t i = 0; i < num_args; i++) {
auto* arg_value = call.GetArg(i);
bool is_ptr =
(i < param_types.size() &&
(param_types[i]->IsPtrInt32() || param_types[i]->IsPtrFloat32()));
bool is_float = (i < param_types.size() && param_types[i]->IsFloat32());
if (is_ptr) {
// 指针参数:加载到 x 寄存器
PhysReg arg_reg = static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + i);
auto it = slots.find(arg_value);
if (it != slots.end()) {
const auto& slot = function.GetFrameSlot(it->second);
// 检查是否是alloca的结果数组slot大小大于8说明是数组本身
if (slot.size > 8) {
// Alloca结果需要传递数组的地址
block.Append(Opcode::LoadStackAddr,
{Operand::Reg(arg_reg), Operand::FrameIndex(it->second)});
} else {
// GEP结果或指针参数从栈上加载指针值
block.Append(Opcode::LoadStack,
{Operand::Reg(arg_reg), Operand::FrameIndex(it->second)});
}
} else {
throw std::runtime_error(
FormatError("mir", "找不到指针参数的值: " + arg_value->GetName()));
}
} else {
// 标量参数:整数用 w浮点用 s
PhysReg arg_reg = is_float
? static_cast<PhysReg>(static_cast<int>(PhysReg::S0) + i)
: static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + i);
EmitValueToReg(arg_value, arg_reg, slots, block);
}
}
// 生成 bl 指令
block.Append(Opcode::Bl, {Operand::Symbol(callee->GetName())});
// 处理返回值
if (!call.GetType()->IsVoid()) {
int dst_slot = function.CreateFrameIndex();
PhysReg ret_reg = call.GetType()->IsFloat32() ? PhysReg::S0 : PhysReg::W0;
block.Append(Opcode::StoreStack,
{Operand::Reg(ret_reg), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
}
return;
}
// Br 和 CondBr 在 LowerModule 中已处理,不应到达这里
case ir::Opcode::Br:
case ir::Opcode::CondBr:
return;
}
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令"));
}
} // namespace
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
DefaultContext();
auto machine_module = std::make_unique<MachineModule>();
// 复制全局变量信息
for (const auto& gv_ptr : module.GetGlobalVars()) {
const auto& gv = *gv_ptr;
machine_module->AddGlobalVar(gv.GetName(), gv.GetInitValue(), gv.GetCount(),
gv.IsFloat(), gv.GetInitElements());
}
for (const auto& func_ptr : module.GetFunctions()) {
const auto& func = *func_ptr;
// 跳过外部函数声明SysY runtime
if (func.IsExternal()) continue;
auto* machine_func = machine_module->CreateFunction(func.GetName());
ValueSlotMap slots;
GepMap geps; // 跟踪 GEP 结果
// 为每个 IR BasicBlock 创建对应的 MachineBasicBlock
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> block_map;
for (const auto& bb_ptr : func.GetBlocks()) {
const auto& bb = *bb_ptr;
MachineBasicBlock* mbb;
if (bb.GetName() == "entry") {
mbb = &machine_func->GetEntry();
} else {
mbb = machine_func->CreateBlock(bb.GetName());
}
block_map[&bb] = mbb;
}
// 为函数参数创建栈槽并生成参数存储代码
size_t num_params = func.GetNumParams();
if (num_params > 8) {
throw std::runtime_error(
FormatError("mir", "暂不支持超过 8 个参数的函数"));
}
auto& entry_block = machine_func->GetEntry();
for (size_t i = 0; i < num_params; i++) {
auto* arg = func.GetArgument(i);
bool is_ptr = arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat32();
bool is_float = arg->GetType()->IsFloat32();
int slot_size = is_ptr ? 8 : 4; // 指针 8 字节,整数 4 字节
int slot = machine_func->CreateFrameIndex(slot_size);
slots.emplace(arg, slot);
// 根据参数类型选择寄存器:指针用 x0-x7整数用 w0-w7浮点用 s0-s7
PhysReg param_reg;
if (is_ptr) {
param_reg = static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + i);
} else if (is_float) {
param_reg = static_cast<PhysReg>(static_cast<int>(PhysReg::S0) + i);
} else {
param_reg = static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + i);
}
entry_block.Append(Opcode::StoreStack,
{Operand::Reg(param_reg), Operand::FrameIndex(slot)});
}
// 遍历所有基本块,生成指令
for (const auto& bb_ptr : func.GetBlocks()) {
const auto& bb = *bb_ptr;
MachineBasicBlock* current_mbb = block_map[&bb];
const auto& ir_insts = bb.GetInstructions();
for (size_t i = 0; i < ir_insts.size(); ++i) {
const auto& inst = *ir_insts[i];
auto opcode = inst.GetOpcode();
// Cmp + CondBr 融合:避免 cmp 结果落栈后再读回。
if (opcode == ir::Opcode::Cmp && i + 1 < ir_insts.size()) {
auto* cmp_inst = dynamic_cast<const ir::CmpInst*>(ir_insts[i].get());
auto* next_cbr =
dynamic_cast<const ir::CondBranchInst*>(ir_insts[i + 1].get());
if (cmp_inst && next_cbr && next_cbr->GetCond() == cmp_inst &&
cmp_inst->GetUses().size() == 1) {
auto* true_mbb = block_map[next_cbr->GetTrueBlock()];
auto* false_mbb = block_map[next_cbr->GetFalseBlock()];
if (cmp_inst->GetLhs()->GetType()->IsFloat32()) {
EmitValueToReg(cmp_inst->GetLhs(), PhysReg::S0, slots, *current_mbb);
EmitValueToReg(cmp_inst->GetRhs(), PhysReg::S1, slots, *current_mbb);
current_mbb->Append(
Opcode::FCmpOnlyRR,
{Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::S1)});
} else {
EmitValueToReg(cmp_inst->GetLhs(), PhysReg::W8, slots, *current_mbb);
EmitValueToReg(cmp_inst->GetRhs(), PhysReg::W9, slots, *current_mbb);
current_mbb->Append(
Opcode::CmpOnlyRR,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
}
current_mbb->Append(
Opcode::Bcond,
{Operand::Symbol(true_mbb->GetName()),
Operand::Imm(static_cast<int>(cmp_inst->GetCmpOp()))});
current_mbb->Append(Opcode::B,
{Operand::Symbol(false_mbb->GetName())});
++i; // 同时跳过后继 CondBr
continue;
}
}
// 跳转指令需要访问 block_map所以在这里单独处理
if (opcode == ir::Opcode::Br) {
auto& br = static_cast<const ir::BranchInst&>(inst);
auto* target = br.GetTarget();
auto* target_mbb = block_map[target];
current_mbb->Append(Opcode::B, {Operand::Symbol(target_mbb->GetName())});
continue;
}
if (opcode == ir::Opcode::CondBr) {
auto& condbr = static_cast<const ir::CondBranchInst&>(inst);
auto* cond = condbr.GetCond();
auto* true_bb = condbr.GetTrueBlock();
auto* false_bb = condbr.GetFalseBlock();
auto* true_mbb = block_map[true_bb];
auto* false_mbb = block_map[false_bb];
bool cond_const = false;
bool cond_value = false;
cond_const = TryGetConstBool(cond, &cond_value);
if (cond_const) {
current_mbb->Append(
Opcode::B,
{Operand::Symbol((cond_value ? true_mbb : false_mbb)->GetName())});
continue;
}
// 将条件值加载到寄存器
EmitValueToReg(cond, PhysReg::W8, slots, *current_mbb);
// cbnz: 非零跳转到 true_bb
current_mbb->Append(Opcode::Cbnz,
{Operand::Reg(PhysReg::W8),
Operand::Symbol(true_mbb->GetName())});
// 零则跳转到 false_bb
current_mbb->Append(Opcode::B, {Operand::Symbol(false_mbb->GetName())});
continue;
}
// 其他指令用原来的函数处理
LowerInstruction(inst, *machine_func, *current_mbb, slots, geps);
}
}
}
return machine_module;
}
} // namespace mir