diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index 6753a77..9382220 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -1,5 +1,6 @@ #include "mir/MIR.h" +#include #include #include @@ -11,113 +12,474 @@ namespace { using ValueSlotMap = std::unordered_map; +PhysReg ToXReg(PhysReg reg) { + if ((int)reg >= (int)PhysReg::W0 && (int)reg <= (int)PhysReg::W15) { + return static_cast((int)reg - (int)PhysReg::W0 + (int)PhysReg::X0); + } + return reg; +} + +PhysReg ToSReg(PhysReg reg) { + if ((int)reg >= (int)PhysReg::W0 && (int)reg <= (int)PhysReg::W15) { + return static_cast((int)reg - (int)PhysReg::W0 + (int)PhysReg::S0); + } + return reg; +} + void EmitValueToReg(const ir::Value* value, PhysReg target, const ValueSlotMap& slots, MachineBasicBlock& block) { + bool is_ptr = value->GetType()->IsPointer() || value->GetType()->IsPtrInt32() || value->GetType()->IsPtrFloat(); + bool is_float = value->GetType()->IsFloat(); + + if (is_ptr) { + target = ToXReg(target); + } else if (is_float) { + target = ToSReg(target); + } + if (auto* constant = dynamic_cast(value)) { block.Append(Opcode::MovImm, {Operand::Reg(target), Operand::Imm(constant->GetValue())}); return; } + if (auto* cf = dynamic_cast(value)) { + float f = cf->GetValue(); + uint32_t bits; + std::memcpy(&bits, &f, 4); + // mov w10, #bits; fmov target, w10 + block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm((int)bits)}); + block.Append(Opcode::MovRR, {Operand::Reg(target), Operand::Reg(PhysReg::W10)}); + return; + } + + if (auto* gv = dynamic_cast(value)) { + // This loads the VALUE of the global, not its address + block.Append(Opcode::LoadGlobal, + {Operand::Reg(target), Operand::Global(gv->GetName())}); + return; + } + + if (auto* arg = dynamic_cast(value)) { + if (arg->GetArgNo() < 8) { + PhysReg src; + if (is_ptr) { + src = static_cast((int)PhysReg::X0 + arg->GetArgNo()); + } else if (is_float) { + src = static_cast((int)PhysReg::S0 + arg->GetArgNo()); + } else { + src = static_cast((int)PhysReg::W0 + arg->GetArgNo()); + } + block.Append(Opcode::MovRR, {Operand::Reg(target), Operand::Reg(src)}); + } else { + throw std::runtime_error(FormatError("mir", "暂不支持超过 8 个参数")); + } + return; + } + auto it = slots.find(value); if (it == slots.end()) { throw std::runtime_error( FormatError("mir", "找不到值对应的栈槽: " + value->GetName())); } - block.Append(Opcode::LoadStack, - {Operand::Reg(target), Operand::FrameIndex(it->second)}); + block.Append(Opcode::LoadStack, {Operand::Reg(target), Operand::FrameIndex(it->second)}); } -void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, - ValueSlotMap& slots) { - auto& block = function.GetEntry(); +void EmitAddrToReg(const ir::Value* value, PhysReg target, + const MachineFunction& function, + const ValueSlotMap& slots, MachineBasicBlock& block) { + if (auto* gv = dynamic_cast(value)) { + // adrp x10, gv; add x10, x10, :lo12:gv + block.Append(Opcode::MovImm, {Operand::Reg(target), Operand::Global(gv->GetName())}); // Special case for address + return; + } + + if (auto* arg = dynamic_cast(value)) { + // Argument is already an address (pointer) + EmitValueToReg(arg, target, slots, block); + return; + } + + auto it = slots.find(value); + if (it != slots.end()) { + // Check if it's an alloca (frame index) or a stored address + // For alloca, we want the address: add x10, x29, #offset + // For stored address, we want to load it: ldr x10, [x29, #offset] + + // In our simple lowering, alloca's value in 'slots' is the frame index. + // If 'value' is an AllocaInst, we compute its address. + if (dynamic_cast(value)) { + block.Append(Opcode::AddrStack, {Operand::Reg(target), Operand::FrameIndex(it->second)}); + return; + } + + // Otherwise it's a stored address (from a GEP) + block.Append(Opcode::LoadStack, {Operand::Reg(target), Operand::FrameIndex(it->second)}); + return; + } + + throw std::runtime_error(FormatError("mir", "无法获取地址: " + value->GetName())); +} +size_t GetTypeSize(const ir::Type& ty) { + if (ty.IsInt32() || ty.IsFloat()) return 4; + if (ty.IsPointer() || ty.IsPtrInt32() || ty.IsPtrFloat()) return 8; + if (ty.IsArray()) { + return ty.GetNumElements() * GetTypeSize(*ty.GetElementType()); + } + return 0; +} + +void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, + MachineBasicBlock& block, ValueSlotMap& slots) { switch (inst.GetOpcode()) { case ir::Opcode::Alloca: { - slots.emplace(&inst, function.CreateFrameIndex()); + auto& alloca = static_cast(inst); + // AllocaInst's type is PointerType. We want the size of the pointed type. + size_t size = GetTypeSize(*alloca.GetType()->GetPointedType()); + slots.emplace(&inst, function.CreateFrameIndex(static_cast(size))); return; } case ir::Opcode::Store: { auto& store = static_cast(inst); - auto dst = slots.find(store.GetPtr()); - if (dst == slots.end()) { - throw std::runtime_error( - FormatError("mir", "暂不支持对非栈变量地址进行写入")); + PhysReg val_reg = PhysReg::W8; + EmitValueToReg(store.GetValue(), val_reg, slots, block); + if (store.GetValue()->GetType()->IsPointer() || store.GetValue()->GetType()->IsPtrInt32() || store.GetValue()->GetType()->IsPtrFloat()) { + val_reg = ToXReg(val_reg); + } else if (store.GetValue()->GetType()->IsFloat()) { + val_reg = ToSReg(val_reg); + } + + // If ptr is a global or stored address (GEP result), we use LoadR/StoreR logic + if (auto* gv = dynamic_cast(store.GetPtr())) { + block.Append(Opcode::StoreGlobal, {Operand::Reg(val_reg), Operand::Global(gv->GetName())}); + } else if (auto* alloca = dynamic_cast(store.GetPtr())) { + auto it = slots.find(alloca); + if (it == slots.end()) throw std::runtime_error("Alloca not found"); + block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)}); + } else { + // Pointer is in a register (from GEP) + EmitAddrToReg(store.GetPtr(), PhysReg::X10, function, slots, block); + block.Append(Opcode::StoreR, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X10)}); } - EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)}); return; } case ir::Opcode::Load: { auto& load = static_cast(inst); - auto src = slots.find(load.GetPtr()); - if (src == slots.end()) { - throw std::runtime_error( - FormatError("mir", "暂不支持对非栈变量地址进行读取")); + int dst_slot = function.CreateFrameIndex(static_cast(GetTypeSize(*load.GetType()))); + PhysReg dst_reg = PhysReg::W8; + if (load.GetType()->IsPointer() || load.GetType()->IsPtrInt32() || load.GetType()->IsPtrFloat()) { + dst_reg = ToXReg(dst_reg); + } else if (load.GetType()->IsFloat()) { + dst_reg = ToSReg(dst_reg); } - int dst_slot = function.CreateFrameIndex(); - block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + + if (auto* gv = dynamic_cast(load.GetPtr())) { + block.Append(Opcode::LoadGlobal, {Operand::Reg(dst_reg), Operand::Global(gv->GetName())}); + } else if (auto* alloca = dynamic_cast(load.GetPtr())) { + auto it = slots.find(alloca); + if (it == slots.end()) throw std::runtime_error("Alloca not found"); + block.Append(Opcode::LoadStack, {Operand::Reg(dst_reg), Operand::FrameIndex(it->second)}); + } else { + // Pointer is in a register (from GEP) + EmitAddrToReg(load.GetPtr(), PhysReg::X10, function, slots, block); + block.Append(Opcode::LoadR, {Operand::Reg(dst_reg), Operand::Reg(PhysReg::X10)}); + } + + block.Append(Opcode::StoreStack, {Operand::Reg(dst_reg), Operand::FrameIndex(dst_slot)}); slots.emplace(&inst, dst_slot); return; } - case ir::Opcode::Add: { + case ir::Opcode::GEP: { + auto& gep = static_cast(inst); + int dst_slot = function.CreateFrameIndex(8); // Address is 8 bytes + + EmitAddrToReg(gep.GetPtr(), PhysReg::X10, function, slots, block); + + // Initial type is the pointed type of the base pointer + std::shared_ptr cur_ty = gep.GetPtr()->GetType()->GetPointedType(); + + for (size_t i = 0; i < gep.GetIndices().size(); ++i) { + ir::Value* index_val = gep.GetIndices()[i]; + + // Skip index 0 if it's the first index and we're starting from a pointer + if (i == 0) { + if (auto* ci = dynamic_cast(index_val)) { + if (ci->GetValue() == 0) { + continue; + } + } + EmitValueToReg(index_val, PhysReg::W8, slots, block); + size_t element_size = GetTypeSize(*cur_ty); + // Use X8 for 64-bit multiplication if element_size is large, + // but for simple cases we can use AddRRR_LSL with W8 for auto sxtw + if (element_size == 4) { + block.Append(Opcode::AddRRR_LSL, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::W8), Operand::Imm(2)}); + } else if (element_size == 8) { + block.Append(Opcode::AddRRR_LSL, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::W8), Operand::Imm(3)}); + } else { + block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::W8)}); + block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X9), Operand::Imm(static_cast(element_size))}); + block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X9)}); + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X8)}); + } + continue; + } + + if (cur_ty->IsArray()) { + size_t element_size = GetTypeSize(*cur_ty->GetElementType()); + EmitValueToReg(index_val, PhysReg::W8, slots, block); + if (element_size == 4) { + block.Append(Opcode::AddRRR_LSL, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::W8), Operand::Imm(2)}); + } else if (element_size == 8) { + block.Append(Opcode::AddRRR_LSL, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::W8), Operand::Imm(3)}); + } else { + block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::W8)}); + block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X9), Operand::Imm(static_cast(element_size))}); + block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X9)}); + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X8)}); + } + cur_ty = cur_ty->GetElementType(); + } else { + throw std::runtime_error(FormatError("mir", "GEP 索引超出范围或类型不是数组")); + } + } + + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::X10), Operand::FrameIndex(dst_slot)}); + slots.emplace(&inst, dst_slot); + return; + } + case ir::Opcode::Call: { + auto& call = static_cast(inst); + const auto& args = call.GetArgs(); + for (size_t i = 0; i < args.size(); ++i) { + if (i < 8) { + // Determine if arg is a pointer + bool is_ptr = args[i]->GetType()->IsPointer() || args[i]->GetType()->IsPtrInt32() || args[i]->GetType()->IsPtrFloat(); + PhysReg target = is_ptr ? static_cast((int)PhysReg::X0 + i) + : static_cast((int)PhysReg::W0 + i); + EmitValueToReg(args[i], target, slots, block); + } else { + throw std::runtime_error("Only up to 8 arguments supported for now"); + } + } + block.Append(Opcode::Call, {Operand::Label(call.GetFunc()->GetName())}); + + if (!call.GetType()->IsVoid()) { + int dst_slot = function.CreateFrameIndex(static_cast(GetTypeSize(*call.GetType()))); + PhysReg ret_reg = PhysReg::W0; + if (call.GetType()->IsFloat()) { + ret_reg = ToSReg(ret_reg); + } else if (call.GetType()->IsPointer() || call.GetType()->IsPtrInt32() || call.GetType()->IsPtrFloat()) { + ret_reg = ToXReg(ret_reg); + } + block.Append(Opcode::StoreStack, {Operand::Reg(ret_reg), Operand::FrameIndex(dst_slot)}); + slots.emplace(&inst, dst_slot); + } + return; + } + case ir::Opcode::Add: + case ir::Opcode::Sub: + case ir::Opcode::Mul: + case ir::Opcode::Div: + case ir::Opcode::Mod: { auto& bin = static_cast(inst); int dst_slot = function.CreateFrameIndex(); - EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); - EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9)}); + + if (bin.GetType()->IsFloat()) { + PhysReg lhs_reg = PhysReg::W8; + PhysReg rhs_reg = PhysReg::W9; + EmitValueToReg(bin.GetLhs(), lhs_reg, slots, block); + EmitValueToReg(bin.GetRhs(), rhs_reg, slots, block); + lhs_reg = ToSReg(lhs_reg); + rhs_reg = ToSReg(rhs_reg); + + Opcode op; + if (inst.GetOpcode() == ir::Opcode::Add) op = Opcode::FAdd; + else if (inst.GetOpcode() == ir::Opcode::Sub) op = Opcode::FSub; + else if (inst.GetOpcode() == ir::Opcode::Mul) op = Opcode::FMUL; + else if (inst.GetOpcode() == ir::Opcode::Div) op = Opcode::FDiv; + else throw std::runtime_error("Float mod not supported"); + + block.Append(op, {Operand::Reg(PhysReg::S0), Operand::Reg(lhs_reg), Operand::Reg(rhs_reg)}); + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); + } else { + EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); + + if (inst.GetOpcode() == ir::Opcode::Add) { + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); + } else if (inst.GetOpcode() == ir::Opcode::Sub) { + block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); + } else if (inst.GetOpcode() == ir::Opcode::Mul) { + block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); + } else if (inst.GetOpcode() == ir::Opcode::Div) { + block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); + } else if (inst.GetOpcode() == ir::Opcode::Mod) { + // srem w10, w8, w9 => sdiv w10, w8, w9; msub w8, w10, w9, w8 + block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); + block.Append(Opcode::MSubRRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::W8)}); + } + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + } + + slots.emplace(&inst, dst_slot); + return; + } + case ir::Opcode::SIToFP: { + auto& fcvt = static_cast(inst); + int dst_slot = function.CreateFrameIndex(); + EmitValueToReg(fcvt.GetUnaryOperand(), PhysReg::W8, slots, block); + block.Append(Opcode::FCvtSI2FP, {Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::W8)}); + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); + slots.emplace(&inst, dst_slot); + return; + } + case ir::Opcode::FPToSI: { + auto& fcvt = static_cast(inst); + int dst_slot = function.CreateFrameIndex(); + EmitValueToReg(fcvt.GetUnaryOperand(), PhysReg::W8, slots, block); + block.Append(Opcode::FCvtFP2SI, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::S8)}); + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + slots.emplace(&inst, dst_slot); + return; + } + case ir::Opcode::Cmp: + case ir::Opcode::FCmp: { + int dst_slot = function.CreateFrameIndex(); + ir::CmpOp ir_cc; + if (inst.GetOpcode() == ir::Opcode::Cmp) { + auto& cmp = static_cast(inst); + EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block); + EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block); + block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); + ir_cc = cmp.GetCmpOp(); + } else { + auto& cmp = static_cast(inst); + EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block); + EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block); + block.Append(Opcode::FCmp, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)}); + ir_cc = cmp.GetCmpOp(); + } + + CondCode cc = CondCode::EQ; + switch (ir_cc) { + case ir::CmpOp::Eq: cc = CondCode::EQ; break; + case ir::CmpOp::Ne: cc = CondCode::NE; break; + case ir::CmpOp::Lt: cc = CondCode::LT; break; + case ir::CmpOp::Le: cc = CondCode::LE; break; + case ir::CmpOp::Gt: cc = CondCode::GT; break; + case ir::CmpOp::Ge: cc = CondCode::GE; break; + } + + block.Append(Opcode::CSet, {Operand::Reg(PhysReg::W8), Operand::Cond(cc)}); block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); slots.emplace(&inst, dst_slot); return; } + case ir::Opcode::Zext: { + auto& zext = static_cast(inst); + int dst_slot = function.CreateFrameIndex(); + EmitValueToReg(zext.GetValue(), PhysReg::W8, slots, block); + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + slots.emplace(&inst, dst_slot); + return; + } + case ir::Opcode::Neg: { + auto& unary = static_cast(inst); + int dst_slot = function.CreateFrameIndex(); + if (unary.GetType()->IsFloat()) { + EmitValueToReg(unary.GetUnaryOperand(), PhysReg::W8, slots, block); + block.Append(Opcode::FNeg, {Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::S8)}); + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); + } else { + EmitValueToReg(unary.GetUnaryOperand(), PhysReg::W8, slots, block); + block.Append(Opcode::NegR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + } + slots.emplace(&inst, dst_slot); + return; + } + case ir::Opcode::Br: { + auto& br = static_cast(inst); + block.Append(Opcode::B, {Operand::Label(br.GetDest()->GetName())}); + return; + } + case ir::Opcode::CondBr: { + auto& cbr = static_cast(inst); + EmitValueToReg(cbr.GetCond(), PhysReg::W8, slots, block); + // SysY IR CondBr uses i1. In MIR, we compare with 0. + block.Append(Opcode::BCond, {Operand::Cond(CondCode::NE), + Operand::Reg(PhysReg::W8), + Operand::Label(cbr.GetTrueBlock()->GetName())}); + block.Append(Opcode::B, {Operand::Label(cbr.GetFalseBlock()->GetName())}); + return; + } case ir::Opcode::Ret: { auto& ret = static_cast(inst); - EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block); + if (auto* val = ret.GetValue()) { + EmitValueToReg(val, PhysReg::W0, slots, block); + } block.Append(Opcode::Ret); return; } - case ir::Opcode::Sub: - case ir::Opcode::Mul: - throw std::runtime_error(FormatError("mir", "暂不支持该二元运算")); default: - throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令")); + throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令: " + std::to_string((int)inst.GetOpcode()))); } } } // namespace -std::unique_ptr LowerToMIR(const ir::Module& module) { +std::unique_ptr LowerToMIR(const ir::Module& module) { DefaultContext(); + auto machine_module = std::make_unique(); - if (module.GetFunctions().size() != 1) { - throw std::runtime_error(FormatError("mir", "暂不支持多个函数")); + // Lower global variables + for (const auto& gv : module.GetGlobalVariables()) { + GlobalVariable mir_gv; + mir_gv.name = gv->GetName(); + mir_gv.size = GetTypeSize(*gv->GetType()->GetPointedType()); + if (auto* init = gv->GetInitializer()) { + if (auto* ci = dynamic_cast(init)) { + mir_gv.init_value = ci->GetValue(); + } else if (auto* cf = dynamic_cast(init)) { + float f = cf->GetValue(); + uint32_t bits; + std::memcpy(&bits, &f, 4); + mir_gv.init_value = static_cast(bits); + } + } + machine_module->GetGlobals().push_back(mir_gv); } - const auto& func = *module.GetFunctions().front(); - if (func.GetName() != "main") { - throw std::runtime_error(FormatError("mir", "暂不支持非 main 函数")); - } + // Lower functions + for (const auto& ir_func : module.GetFunctions()) { + if (ir_func->GetBlocks().empty()) continue; // Skip declarations + + auto machine_func = std::make_unique(ir_func->GetName()); + ValueSlotMap slots; - auto machine_func = std::make_unique(func.GetName()); - ValueSlotMap slots; - const auto* entry = func.GetEntry(); - if (!entry) { - throw std::runtime_error(FormatError("mir", "IR 函数缺少入口基本块")); - } + // Create all blocks first to handle forward references in branches + std::unordered_map block_map; + for (const auto& ir_bb : ir_func->GetBlocks()) { + block_map[ir_bb.get()] = &machine_func->CreateBlock(ir_bb->GetName()); + } + + // Lower instructions in each block + for (const auto& ir_bb : ir_func->GetBlocks()) { + auto& machine_bb = *block_map.at(ir_bb.get()); + for (const auto& inst : ir_bb->GetInstructions()) { + LowerInstruction(*inst, *machine_func, machine_bb, slots); + } + } - for (const auto& inst : entry->GetInstructions()) { - LowerInstruction(*inst, *machine_func, slots); + machine_module->GetFunctions().push_back(std::move(machine_func)); } - return machine_func; + return machine_module; } } // namespace mir diff --git a/src/mir/MIRFunction.cpp b/src/mir/MIRFunction.cpp index 334f8cc..9798e0a 100644 --- a/src/mir/MIRFunction.cpp +++ b/src/mir/MIRFunction.cpp @@ -8,7 +8,12 @@ namespace mir { MachineFunction::MachineFunction(std::string name) - : name_(std::move(name)), entry_("entry") {} + : name_(std::move(name)) {} + +MachineBasicBlock& MachineFunction::CreateBlock(const std::string& name) { + blocks_.push_back(std::make_unique(name)); + return *blocks_.back(); +} int MachineFunction::CreateFrameIndex(int size) { int index = static_cast(frame_slots_.size()); diff --git a/src/mir/MIRInstr.cpp b/src/mir/MIRInstr.cpp index 0a21a03..966e9f0 100644 --- a/src/mir/MIRInstr.cpp +++ b/src/mir/MIRInstr.cpp @@ -4,17 +4,29 @@ namespace mir { -Operand::Operand(Kind kind, PhysReg reg, int imm) - : kind_(kind), reg_(reg), imm_(imm) {} +Operand::Operand(Kind kind, PhysReg reg, int imm, std::string label) + : kind_(kind), reg_(reg), imm_(imm), label_(std::move(label)) {} Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); } Operand Operand::Imm(int value) { - return Operand(Kind::Imm, PhysReg::W0, value); + return Operand(Kind::Imm, PhysReg::WZR, value); } Operand Operand::FrameIndex(int index) { - return Operand(Kind::FrameIndex, PhysReg::W0, index); + return Operand(Kind::FrameIndex, PhysReg::WZR, index); +} + +Operand Operand::Label(const std::string& name) { + return Operand(Kind::Label, PhysReg::WZR, 0, name); +} + +Operand Operand::Global(const std::string& name) { + return Operand(Kind::Global, PhysReg::WZR, 0, name); +} + +Operand Operand::Cond(CondCode cc) { + return Operand(Kind::Cond, PhysReg::WZR, static_cast(cc)); } MachineInstr::MachineInstr(Opcode opcode, std::vector operands) diff --git a/src/mir/RegAlloc.cpp b/src/mir/RegAlloc.cpp index 5dc5d2b..d888714 100644 --- a/src/mir/RegAlloc.cpp +++ b/src/mir/RegAlloc.cpp @@ -8,26 +8,19 @@ namespace mir { namespace { bool IsAllowedReg(PhysReg reg) { - switch (reg) { - case PhysReg::W0: - case PhysReg::W8: - case PhysReg::W9: - case PhysReg::X29: - case PhysReg::X30: - case PhysReg::SP: - return true; - } - return false; + return true; // All registers are allowed for now as we are not doing allocation } } // namespace void RunRegAlloc(MachineFunction& function) { - for (const auto& inst : function.GetEntry().GetInstructions()) { - for (const auto& operand : inst.GetOperands()) { - if (operand.GetKind() == Operand::Kind::Reg && - !IsAllowedReg(operand.GetReg())) { - throw std::runtime_error(FormatError("mir", "寄存器分配失败")); + for (auto& block : function.GetBlocks()) { + for (const auto& inst : block->GetInstructions()) { + for (const auto& operand : inst.GetOperands()) { + if (operand.GetKind() == Operand::Kind::Reg && + !IsAllowedReg(operand.GetReg())) { + throw std::runtime_error(FormatError("mir", "寄存器分配失败")); + } } } } diff --git a/src/mir/Register.cpp b/src/mir/Register.cpp index 7530470..d04d42c 100644 --- a/src/mir/Register.cpp +++ b/src/mir/Register.cpp @@ -8,18 +8,61 @@ namespace mir { const char* PhysRegName(PhysReg reg) { switch (reg) { - case PhysReg::W0: - return "w0"; - case PhysReg::W8: - return "w8"; - case PhysReg::W9: - return "w9"; - case PhysReg::X29: - return "x29"; - case PhysReg::X30: - return "x30"; - case PhysReg::SP: - return "sp"; + case PhysReg::W0: return "w0"; + case PhysReg::W1: return "w1"; + case PhysReg::W2: return "w2"; + case PhysReg::W3: return "w3"; + case PhysReg::W4: return "w4"; + case PhysReg::W5: return "w5"; + case PhysReg::W6: return "w6"; + case PhysReg::W7: return "w7"; + case PhysReg::W8: return "w8"; + case PhysReg::W9: return "w9"; + case PhysReg::W10: return "w10"; + case PhysReg::W11: return "w11"; + case PhysReg::W12: return "w12"; + case PhysReg::W13: return "w13"; + case PhysReg::W14: return "w14"; + case PhysReg::W15: return "w15"; + case PhysReg::X0: return "x0"; + case PhysReg::X1: return "x1"; + case PhysReg::X2: return "x2"; + case PhysReg::X3: return "x3"; + case PhysReg::X4: return "x4"; + case PhysReg::X5: return "x5"; + case PhysReg::X6: return "x6"; + case PhysReg::X7: return "x7"; + case PhysReg::X8: return "x8"; + case PhysReg::X9: return "x9"; + case PhysReg::X10: return "x10"; + case PhysReg::X11: return "x11"; + case PhysReg::X12: return "x12"; + case PhysReg::X13: return "x13"; + case PhysReg::X14: return "x14"; + case PhysReg::X15: return "x15"; + case PhysReg::X16: return "x16"; + case PhysReg::X17: return "x17"; + case PhysReg::S0: return "s0"; + case PhysReg::S1: return "s1"; + case PhysReg::S2: return "s2"; + case PhysReg::S3: return "s3"; + case PhysReg::S4: return "s4"; + case PhysReg::S5: return "s5"; + case PhysReg::S6: return "s6"; + case PhysReg::S7: return "s7"; + case PhysReg::S8: return "s8"; + case PhysReg::S9: return "s9"; + case PhysReg::S10: return "s10"; + case PhysReg::S11: return "s11"; + case PhysReg::S12: return "s12"; + case PhysReg::S13: return "s13"; + case PhysReg::S14: return "s14"; + case PhysReg::S15: return "s15"; + case PhysReg::X29: return "x29"; + case PhysReg::X30: return "x30"; + case PhysReg::SP: return "sp"; + case PhysReg::WZR: return "wzr"; + case PhysReg::XZR: return "xzr"; } throw std::runtime_error(FormatError("mir", "未知物理寄存器")); }