You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

628 lines
24 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "mir/MIR.h"
#include <iostream>
#include <stdexcept>
#include <unordered_map>
#include <vector>
#include <cstring>
#include "ir/IR.h"
#include "utils/Log.h"
std::vector<mir::GlobalVarInfo> g_globalVars;
namespace mir {
namespace {
using ValueSlotMap = std::unordered_map<const ir::Value*, int>;
static std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> block_map;
MachineBasicBlock* GetOrCreateBlock(const ir::BasicBlock* ir_block,
MachineFunction& function) {
auto it = block_map.find(ir_block);
if (it != block_map.end()) {
return it->second;
}
std::string name = ir_block->GetName();
if (name.empty()) {
name = "block_" + std::to_string(block_map.size());
}
auto* block = function.CreateBlock(name);
block_map[ir_block] = block;
return block;
}
void EmitValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block,
bool for_address=false) {
if (auto* arg = dynamic_cast<const ir::Argument*>(value)) {
auto it = slots.find(arg);
if (it != slots.end()) {
// 从栈槽加载参数值
if (value->GetType()->IsFloat32()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
} else {
block.Append(Opcode::Load,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
}
return;
}
}
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
int64_t val = constant->GetValue();
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(static_cast<int>(val))});
return;
}
// 处理浮点常量
if (auto* fconstant = dynamic_cast<const ir::ConstantFloat*>(value)) {
float val = fconstant->GetValue();
uint32_t bits;
memcpy(&bits, &val, sizeof(val));
// 检查目标是否是浮点寄存器
bool target_is_fp = (target == PhysReg::FT0 || target == PhysReg::FT1 ||
target == PhysReg::FT2 || target == PhysReg::FT3 ||
target == PhysReg::FT4 || target == PhysReg::FT5 ||
target == PhysReg::FT6 || target == PhysReg::FT7 ||
target == PhysReg::FA0 || target == PhysReg::FA1 ||
target == PhysReg::FA2 || target == PhysReg::FA3 ||
target == PhysReg::FA4 || target == PhysReg::FA5 ||
target == PhysReg::FA6 || target == PhysReg::FA7);
if (target_is_fp) {
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::T0), Operand::Imm(static_cast<int>(bits))});
block.Append(Opcode::FMovWX, {Operand::Reg(target), Operand::Reg(PhysReg::T0)});
} else {
// 目标是整数寄存器,直接加载
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(static_cast<int>(bits))});
}
return;
}
if (auto* gep = dynamic_cast<const ir::GepInst*>(value)) {
EmitValueToReg(gep->GetBasePtr(), target, slots, block, true);
EmitValueToReg(gep->GetIndex(), PhysReg::T1, slots, block);
block.Append(Opcode::Slli, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T1),
Operand::Imm(2)});
block.Append(Opcode::Add, {Operand::Reg(target),
Operand::Reg(target),
Operand::Reg(PhysReg::T1)});
return;
}
if (auto* alloca = dynamic_cast<const ir::AllocaInst*>(value)) {
auto it = slots.find(alloca);
if (it != slots.end()) {
block.Append(Opcode::LoadAddr,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
return;
}
}
if (auto* global = dynamic_cast<const ir::GlobalVariable*>(value)) {
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(target), Operand::Global(global->GetName())});
if (!for_address) {
if (global->IsFloat()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(target), Operand::Reg(target)});
} else {
block.Append(Opcode::LoadGlobal,
{Operand::Reg(target), Operand::Reg(target)});
}
}
return;
}
// 关键:在 slots 中查找,并根据类型生成正确的加载指令
auto it = slots.find(value);
if (it != slots.end()) {
if (value->GetType()->IsFloat32()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
} else {
block.Append(Opcode::Load,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
}
return;
}
std::cerr << "未找到的值: " << value << std::endl;
std::cerr << " 名称: " << value->GetName() << std::endl;
std::cerr << " 类型: " << (value->GetType()->IsFloat32() ? "float" : "int") << std::endl;
std::cerr << " 是否是 ConstantInt: " << (dynamic_cast<const ir::ConstantInt*>(value) != nullptr) << std::endl;
std::cerr << " 是否是 ConstantFloat: " << (dynamic_cast<const ir::ConstantFloat*>(value) != nullptr) << std::endl;
std::cerr << " 是否是 Instruction: " << (dynamic_cast<const ir::Instruction*>(value) != nullptr) << std::endl;
throw std::runtime_error(
FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
}
void StoreRegToSlot(PhysReg reg, int slot, MachineBasicBlock& block, bool isFloat = false) {
if (isFloat) {
block.Append(Opcode::StoreFloat,
{Operand::Reg(reg), Operand::FrameIndex(slot)});
} else {
block.Append(Opcode::Store,
{Operand::Reg(reg), Operand::FrameIndex(slot)});
}
}
// 将 LowerInstruction 重命名为 LowerInstructionToBlock并添加 MachineBasicBlock 参数
void LowerInstructionToBlock(const ir::Instruction& inst, MachineFunction& function,
ValueSlotMap& slots, MachineBasicBlock& block) {
switch (inst.GetOpcode()) {
case ir::Opcode::Alloca: {
auto& alloca = static_cast<const ir::AllocaInst&>(inst);
int size = 4;
if (alloca.GetNumElements() > 1) {
size = alloca.GetNumElements() * 4;
}
slots.emplace(&inst, function.CreateFrameIndex(size));
return;
}
case ir::Opcode::Store: {
auto& store = static_cast<const ir::StoreInst&>(inst);
if (dynamic_cast<const ir::GepInst*>(store.GetPtr())) {
EmitValueToReg(store.GetValue(), PhysReg::T2, slots, block);
EmitValueToReg(store.GetPtr(), PhysReg::T0, slots, block, true);
block.Append(Opcode::StoreIndirect,
{Operand::Reg(PhysReg::T2), Operand::Reg(PhysReg::T0)});
return;
}
if (auto* global = dynamic_cast<const ir::GlobalVariable*>(store.GetPtr())) {
EmitValueToReg(store.GetValue(), PhysReg::T0, slots, block);
std::string global_name = global->GetName();
block.Append(Opcode::StoreGlobal,
{Operand::Reg(PhysReg::T0), Operand::Global(global_name)});
return;
}
auto dst = slots.find(store.GetPtr());
if (dst != slots.end()) {
EmitValueToReg(store.GetValue(), PhysReg::T0, slots, block);
StoreRegToSlot(PhysReg::T0, dst->second, block);
return;
}
throw std::runtime_error(FormatError("mir", "Store: 无法处理的指针类型"));
}
case ir::Opcode::Load: {
auto& load = static_cast<const ir::LoadInst&>(inst);
if (dynamic_cast<const ir::GepInst*>(load.GetPtr())) {
EmitValueToReg(load.GetPtr(), PhysReg::T0, slots, block, true);
block.Append(Opcode::LoadIndirect,
{Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
int dst_slot = function.CreateFrameIndex(4);
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
slots.emplace(&inst, dst_slot);
return;
}
if (auto* global = dynamic_cast<const ir::GlobalVariable*>(load.GetPtr())) {
int dst_slot = function.CreateFrameIndex(4);
std::string global_name = global->GetName();
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::T0), Operand::Global(global_name)});
if (global->IsFloat()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, true);
} else {
block.Append(Opcode::LoadGlobal,
{Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
}
slots.emplace(&inst, dst_slot);
return;
}
auto src = slots.find(load.GetPtr());
if (src != slots.end()) {
int dst_slot = function.CreateFrameIndex(4);
if (load.GetType()->IsFloat32()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::FT0), Operand::FrameIndex(src->second)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block, true);
} else {
block.Append(Opcode::Load,
{Operand::Reg(PhysReg::T0), Operand::FrameIndex(src->second)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
}
slots.emplace(&inst, dst_slot);
return;
}
throw std::runtime_error(FormatError("mir", "Load: 无法处理的指针类型"));
}
case ir::Opcode::Add:
case ir::Opcode::Sub:
case ir::Opcode::Mul:
case ir::Opcode::Div:
case ir::Opcode::Mod: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::T0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::T1, slots, block);
Opcode op;
switch (inst.GetOpcode()) {
case ir::Opcode::Add: op = Opcode::Add; break;
case ir::Opcode::Sub: op = Opcode::Sub; break;
case ir::Opcode::Mul: op = Opcode::Mul; break;
case ir::Opcode::Div: op = Opcode::Div; break;
case ir::Opcode::Mod: op = Opcode::Rem; break;
default: op = Opcode::Add; break;
}
block.Append(op, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Gep: {
int dst_slot = function.CreateFrameIndex();
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Call: {
auto& call = static_cast<const ir::CallInst&>(inst);
for (size_t i = 0; i < call.GetNumArgs() && i < 8; i++) {
PhysReg argReg = static_cast<PhysReg>(static_cast<int>(PhysReg::A0) + i);
EmitValueToReg(call.GetArg(i), argReg, slots, block);
}
std::string func_name = call.GetCalleeName();
block.Append(Opcode::Call, {Operand::Func(func_name)});
if (!call.GetType()->IsVoid()) {
int dst_slot = function.CreateFrameIndex();
StoreRegToSlot(PhysReg::A0, dst_slot, block);
slots.emplace(&inst, dst_slot);
}
return;
}
case ir::Opcode::ICmp: {
auto& icmp = static_cast<const ir::ICmpInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(icmp.GetLhs(), PhysReg::T0, slots, block);
EmitValueToReg(icmp.GetRhs(), PhysReg::T1, slots, block);
ir::ICmpPredicate pred = icmp.GetPredicate();
switch (pred) {
case ir::ICmpPredicate::EQ:
block.Append(Opcode::Sub, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
block.Append(Opcode::Slti, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Imm(1)});
break;
case ir::ICmpPredicate::NE:
block.Append(Opcode::Sub, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
block.Append(Opcode::Sltu, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::ZERO),
Operand::Reg(PhysReg::T0)});
break;
case ir::ICmpPredicate::SLT:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
break;
case ir::ICmpPredicate::SLE:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0)});
block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Imm(1)});
break;
case ir::ICmpPredicate::SGT:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0)});
break;
case ir::ICmpPredicate::SGE:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0)});
block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Imm(1)});
break;
}
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::ZExt: {
auto& zext = static_cast<const ir::ZExtInst&>(inst);
int dst_slot = function.CreateFrameIndex(4); // i32 是 4 字节
// 获取源操作数的值
EmitValueToReg(zext.GetSrc(), PhysReg::T0, slots, block);
// 存储到新栈槽
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::FAdd:
case ir::Opcode::FSub:
case ir::Opcode::FMul:
case ir::Opcode::FDiv: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
EmitValueToReg(bin.GetLhs(), PhysReg::FT0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::FT1, slots, block);
Opcode op;
switch (inst.GetOpcode()) {
case ir::Opcode::FAdd: op = Opcode::FAdd; break;
case ir::Opcode::FSub: op = Opcode::FSub; break;
case ir::Opcode::FMul: op = Opcode::FMul; break;
case ir::Opcode::FDiv: op = Opcode::FDiv; break;
default: op = Opcode::FAdd; break;
}
block.Append(op, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::FCmp: {
auto& fcmp = static_cast<const ir::FCmpInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
EmitValueToReg(fcmp.GetLhs(), PhysReg::FT0, slots, block);
EmitValueToReg(fcmp.GetRhs(), PhysReg::FT1, slots, block);
ir::FCmpPredicate pred = fcmp.GetPredicate();
switch (pred) {
case ir::FCmpPredicate::OEQ:
block.Append(Opcode::FEq, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
case ir::FCmpPredicate::OLT:
block.Append(Opcode::FLt, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
case ir::FCmpPredicate::OLE:
block.Append(Opcode::FLe, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
default:
block.Append(Opcode::FEq, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
}
block.Append(Opcode::FMov, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::SIToFP: {
auto& conv = static_cast<const ir::SIToFPInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
auto src_it = slots.find(conv.GetSrc());
if (src_it == slots.end()) {
throw std::runtime_error(FormatError("mir", "SIToFP: 找不到源操作数的栈槽"));
}
block.Append(Opcode::Load,
{Operand::Reg(PhysReg::T0), Operand::FrameIndex(src_it->second)});
block.Append(Opcode::SIToFP, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block, true);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::FPToSI: {
auto& conv = static_cast<const ir::FPToSIInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
auto src_it = slots.find(conv.GetSrc());
if (src_it == slots.end()) {
throw std::runtime_error(FormatError("mir", "FPToSI: 找不到源操作数的栈槽"));
}
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::FT0), Operand::FrameIndex(src_it->second)});
block.Append(Opcode::FPToSI, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Br: {
auto& br = static_cast<const ir::BrInst&>(inst);
auto* target = br.GetTarget();
MachineBasicBlock* target_block = GetOrCreateBlock(target, function);
block.Append(Opcode::Br, {Operand::Imm64(reinterpret_cast<intptr_t>(target_block))});
return;
}
case ir::Opcode::CondBr: {
auto& condbr = static_cast<const ir::CondBrInst&>(inst);
auto* true_bb = condbr.GetTrueBB();
auto* false_bb = condbr.GetFalseBB();
EmitValueToReg(condbr.GetCond(), PhysReg::T0, slots, block);
block.Append(Opcode::Sltu, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::ZERO),
Operand::Reg(PhysReg::T0)});
MachineBasicBlock* true_block = GetOrCreateBlock(true_bb, function);
MachineBasicBlock* false_block = GetOrCreateBlock(false_bb, function);
block.Append(Opcode::CondBr, {Operand::Reg(PhysReg::T1),
Operand::Imm64(reinterpret_cast<intptr_t>(true_block)),
Operand::Imm64(reinterpret_cast<intptr_t>(false_block))});
return;
}
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
if (ret.GetValue()) {
auto val = ret.GetValue();
if (val->GetType()->IsFloat32()) {
auto it = slots.find(val);
if (it != slots.end()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::FT0), Operand::FrameIndex(it->second)});
block.Append(Opcode::FPToSI, {Operand::Reg(PhysReg::A0),
Operand::Reg(PhysReg::FT0)});
} else {
throw std::runtime_error(FormatError("mir", "Ret: 找不到浮点返回值的栈槽"));
}
} else {
EmitValueToReg(val, PhysReg::A0, slots, block);
}
} else {
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::A0), Operand::Imm(0)});
}
block.Append(Opcode::Ret);
return;
}
default: {
break;
}
}
}
} // namespace
std::unique_ptr<MachineFunction> LowerFunctionToMIR(const ir::Function& func) {
block_map.clear();
auto machine_func = std::make_unique<MachineFunction>(func.GetName());
ValueSlotMap slots;
// ========== 新增:为函数参数分配栈槽 ==========
for (size_t i = 0; i < func.GetNumArgs(); i++) {
ir::Argument* arg = func.GetArgument(i);
int slot = machine_func->CreateFrameIndex(4); // int 和指针都是 4 字节
// 将参数值从寄存器存储到栈槽
PhysReg argReg = static_cast<PhysReg>(static_cast<int>(PhysReg::A0) + i);
MachineBasicBlock* entry = machine_func->GetEntry();
// 存储参数到栈槽
if (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat32()) {
// 指针类型
entry->Append(Opcode::Store, {Operand::Reg(argReg), Operand::FrameIndex(slot)});
} else if (arg->GetType()->IsInt32()) {
// 整数类型
entry->Append(Opcode::Store, {Operand::Reg(argReg), Operand::FrameIndex(slot)});
} else if (arg->GetType()->IsFloat32()) {
// 浮点类型
entry->Append(Opcode::StoreFloat, {Operand::Reg(argReg), Operand::FrameIndex(slot)});
}
slots[arg] = slot;
}
// 第一遍:创建所有 IR 基本块对应的 MIR 基本块
for (const auto& ir_block : func.GetBlocks()) {
GetOrCreateBlock(ir_block.get(), *machine_func);
}
// 第二遍:遍历所有基本块,降低指令
for (const auto& ir_block : func.GetBlocks()) {
MachineBasicBlock* mbb = GetOrCreateBlock(ir_block.get(), *machine_func);
for (const auto& inst : ir_block->GetInstructions()) {
LowerInstructionToBlock(*inst, *machine_func, slots, *mbb);
}
}
return machine_func;
}
std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& module) {
DefaultContext();
// 收集全局变量(只做一次)
g_globalVars.clear();
for (const auto& global : module.GetGlobalVariables()) {
GlobalVarInfo info;
info.name = global->GetName();
info.isConst = global->IsConst();
info.isArray = global->IsArray();
info.arraySize = global->GetNumElements();
info.isFloat = global->IsFloat();
info.value = 0;
info.valueF = 0.0f;
if (info.isArray) {
if (info.isFloat) {
const auto& initVals = global->GetInitValsF();
for (float val : initVals) {
info.arrayValuesF.push_back(val);
}
} else {
if (global->HasInitVals()) {
const auto& initVals = global->GetInitVals();
for (int val : initVals) {
info.arrayValues.push_back(val);
}
}
}
} else {
if (info.isFloat) {
info.valueF = global->GetInitValF();
} else {
info.value = global->GetInitVal();
}
}
g_globalVars.push_back(info);
}
const auto& functions = module.GetFunctions();
if (functions.empty()) {
throw std::runtime_error(FormatError("mir", "模块中没有函数"));
}
std::vector<std::unique_ptr<MachineFunction>> result;
// 为每个函数生成 MachineFunction
for (const auto& func : functions) {
auto machine_func = LowerFunctionToMIR(*func);
result.push_back(std::move(machine_func));
}
return result;
}
} // namespace mir