lab3代码实现-续

lc 2 months ago
parent 3dda941176
commit 5c6804f1d6

@ -1,5 +1,6 @@
#include "mir/MIR.h"
#include <cstring>
#include <stdexcept>
#include <unordered_map>
@ -11,113 +12,474 @@ namespace {
using ValueSlotMap = std::unordered_map<const ir::Value*, int>;
PhysReg ToXReg(PhysReg reg) {
if ((int)reg >= (int)PhysReg::W0 && (int)reg <= (int)PhysReg::W15) {
return static_cast<PhysReg>((int)reg - (int)PhysReg::W0 + (int)PhysReg::X0);
}
return reg;
}
PhysReg ToSReg(PhysReg reg) {
if ((int)reg >= (int)PhysReg::W0 && (int)reg <= (int)PhysReg::W15) {
return static_cast<PhysReg>((int)reg - (int)PhysReg::W0 + (int)PhysReg::S0);
}
return reg;
}
void EmitValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) {
bool is_ptr = value->GetType()->IsPointer() || value->GetType()->IsPtrInt32() || value->GetType()->IsPtrFloat();
bool is_float = value->GetType()->IsFloat();
if (is_ptr) {
target = ToXReg(target);
} else if (is_float) {
target = ToSReg(target);
}
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(constant->GetValue())});
return;
}
if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(value)) {
float f = cf->GetValue();
uint32_t bits;
std::memcpy(&bits, &f, 4);
// mov w10, #bits; fmov target, w10
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm((int)bits)});
block.Append(Opcode::MovRR, {Operand::Reg(target), Operand::Reg(PhysReg::W10)});
return;
}
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(value)) {
// This loads the VALUE of the global, not its address
block.Append(Opcode::LoadGlobal,
{Operand::Reg(target), Operand::Global(gv->GetName())});
return;
}
if (auto* arg = dynamic_cast<const ir::Argument*>(value)) {
if (arg->GetArgNo() < 8) {
PhysReg src;
if (is_ptr) {
src = static_cast<PhysReg>((int)PhysReg::X0 + arg->GetArgNo());
} else if (is_float) {
src = static_cast<PhysReg>((int)PhysReg::S0 + arg->GetArgNo());
} else {
src = static_cast<PhysReg>((int)PhysReg::W0 + arg->GetArgNo());
}
block.Append(Opcode::MovRR, {Operand::Reg(target), Operand::Reg(src)});
} else {
throw std::runtime_error(FormatError("mir", "暂不支持超过 8 个参数"));
}
return;
}
auto it = slots.find(value);
if (it == slots.end()) {
throw std::runtime_error(
FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
}
block.Append(Opcode::LoadStack,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
block.Append(Opcode::LoadStack, {Operand::Reg(target), Operand::FrameIndex(it->second)});
}
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
ValueSlotMap& slots) {
auto& block = function.GetEntry();
void EmitAddrToReg(const ir::Value* value, PhysReg target,
const MachineFunction& function,
const ValueSlotMap& slots, MachineBasicBlock& block) {
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(value)) {
// adrp x10, gv; add x10, x10, :lo12:gv
block.Append(Opcode::MovImm, {Operand::Reg(target), Operand::Global(gv->GetName())}); // Special case for address
return;
}
if (auto* arg = dynamic_cast<const ir::Argument*>(value)) {
// Argument is already an address (pointer)
EmitValueToReg(arg, target, slots, block);
return;
}
auto it = slots.find(value);
if (it != slots.end()) {
// Check if it's an alloca (frame index) or a stored address
// For alloca, we want the address: add x10, x29, #offset
// For stored address, we want to load it: ldr x10, [x29, #offset]
// In our simple lowering, alloca's value in 'slots' is the frame index.
// If 'value' is an AllocaInst, we compute its address.
if (dynamic_cast<const ir::AllocaInst*>(value)) {
block.Append(Opcode::AddrStack, {Operand::Reg(target), Operand::FrameIndex(it->second)});
return;
}
// Otherwise it's a stored address (from a GEP)
block.Append(Opcode::LoadStack, {Operand::Reg(target), Operand::FrameIndex(it->second)});
return;
}
throw std::runtime_error(FormatError("mir", "无法获取地址: " + value->GetName()));
}
size_t GetTypeSize(const ir::Type& ty) {
if (ty.IsInt32() || ty.IsFloat()) return 4;
if (ty.IsPointer() || ty.IsPtrInt32() || ty.IsPtrFloat()) return 8;
if (ty.IsArray()) {
return ty.GetNumElements() * GetTypeSize(*ty.GetElementType());
}
return 0;
}
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
MachineBasicBlock& block, ValueSlotMap& slots) {
switch (inst.GetOpcode()) {
case ir::Opcode::Alloca: {
slots.emplace(&inst, function.CreateFrameIndex());
auto& alloca = static_cast<const ir::AllocaInst&>(inst);
// AllocaInst's type is PointerType. We want the size of the pointed type.
size_t size = GetTypeSize(*alloca.GetType()->GetPointedType());
slots.emplace(&inst, function.CreateFrameIndex(static_cast<int>(size)));
return;
}
case ir::Opcode::Store: {
auto& store = static_cast<const ir::StoreInst&>(inst);
auto dst = slots.find(store.GetPtr());
if (dst == slots.end()) {
throw std::runtime_error(
FormatError("mir", "暂不支持对非栈变量地址进行写入"));
PhysReg val_reg = PhysReg::W8;
EmitValueToReg(store.GetValue(), val_reg, slots, block);
if (store.GetValue()->GetType()->IsPointer() || store.GetValue()->GetType()->IsPtrInt32() || store.GetValue()->GetType()->IsPtrFloat()) {
val_reg = ToXReg(val_reg);
} else if (store.GetValue()->GetType()->IsFloat()) {
val_reg = ToSReg(val_reg);
}
// If ptr is a global or stored address (GEP result), we use LoadR/StoreR logic
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(store.GetPtr())) {
block.Append(Opcode::StoreGlobal, {Operand::Reg(val_reg), Operand::Global(gv->GetName())});
} else if (auto* alloca = dynamic_cast<const ir::AllocaInst*>(store.GetPtr())) {
auto it = slots.find(alloca);
if (it == slots.end()) throw std::runtime_error("Alloca not found");
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)});
} else {
// Pointer is in a register (from GEP)
EmitAddrToReg(store.GetPtr(), PhysReg::X10, function, slots, block);
block.Append(Opcode::StoreR, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X10)});
}
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)});
return;
}
case ir::Opcode::Load: {
auto& load = static_cast<const ir::LoadInst&>(inst);
auto src = slots.find(load.GetPtr());
if (src == slots.end()) {
throw std::runtime_error(
FormatError("mir", "暂不支持对非栈变量地址进行读取"));
int dst_slot = function.CreateFrameIndex(static_cast<int>(GetTypeSize(*load.GetType())));
PhysReg dst_reg = PhysReg::W8;
if (load.GetType()->IsPointer() || load.GetType()->IsPtrInt32() || load.GetType()->IsPtrFloat()) {
dst_reg = ToXReg(dst_reg);
} else if (load.GetType()->IsFloat()) {
dst_reg = ToSReg(dst_reg);
}
int dst_slot = function.CreateFrameIndex();
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(load.GetPtr())) {
block.Append(Opcode::LoadGlobal, {Operand::Reg(dst_reg), Operand::Global(gv->GetName())});
} else if (auto* alloca = dynamic_cast<const ir::AllocaInst*>(load.GetPtr())) {
auto it = slots.find(alloca);
if (it == slots.end()) throw std::runtime_error("Alloca not found");
block.Append(Opcode::LoadStack, {Operand::Reg(dst_reg), Operand::FrameIndex(it->second)});
} else {
// Pointer is in a register (from GEP)
EmitAddrToReg(load.GetPtr(), PhysReg::X10, function, slots, block);
block.Append(Opcode::LoadR, {Operand::Reg(dst_reg), Operand::Reg(PhysReg::X10)});
}
block.Append(Opcode::StoreStack, {Operand::Reg(dst_reg), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Add: {
case ir::Opcode::GEP: {
auto& gep = static_cast<const ir::GEPInst&>(inst);
int dst_slot = function.CreateFrameIndex(8); // Address is 8 bytes
EmitAddrToReg(gep.GetPtr(), PhysReg::X10, function, slots, block);
// Initial type is the pointed type of the base pointer
std::shared_ptr<ir::Type> cur_ty = gep.GetPtr()->GetType()->GetPointedType();
for (size_t i = 0; i < gep.GetIndices().size(); ++i) {
ir::Value* index_val = gep.GetIndices()[i];
// Skip index 0 if it's the first index and we're starting from a pointer
if (i == 0) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(index_val)) {
if (ci->GetValue() == 0) {
continue;
}
}
EmitValueToReg(index_val, PhysReg::W8, slots, block);
size_t element_size = GetTypeSize(*cur_ty);
// Use X8 for 64-bit multiplication if element_size is large,
// but for simple cases we can use AddRRR_LSL with W8 for auto sxtw
if (element_size == 4) {
block.Append(Opcode::AddRRR_LSL, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::W8), Operand::Imm(2)});
} else if (element_size == 8) {
block.Append(Opcode::AddRRR_LSL, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::W8), Operand::Imm(3)});
} else {
block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::W8)});
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X9), Operand::Imm(static_cast<int>(element_size))});
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X9)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X8)});
}
continue;
}
if (cur_ty->IsArray()) {
size_t element_size = GetTypeSize(*cur_ty->GetElementType());
EmitValueToReg(index_val, PhysReg::W8, slots, block);
if (element_size == 4) {
block.Append(Opcode::AddRRR_LSL, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::W8), Operand::Imm(2)});
} else if (element_size == 8) {
block.Append(Opcode::AddRRR_LSL, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::W8), Operand::Imm(3)});
} else {
block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::W8)});
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X9), Operand::Imm(static_cast<int>(element_size))});
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X9)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X8)});
}
cur_ty = cur_ty->GetElementType();
} else {
throw std::runtime_error(FormatError("mir", "GEP 索引超出范围或类型不是数组"));
}
}
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::X10), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Call: {
auto& call = static_cast<const ir::CallInst&>(inst);
const auto& args = call.GetArgs();
for (size_t i = 0; i < args.size(); ++i) {
if (i < 8) {
// Determine if arg is a pointer
bool is_ptr = args[i]->GetType()->IsPointer() || args[i]->GetType()->IsPtrInt32() || args[i]->GetType()->IsPtrFloat();
PhysReg target = is_ptr ? static_cast<PhysReg>((int)PhysReg::X0 + i)
: static_cast<PhysReg>((int)PhysReg::W0 + i);
EmitValueToReg(args[i], target, slots, block);
} else {
throw std::runtime_error("Only up to 8 arguments supported for now");
}
}
block.Append(Opcode::Call, {Operand::Label(call.GetFunc()->GetName())});
if (!call.GetType()->IsVoid()) {
int dst_slot = function.CreateFrameIndex(static_cast<int>(GetTypeSize(*call.GetType())));
PhysReg ret_reg = PhysReg::W0;
if (call.GetType()->IsFloat()) {
ret_reg = ToSReg(ret_reg);
} else if (call.GetType()->IsPointer() || call.GetType()->IsPtrInt32() || call.GetType()->IsPtrFloat()) {
ret_reg = ToXReg(ret_reg);
}
block.Append(Opcode::StoreStack, {Operand::Reg(ret_reg), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
}
return;
}
case ir::Opcode::Add:
case ir::Opcode::Sub:
case ir::Opcode::Mul:
case ir::Opcode::Div:
case ir::Opcode::Mod: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
if (bin.GetType()->IsFloat()) {
PhysReg lhs_reg = PhysReg::W8;
PhysReg rhs_reg = PhysReg::W9;
EmitValueToReg(bin.GetLhs(), lhs_reg, slots, block);
EmitValueToReg(bin.GetRhs(), rhs_reg, slots, block);
lhs_reg = ToSReg(lhs_reg);
rhs_reg = ToSReg(rhs_reg);
Opcode op;
if (inst.GetOpcode() == ir::Opcode::Add) op = Opcode::FAdd;
else if (inst.GetOpcode() == ir::Opcode::Sub) op = Opcode::FSub;
else if (inst.GetOpcode() == ir::Opcode::Mul) op = Opcode::FMUL;
else if (inst.GetOpcode() == ir::Opcode::Div) op = Opcode::FDiv;
else throw std::runtime_error("Float mod not supported");
block.Append(op, {Operand::Reg(PhysReg::S0), Operand::Reg(lhs_reg), Operand::Reg(rhs_reg)});
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
if (inst.GetOpcode() == ir::Opcode::Add) {
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
} else if (inst.GetOpcode() == ir::Opcode::Sub) {
block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
} else if (inst.GetOpcode() == ir::Opcode::Mul) {
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
} else if (inst.GetOpcode() == ir::Opcode::Div) {
block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
} else if (inst.GetOpcode() == ir::Opcode::Mod) {
// srem w10, w8, w9 => sdiv w10, w8, w9; msub w8, w10, w9, w8
block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
block.Append(Opcode::MSubRRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::W8)});
}
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
}
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::SIToFP: {
auto& fcvt = static_cast<const ir::UnaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(fcvt.GetUnaryOperand(), PhysReg::W8, slots, block);
block.Append(Opcode::FCvtSI2FP, {Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::W8)});
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::FPToSI: {
auto& fcvt = static_cast<const ir::UnaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(fcvt.GetUnaryOperand(), PhysReg::W8, slots, block);
block.Append(Opcode::FCvtFP2SI, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::S8)});
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Cmp:
case ir::Opcode::FCmp: {
int dst_slot = function.CreateFrameIndex();
ir::CmpOp ir_cc;
if (inst.GetOpcode() == ir::Opcode::Cmp) {
auto& cmp = static_cast<const ir::CmpInst&>(inst);
EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
ir_cc = cmp.GetCmpOp();
} else {
auto& cmp = static_cast<const ir::FCmpInst&>(inst);
EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::FCmp, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)});
ir_cc = cmp.GetCmpOp();
}
CondCode cc = CondCode::EQ;
switch (ir_cc) {
case ir::CmpOp::Eq: cc = CondCode::EQ; break;
case ir::CmpOp::Ne: cc = CondCode::NE; break;
case ir::CmpOp::Lt: cc = CondCode::LT; break;
case ir::CmpOp::Le: cc = CondCode::LE; break;
case ir::CmpOp::Gt: cc = CondCode::GT; break;
case ir::CmpOp::Ge: cc = CondCode::GE; break;
}
block.Append(Opcode::CSet, {Operand::Reg(PhysReg::W8), Operand::Cond(cc)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Zext: {
auto& zext = static_cast<const ir::ZextInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(zext.GetValue(), PhysReg::W8, slots, block);
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Neg: {
auto& unary = static_cast<const ir::UnaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
if (unary.GetType()->IsFloat()) {
EmitValueToReg(unary.GetUnaryOperand(), PhysReg::W8, slots, block);
block.Append(Opcode::FNeg, {Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::S8)});
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
EmitValueToReg(unary.GetUnaryOperand(), PhysReg::W8, slots, block);
block.Append(Opcode::NegR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
}
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Br: {
auto& br = static_cast<const ir::BranchInst&>(inst);
block.Append(Opcode::B, {Operand::Label(br.GetDest()->GetName())});
return;
}
case ir::Opcode::CondBr: {
auto& cbr = static_cast<const ir::CondBranchInst&>(inst);
EmitValueToReg(cbr.GetCond(), PhysReg::W8, slots, block);
// SysY IR CondBr uses i1. In MIR, we compare with 0.
block.Append(Opcode::BCond, {Operand::Cond(CondCode::NE),
Operand::Reg(PhysReg::W8),
Operand::Label(cbr.GetTrueBlock()->GetName())});
block.Append(Opcode::B, {Operand::Label(cbr.GetFalseBlock()->GetName())});
return;
}
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block);
if (auto* val = ret.GetValue()) {
EmitValueToReg(val, PhysReg::W0, slots, block);
}
block.Append(Opcode::Ret);
return;
}
case ir::Opcode::Sub:
case ir::Opcode::Mul:
throw std::runtime_error(FormatError("mir", "暂不支持该二元运算"));
default:
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令"));
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令: " + std::to_string((int)inst.GetOpcode())));
}
}
} // namespace
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module) {
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
DefaultContext();
auto machine_module = std::make_unique<MachineModule>();
if (module.GetFunctions().size() != 1) {
throw std::runtime_error(FormatError("mir", "暂不支持多个函数"));
// Lower global variables
for (const auto& gv : module.GetGlobalVariables()) {
GlobalVariable mir_gv;
mir_gv.name = gv->GetName();
mir_gv.size = GetTypeSize(*gv->GetType()->GetPointedType());
if (auto* init = gv->GetInitializer()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(init)) {
mir_gv.init_value = ci->GetValue();
} else if (auto* cf = dynamic_cast<ir::ConstantFloat*>(init)) {
float f = cf->GetValue();
uint32_t bits;
std::memcpy(&bits, &f, 4);
mir_gv.init_value = static_cast<int>(bits);
}
}
machine_module->GetGlobals().push_back(mir_gv);
}
const auto& func = *module.GetFunctions().front();
if (func.GetName() != "main") {
throw std::runtime_error(FormatError("mir", "暂不支持非 main 函数"));
}
// Lower functions
for (const auto& ir_func : module.GetFunctions()) {
if (ir_func->GetBlocks().empty()) continue; // Skip declarations
auto machine_func = std::make_unique<MachineFunction>(ir_func->GetName());
ValueSlotMap slots;
auto machine_func = std::make_unique<MachineFunction>(func.GetName());
ValueSlotMap slots;
const auto* entry = func.GetEntry();
if (!entry) {
throw std::runtime_error(FormatError("mir", "IR 函数缺少入口基本块"));
}
// Create all blocks first to handle forward references in branches
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> block_map;
for (const auto& ir_bb : ir_func->GetBlocks()) {
block_map[ir_bb.get()] = &machine_func->CreateBlock(ir_bb->GetName());
}
// Lower instructions in each block
for (const auto& ir_bb : ir_func->GetBlocks()) {
auto& machine_bb = *block_map.at(ir_bb.get());
for (const auto& inst : ir_bb->GetInstructions()) {
LowerInstruction(*inst, *machine_func, machine_bb, slots);
}
}
for (const auto& inst : entry->GetInstructions()) {
LowerInstruction(*inst, *machine_func, slots);
machine_module->GetFunctions().push_back(std::move(machine_func));
}
return machine_func;
return machine_module;
}
} // namespace mir

@ -8,7 +8,12 @@
namespace mir {
MachineFunction::MachineFunction(std::string name)
: name_(std::move(name)), entry_("entry") {}
: name_(std::move(name)) {}
MachineBasicBlock& MachineFunction::CreateBlock(const std::string& name) {
blocks_.push_back(std::make_unique<MachineBasicBlock>(name));
return *blocks_.back();
}
int MachineFunction::CreateFrameIndex(int size) {
int index = static_cast<int>(frame_slots_.size());

@ -4,17 +4,29 @@
namespace mir {
Operand::Operand(Kind kind, PhysReg reg, int imm)
: kind_(kind), reg_(reg), imm_(imm) {}
Operand::Operand(Kind kind, PhysReg reg, int imm, std::string label)
: kind_(kind), reg_(reg), imm_(imm), label_(std::move(label)) {}
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); }
Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::W0, value);
return Operand(Kind::Imm, PhysReg::WZR, value);
}
Operand Operand::FrameIndex(int index) {
return Operand(Kind::FrameIndex, PhysReg::W0, index);
return Operand(Kind::FrameIndex, PhysReg::WZR, index);
}
Operand Operand::Label(const std::string& name) {
return Operand(Kind::Label, PhysReg::WZR, 0, name);
}
Operand Operand::Global(const std::string& name) {
return Operand(Kind::Global, PhysReg::WZR, 0, name);
}
Operand Operand::Cond(CondCode cc) {
return Operand(Kind::Cond, PhysReg::WZR, static_cast<int>(cc));
}
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)

@ -8,26 +8,19 @@ namespace mir {
namespace {
bool IsAllowedReg(PhysReg reg) {
switch (reg) {
case PhysReg::W0:
case PhysReg::W8:
case PhysReg::W9:
case PhysReg::X29:
case PhysReg::X30:
case PhysReg::SP:
return true;
}
return false;
return true; // All registers are allowed for now as we are not doing allocation
}
} // namespace
void RunRegAlloc(MachineFunction& function) {
for (const auto& inst : function.GetEntry().GetInstructions()) {
for (const auto& operand : inst.GetOperands()) {
if (operand.GetKind() == Operand::Kind::Reg &&
!IsAllowedReg(operand.GetReg())) {
throw std::runtime_error(FormatError("mir", "寄存器分配失败"));
for (auto& block : function.GetBlocks()) {
for (const auto& inst : block->GetInstructions()) {
for (const auto& operand : inst.GetOperands()) {
if (operand.GetKind() == Operand::Kind::Reg &&
!IsAllowedReg(operand.GetReg())) {
throw std::runtime_error(FormatError("mir", "寄存器分配失败"));
}
}
}
}

@ -8,18 +8,61 @@ namespace mir {
const char* PhysRegName(PhysReg reg) {
switch (reg) {
case PhysReg::W0:
return "w0";
case PhysReg::W8:
return "w8";
case PhysReg::W9:
return "w9";
case PhysReg::X29:
return "x29";
case PhysReg::X30:
return "x30";
case PhysReg::SP:
return "sp";
case PhysReg::W0: return "w0";
case PhysReg::W1: return "w1";
case PhysReg::W2: return "w2";
case PhysReg::W3: return "w3";
case PhysReg::W4: return "w4";
case PhysReg::W5: return "w5";
case PhysReg::W6: return "w6";
case PhysReg::W7: return "w7";
case PhysReg::W8: return "w8";
case PhysReg::W9: return "w9";
case PhysReg::W10: return "w10";
case PhysReg::W11: return "w11";
case PhysReg::W12: return "w12";
case PhysReg::W13: return "w13";
case PhysReg::W14: return "w14";
case PhysReg::W15: return "w15";
case PhysReg::X0: return "x0";
case PhysReg::X1: return "x1";
case PhysReg::X2: return "x2";
case PhysReg::X3: return "x3";
case PhysReg::X4: return "x4";
case PhysReg::X5: return "x5";
case PhysReg::X6: return "x6";
case PhysReg::X7: return "x7";
case PhysReg::X8: return "x8";
case PhysReg::X9: return "x9";
case PhysReg::X10: return "x10";
case PhysReg::X11: return "x11";
case PhysReg::X12: return "x12";
case PhysReg::X13: return "x13";
case PhysReg::X14: return "x14";
case PhysReg::X15: return "x15";
case PhysReg::X16: return "x16";
case PhysReg::X17: return "x17";
case PhysReg::S0: return "s0";
case PhysReg::S1: return "s1";
case PhysReg::S2: return "s2";
case PhysReg::S3: return "s3";
case PhysReg::S4: return "s4";
case PhysReg::S5: return "s5";
case PhysReg::S6: return "s6";
case PhysReg::S7: return "s7";
case PhysReg::S8: return "s8";
case PhysReg::S9: return "s9";
case PhysReg::S10: return "s10";
case PhysReg::S11: return "s11";
case PhysReg::S12: return "s12";
case PhysReg::S13: return "s13";
case PhysReg::S14: return "s14";
case PhysReg::S15: return "s15";
case PhysReg::X29: return "x29";
case PhysReg::X30: return "x30";
case PhysReg::SP: return "sp";
case PhysReg::WZR: return "wzr";
case PhysReg::XZR: return "xzr";
}
throw std::runtime_error(FormatError("mir", "未知物理寄存器"));
}

Loading…
Cancel
Save