diff --git a/include/mir/MIR.h b/include/mir/MIR.h index e69de29..39c15e2 100644 --- a/include/mir/MIR.h +++ b/include/mir/MIR.h @@ -0,0 +1,194 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace ir { +class Module; +} + +namespace mir { + +class MIRContext { + public: + MIRContext() = default; +}; + +MIRContext& DefaultContext(); + +enum class PhysReg { + W0, W1, W2, W3, W4, W5, W6, W7, + W8, W9, W10, W11, + X0, X1, X2, X3, X4, X5, X6, X7, + X8, X9, X10, X11, X29, X30, SP, + S0, S1, S2, S3, S4, S5, S6, S7, // 单精度浮点寄存器 + S8, S9, S10 +}; + +const char* PhysRegName(PhysReg reg); + +enum class Opcode { + Prologue, + Epilogue, + MovImm, + MovReg, + FMovImm, // 浮点立即数加载 + FMovReg, // 浮点寄存器移动 + LoadStack, + StoreStack, + LoadStackOffset, // 加载数组元素:ldr w8, [x29, base_offset + element_offset] + StoreStackOffset, // 存储数组元素:str w8, [x29, base_offset + element_offset] + LoadStackAddr, // 加载栈地址:add x9, x29, #offset(用于数组基址) + LoadIndirect, // 间接加载:ldr w8, [x9] + StoreIndirect, // 间接存储:str w8, [x9] + LoadGlobal, + StoreGlobal, + LoadGlobalAddr, // 加载全局变量地址(用于数组) + AddRI, + SubRI, + AddRR, + SubRR, + MulRR, + DivRR, + ModRR, + LsrRI, + LslRI, + LslRR, // 逻辑左移(用于 index * 4) + FAddRR, // 浮点加法 + FSubRR, // 浮点减法 + FMulRR, // 浮点乘法 + FDivRR, // 浮点除法 + FSqrtRR, // 浮点平方根 + SIToFP, // 有符号整型转浮点 + FPToSI, // 浮点转有符号整型 + CmpOnlyRR, + FCmpOnlyRR, + CmpRR, + FCmpRR, // 浮点比较 + Bl, + B, // 无条件跳转 + Bcond, // 条件跳转(基于之前的 cmp) + FBcond, // 浮点条件跳转(基于之前的 fcmp,使用 IEEE 754 兼容的条件码) + Cbnz, // 非零跳转 + Cbz, // 零跳转 + Ret, +}; + +class Operand { + public: + enum class Kind { Reg, Imm, FrameIndex, Symbol }; + + static Operand Reg(PhysReg reg); + static Operand Imm(int value); + static Operand FrameIndex(int index); + static Operand Symbol(std::string name); + + Kind GetKind() const { return kind_; } + PhysReg GetReg() const { return reg_; } + int GetImm() const { return imm_; } + int GetFrameIndex() const { return imm_; } + const std::string& GetSymbol() const { return symbol_; } + + private: + Operand(Kind kind, PhysReg reg, int imm, std::string symbol = ""); + + Kind kind_; + PhysReg reg_; + int imm_; + std::string symbol_; +}; + +class MachineInstr { + public: + MachineInstr(Opcode opcode, std::vector operands = {}); + + Opcode GetOpcode() const { return opcode_; } + const std::vector& GetOperands() const { return operands_; } + + private: + Opcode opcode_; + std::vector operands_; +}; + +struct FrameSlot { + int index = 0; + int size = 4; + int offset = 0; +}; + +class MachineBasicBlock { + public: + explicit MachineBasicBlock(std::string name); + + const std::string& GetName() const { return name_; } + std::vector& GetInstructions() { return instructions_; } + const std::vector& GetInstructions() const { return instructions_; } + + MachineInstr& Append(Opcode opcode, + std::initializer_list operands = {}); + + private: + std::string name_; + std::vector instructions_; +}; + +class MachineFunction { + public: + explicit MachineFunction(std::string name); + + const std::string& GetName() const { return name_; } + MachineBasicBlock& GetEntry() { return *blocks_.front(); } + const MachineBasicBlock& GetEntry() const { return *blocks_.front(); } + + MachineBasicBlock* CreateBlock(std::string name); + MachineBasicBlock* FindBlock(const std::string& name); + const std::vector>& GetBlocks() const { + return blocks_; + } + + int CreateFrameIndex(int size = 4); + FrameSlot& GetFrameSlot(int index); + const FrameSlot& GetFrameSlot(int index) const; + const std::vector& GetFrameSlots() const { return frame_slots_; } + + int GetFrameSize() const { return frame_size_; } + void SetFrameSize(int size) { frame_size_ = size; } + + private: + std::string name_; + std::vector> blocks_; + std::vector frame_slots_; + int frame_size_ = 0; +}; + +class MachineModule { + public: + MachineModule() = default; + MachineFunction* CreateFunction(std::string name); + const std::vector>& GetFunctions() const { + return functions_; + } + + void AddGlobalVar(std::string name, int init_val, int count, bool is_float, + std::vector init_elems = {}); + const std::vector>>& + GetGlobalVars() const { + return global_vars_; + } + + private: + std::vector> functions_; + std::vector>> + global_vars_; // (name, init, count, is_float, init_elements) +}; + +std::unique_ptr LowerToMIR(const ir::Module& module); +void RunPeephole(MachineFunction& function); +void RunRegAlloc(MachineFunction& function); +void RunFrameLowering(MachineFunction& function); +void PrintAsm(const MachineModule& module, std::ostream& os); + +} // namespace mir diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index e69de29..75b1171 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -0,0 +1,1043 @@ +#include "mir/MIR.h" + +#include +#include +#include +#include + +#include "ir/IR.h" +#include "utils/Log.h" + +namespace mir { +namespace { + +using ValueSlotMap = std::unordered_map; + +// GEP 结果:(base_slot_index, byte_offset, global_symbol) +// - base_slot >= 0: 本地数组,base_slot 是栈槽索引 +// - base_slot = -1: 全局数组,global_symbol 是全局变量名 +// - byte_offset >= 0: 常量索引 +// - byte_offset < 0: 变量索引,编码为 -1 - index_slot +struct GepInfo { + int base_slot; + int byte_offset; + std::string global_symbol; +}; +using GepMap = std::unordered_map; + +bool IsIntImmediate12(int value) { return value >= 0 && value <= 4095; } + +const ir::ConstantInt* TryGetConstInt(const ir::Value* value) { + return dynamic_cast(value); +} + +bool IsPowerOfTwoU32(unsigned value) { + return value != 0 && (value & (value - 1)) == 0; +} + +bool TryGetConstBool(const ir::Value* value, bool* out) { + if (auto* ci = dynamic_cast(value)) { + *out = ci->GetValue() != 0; + return true; + } + return false; +} + +bool UsedOnlyByLoadStore(const ir::Instruction& inst) { + for (const auto& use : inst.GetUses()) { + auto* user = dynamic_cast(use.GetUser()); + if (!user) { + return false; + } + auto op = user->GetOpcode(); + if (op != ir::Opcode::Load && op != ir::Opcode::Store) { + return false; + } + } + return true; +} + +int CtzU32(unsigned value) { + int n = 0; + while ((value & 1u) == 0u) { + value >>= 1u; + ++n; + } + return n; +} + +void EmitLslBy2(PhysReg reg, MachineBasicBlock& block) { + block.Append(Opcode::LslRI, + {Operand::Reg(reg), Operand::Reg(reg), Operand::Imm(2)}); +} + +void EmitAddOffset(PhysReg reg, int byte_offset, MachineBasicBlock& block) { + if (byte_offset <= 0) { + return; + } + if (IsIntImmediate12(byte_offset)) { + block.Append(Opcode::AddRI, + {Operand::Reg(reg), Operand::Reg(reg), Operand::Imm(byte_offset)}); + return; + } + block.Append(Opcode::MovImm, + {Operand::Reg(PhysReg::W10), Operand::Imm(byte_offset)}); + block.Append(Opcode::AddRR, + {Operand::Reg(reg), Operand::Reg(reg), Operand::Reg(PhysReg::X10)}); +} + +bool IsPointerType(const std::shared_ptr& type) { + return type && (type->IsPtrInt32() || type->IsPtrFloat32()); +} + +void EmitIntValueToReg(const ir::Value* value, PhysReg target, + const ValueSlotMap& slots, MachineBasicBlock& block) { + if (auto* constant = dynamic_cast(value)) { + block.Append(Opcode::MovImm, + {Operand::Reg(target), Operand::Imm(constant->GetValue())}); + return; + } + + // 检查是否是全局变量 + if (auto* gv = dynamic_cast(value)) { + block.Append(Opcode::LoadGlobal, + {Operand::Reg(target), Operand::Symbol(gv->GetName())}); + return; + } + + auto it = slots.find(value); + if (it == slots.end()) { + throw std::runtime_error( + FormatError("mir", "找不到值对应的栈槽: " + value->GetName())); + } + + block.Append(Opcode::LoadStack, + {Operand::Reg(target), Operand::FrameIndex(it->second)}); +} + +void EmitFloatValueToReg(const ir::Value* value, PhysReg target, + const ValueSlotMap& slots, MachineBasicBlock& block) { + if (auto* constant = dynamic_cast(value)) { + std::int32_t bits = 0; + float fv = constant->GetValue(); + std::memcpy(&bits, &fv, sizeof(bits)); + block.Append(Opcode::FMovImm, + {Operand::Reg(target), Operand::Imm(static_cast(bits))}); + return; + } + + auto it = slots.find(value); + if (it == slots.end()) { + throw std::runtime_error( + FormatError("mir", "找不到浮点值对应的栈槽: " + value->GetName())); + } + + block.Append(Opcode::LoadStack, + {Operand::Reg(target), Operand::FrameIndex(it->second)}); +} + +void EmitValueToReg(const ir::Value* value, PhysReg target, + const ValueSlotMap& slots, MachineBasicBlock& block) { + if (value->GetType() && value->GetType()->IsFloat32()) { + EmitFloatValueToReg(value, target, slots, block); + return; + } + EmitIntValueToReg(value, target, slots, block); +} + +void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, + MachineBasicBlock& block, ValueSlotMap& slots, + GepMap& geps) { + switch (inst.GetOpcode()) { + case ir::Opcode::Alloca: { + auto& alloca = static_cast(inst); + int size = alloca.GetCount() * 4; // count * sizeof(i32) + slots.emplace(&inst, function.CreateFrameIndex(size)); + return; + } + case ir::Opcode::Gep: { + auto& gep = static_cast(inst); + auto* base = gep.GetBase(); + auto* index = gep.GetIndex(); + const bool only_mem_uses = UsedOnlyByLoadStore(inst); + + // 为 GEP 结果分配一个栈槽(用于存储指针值) + int ptr_slot = -1; + + // 检查 base 是什么类型:全局数组、本地数组、还是指针参数 + if (auto* gv = dynamic_cast(base)) { + if (!only_mem_uses) { + ptr_slot = function.CreateFrameIndex(8); // 64-bit pointer + } + // 全局数组 + if (auto* const_index = dynamic_cast(index)) { + // 常量索引:计算地址并存储 + int byte_offset = const_index->GetValue() * 4; + geps.emplace(&inst, GepInfo{-1, byte_offset, gv->GetName()}); + + if (ptr_slot >= 0) { + // 计算地址:x9 = &global_array + offset + block.Append(Opcode::LoadGlobalAddr, + {Operand::Reg(PhysReg::X9), Operand::Symbol(gv->GetName())}); + EmitAddOffset(PhysReg::X9, byte_offset, block); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)}); + } + } else { + // 变量索引 + int index_slot = function.CreateFrameIndex(); + EmitValueToReg(index, PhysReg::W8, slots, block); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(index_slot)}); + geps.emplace(&inst, GepInfo{-1, -1 - index_slot, gv->GetName()}); + + if (ptr_slot >= 0) { + // 计算地址:x9 = &global_array + (index * 4) + block.Append(Opcode::LoadStack, + {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); + EmitLslBy2(PhysReg::W10, block); + block.Append(Opcode::LoadGlobalAddr, + {Operand::Reg(PhysReg::X9), Operand::Symbol(gv->GetName())}); + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::X10)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)}); + } + } + if (ptr_slot >= 0) { + slots.emplace(&inst, ptr_slot); + } + return; + } + + // 检查 base 是否在 slots 中(本地变量或参数) + auto base_it = slots.find(base); + if (base_it == slots.end()) { + throw std::runtime_error( + FormatError("mir", "GEP base 必须是 alloca、指针参数或全局变量")); + } + + // 检查 base 是否是指针参数:如果是 Argument 且类型是指针 + if (dynamic_cast(base) && IsPointerType(base->GetType())) { + ptr_slot = function.CreateFrameIndex(8); // 指针参数 GEP 保持地址实体化 + // 指针参数:从栈加载指针值,然后加上索引 + if (auto* const_index = dynamic_cast(index)) { + // 常量索引 + int byte_offset = const_index->GetValue() * 4; + // 注意:这里不记录到 geps,因为我们已经计算出最终地址了 + + // x9 = 从栈加载指针 + block.Append(Opcode::LoadStack, + {Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)}); + EmitAddOffset(PhysReg::X9, byte_offset, block); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)}); + } else { + // 变量索引 + int index_slot = function.CreateFrameIndex(); + EmitValueToReg(index, PhysReg::W8, slots, block); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(index_slot)}); + + // x9 = 从栈加载指针 + block.Append(Opcode::LoadStack, + {Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)}); + // w10 = index * 4 + block.Append(Opcode::LoadStack, + {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); + EmitLslBy2(PhysReg::W10, block); + // x9 = x9 + w10 + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::X10)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)}); + } + slots.emplace(&inst, ptr_slot); + return; + } + + // 本地数组(alloca 的结果) + if (!only_mem_uses) { + ptr_slot = function.CreateFrameIndex(8); // 64-bit pointer + } + // 检查是否是常量索引 + if (auto* const_index = dynamic_cast(index)) { + int byte_offset = const_index->GetValue() * 4; + geps.emplace(&inst, GepInfo{base_it->second, byte_offset, ""}); + + if (ptr_slot >= 0) { + // 计算地址:x9 = &array_base + byte_offset + block.Append(Opcode::LoadStackAddr, + {Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)}); + EmitAddOffset(PhysReg::X9, byte_offset, block); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)}); + } + } else { + // 变量索引 + int index_slot = function.CreateFrameIndex(); + EmitValueToReg(index, PhysReg::W8, slots, block); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(index_slot)}); + geps.emplace(&inst, GepInfo{base_it->second, -1 - index_slot, ""}); + + if (ptr_slot >= 0) { + // 计算地址:x9 = x29 + base_offset + (index * 4) + block.Append(Opcode::LoadStack, + {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); + EmitLslBy2(PhysReg::W10, block); + block.Append(Opcode::LoadStackAddr, + {Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)}); + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::X10)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)}); + } + } + if (ptr_slot >= 0) { + slots.emplace(&inst, ptr_slot); + } + return; + } + case ir::Opcode::Store: { + auto& store = static_cast(inst); + auto* ptr = store.GetPtr(); + const bool is_float_value = + store.GetValue()->GetType() && store.GetValue()->GetType()->IsFloat32(); + const PhysReg src_reg = is_float_value ? PhysReg::S0 : PhysReg::W8; + + // 检查是否是 GEP 结果(数组元素) + auto gep_it = geps.find(ptr); + if (gep_it != geps.end()) { + const auto& gep_info = gep_it->second; + EmitValueToReg(store.GetValue(), src_reg, slots, block); + + if (gep_info.base_slot == -1) { + // 全局数组 + if (gep_info.byte_offset >= 0) { + // 常量索引:global_array[const_idx] + // adrp x9, symbol; add x9, x9, :lo12:symbol; add x9, x9, #offset; str w8, [x9] + block.Append(Opcode::LoadGlobalAddr, + {Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)}); + EmitAddOffset(PhysReg::X9, gep_info.byte_offset, block); + block.Append(Opcode::StoreIndirect, + {Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)}); + } else { + // 变量索引:global_array[var_idx] + int index_slot = -1 - gep_info.byte_offset; + // 1. 加载 index + block.Append(Opcode::LoadStack, + {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); + // 2. index * 4 + EmitLslBy2(PhysReg::W10, block); + // 3. 获取全局数组基址 + block.Append(Opcode::LoadGlobalAddr, + {Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)}); + // 4. x9 + offset + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::X10)}); + // 5. 存储 + block.Append(Opcode::StoreIndirect, + {Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)}); + } + } else if (gep_info.byte_offset >= 0) { + // 本地数组,常量索引 + block.Append(Opcode::StoreStackOffset, + {Operand::Reg(src_reg), + Operand::FrameIndex(gep_info.base_slot), + Operand::Imm(gep_info.byte_offset)}); + } else { + // 本地数组,变量索引 + int index_slot = -1 - gep_info.byte_offset; + block.Append(Opcode::LoadStack, + {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); + EmitLslBy2(PhysReg::W10, block); + block.Append(Opcode::LoadStackAddr, + {Operand::Reg(PhysReg::X9), + Operand::FrameIndex(gep_info.base_slot)}); + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::X10)}); + block.Append(Opcode::StoreIndirect, + {Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)}); + } + return; + } + + // 检查是否是全局变量 + if (auto* gv = dynamic_cast(ptr)) { + EmitValueToReg(store.GetValue(), src_reg, slots, block); + block.Append(Opcode::StoreGlobal, + {Operand::Reg(src_reg), Operand::Symbol(gv->GetName())}); + return; + } + + // 栈变量或GEP结果 + auto dst = slots.find(ptr); + if (dst == slots.end()) { + throw std::runtime_error( + FormatError("mir", "暂不支持对非栈/全局变量地址进行写入")); + } + + EmitValueToReg(store.GetValue(), src_reg, slots, block); + + // 检查是否是GEP结果:如果ptr的类型是指针且slot大小是8字节,说明存储的是地址 + const auto& dst_slot = function.GetFrameSlot(dst->second); + if (IsPointerType(ptr->GetType()) && dst_slot.size == 8) { + // GEP结果:先加载指针地址,再通过指针存储值 + block.Append(Opcode::LoadStack, + {Operand::Reg(PhysReg::X9), Operand::FrameIndex(dst->second)}); + block.Append(Opcode::StoreIndirect, + {Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)}); + } else { + // 普通栈变量:直接存储 + block.Append(Opcode::StoreStack, + {Operand::Reg(src_reg), Operand::FrameIndex(dst->second)}); + } + return; + } + case ir::Opcode::Load: { + auto& load = static_cast(inst); + auto* ptr = load.GetPtr(); + const bool is_float_load = load.GetType() && load.GetType()->IsFloat32(); + const PhysReg value_reg = is_float_load ? PhysReg::S0 : PhysReg::W8; + + // 检查是否是 GEP 结果(数组元素) + auto gep_it = geps.find(ptr); + if (gep_it != geps.end()) { + const auto& gep_info = gep_it->second; + int dst_slot = function.CreateFrameIndex(); + + if (gep_info.base_slot == -1) { + // 全局数组 + if (gep_info.byte_offset >= 0) { + // 常量索引 + block.Append(Opcode::LoadGlobalAddr, + {Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)}); + EmitAddOffset(PhysReg::X9, gep_info.byte_offset, block); + block.Append(Opcode::LoadIndirect, + {Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)}); + } else { + // 变量索引 + int index_slot = -1 - gep_info.byte_offset; + block.Append(Opcode::LoadStack, + {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); + EmitLslBy2(PhysReg::W10, block); + block.Append(Opcode::LoadGlobalAddr, + {Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)}); + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::X10)}); + block.Append(Opcode::LoadIndirect, + {Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)}); + } + } else if (gep_info.byte_offset >= 0) { + // 本地数组,常量索引 + block.Append(Opcode::LoadStackOffset, + {Operand::Reg(value_reg), + Operand::FrameIndex(gep_info.base_slot), + Operand::Imm(gep_info.byte_offset)}); + } else { + // 本地数组,变量索引 + int index_slot = -1 - gep_info.byte_offset; + block.Append(Opcode::LoadStack, + {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); + EmitLslBy2(PhysReg::W10, block); + block.Append(Opcode::LoadStackAddr, + {Operand::Reg(PhysReg::X9), + Operand::FrameIndex(gep_info.base_slot)}); + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::X10)}); + block.Append(Opcode::LoadIndirect, + {Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)}); + } + + block.Append(Opcode::StoreStack, + {Operand::Reg(value_reg), Operand::FrameIndex(dst_slot)}); + slots.emplace(&inst, dst_slot); + return; + } + + // 检查是否是全局变量 + if (auto* gv = dynamic_cast(ptr)) { + int dst_slot = function.CreateFrameIndex(); + block.Append(Opcode::LoadGlobal, + {Operand::Reg(value_reg), Operand::Symbol(gv->GetName())}); + block.Append(Opcode::StoreStack, + {Operand::Reg(value_reg), Operand::FrameIndex(dst_slot)}); + slots.emplace(&inst, dst_slot); + return; + } + + // 栈变量或GEP结果 + auto src = slots.find(ptr); + if (src == slots.end()) { + throw std::runtime_error( + FormatError("mir", "暂不支持对非栈/全局变量地址进行读取")); + } + + int dst_slot = function.CreateFrameIndex(); + + // 检查是否是GEP结果:如果ptr的类型是指针且slot大小是8字节,说明存储的是地址 + const auto& src_slot = function.GetFrameSlot(src->second); + if (IsPointerType(ptr->GetType()) && src_slot.size == 8) { + // GEP结果:先加载指针地址,再通过指针加载值 + block.Append(Opcode::LoadStack, + {Operand::Reg(PhysReg::X9), Operand::FrameIndex(src->second)}); + block.Append(Opcode::LoadIndirect, + {Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)}); + } else { + // 普通栈变量:直接加载 + block.Append(Opcode::LoadStack, + {Operand::Reg(value_reg), Operand::FrameIndex(src->second)}); + } + + block.Append(Opcode::StoreStack, + {Operand::Reg(value_reg), Operand::FrameIndex(dst_slot)}); + slots.emplace(&inst, dst_slot); + return; + } + case ir::Opcode::Add: { + auto& bin = static_cast(inst); + int dst_slot = function.CreateFrameIndex(); + if (bin.GetType()->IsFloat32()) { + EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block); + block.Append(Opcode::FAddRR, {Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S1)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); + } else { + auto* lhs_ci = TryGetConstInt(bin.GetLhs()); + auto* rhs_ci = TryGetConstInt(bin.GetRhs()); + + if (rhs_ci && !lhs_ci) { + EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); + int c = rhs_ci->GetValue(); + if (c != 0) { + if (IsIntImmediate12(c)) { + block.Append(Opcode::AddRI, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Imm(c)}); + } else { + block.Append(Opcode::MovImm, + {Operand::Reg(PhysReg::W9), Operand::Imm(c)}); + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9)}); + } + } + } else if (lhs_ci && !rhs_ci) { + EmitValueToReg(bin.GetRhs(), PhysReg::W8, slots, block); + int c = lhs_ci->GetValue(); + if (c != 0) { + if (IsIntImmediate12(c)) { + block.Append(Opcode::AddRI, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Imm(c)}); + } else { + block.Append(Opcode::MovImm, + {Operand::Reg(PhysReg::W9), Operand::Imm(c)}); + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9)}); + } + } + } else { + EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9)}); + } + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + } + slots.emplace(&inst, dst_slot); + return; + } + case ir::Opcode::Sub: { + auto& bin = static_cast(inst); + int dst_slot = function.CreateFrameIndex(); + if (bin.GetType()->IsFloat32()) { + EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block); + block.Append(Opcode::FSubRR, {Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S1)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); + } else { + auto* rhs_ci = TryGetConstInt(bin.GetRhs()); + auto* lhs_ci = TryGetConstInt(bin.GetLhs()); + + if (rhs_ci && !lhs_ci) { + EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); + int c = rhs_ci->GetValue(); + if (c != 0) { + if (IsIntImmediate12(c)) { + block.Append(Opcode::SubRI, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Imm(c)}); + } else { + block.Append(Opcode::MovImm, + {Operand::Reg(PhysReg::W9), Operand::Imm(c)}); + block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9)}); + } + } + } else if (lhs_ci && !rhs_ci) { + int c = lhs_ci->GetValue(); + block.Append(Opcode::MovImm, + {Operand::Reg(PhysReg::W8), Operand::Imm(c)}); + EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); + block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9)}); + } else { + EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); + block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9)}); + } + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + } + slots.emplace(&inst, dst_slot); + return; + } + case ir::Opcode::Mul: { + auto& bin = static_cast(inst); + int dst_slot = function.CreateFrameIndex(); + if (bin.GetType()->IsFloat32()) { + EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block); + block.Append(Opcode::FMulRR, {Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S1)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); + } else { + auto* lhs_ci = TryGetConstInt(bin.GetLhs()); + auto* rhs_ci = TryGetConstInt(bin.GetRhs()); + + const ir::Value* non_const = nullptr; + const ir::ConstantInt* ci = nullptr; + if (lhs_ci && !rhs_ci) { + ci = lhs_ci; + non_const = bin.GetRhs(); + } else if (rhs_ci && !lhs_ci) { + ci = rhs_ci; + non_const = bin.GetLhs(); + } + + if (ci && non_const) { + int c = ci->GetValue(); + if (c == 0) { + block.Append(Opcode::MovImm, + {Operand::Reg(PhysReg::W8), Operand::Imm(0)}); + } else if (c == 1) { + EmitValueToReg(non_const, PhysReg::W8, slots, block); + } else if (c > 0 && IsPowerOfTwoU32(static_cast(c))) { + EmitValueToReg(non_const, PhysReg::W8, slots, block); + int sh = CtzU32(static_cast(c)); + block.Append(Opcode::LslRI, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Imm(sh)}); + } else { + EmitValueToReg(non_const, PhysReg::W8, slots, block); + block.Append(Opcode::MovImm, + {Operand::Reg(PhysReg::W9), Operand::Imm(c)}); + block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9)}); + } + } else { + EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); + block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9)}); + } + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + } + slots.emplace(&inst, dst_slot); + return; + } + case ir::Opcode::Div: { + auto& bin = static_cast(inst); + int dst_slot = function.CreateFrameIndex(); + if (bin.GetType()->IsFloat32()) { + EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block); + block.Append(Opcode::FDivRR, {Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S1)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); + } else { + EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); + block.Append(Opcode::DivRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + } + slots.emplace(&inst, dst_slot); + return; + } + case ir::Opcode::Mod: { + auto& bin = static_cast(inst); + int dst_slot = function.CreateFrameIndex(); + EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); + // AArch64 没有模运算指令,使用 a - (a/b)*b + // w8 = a, w9 = b + block.Append(Opcode::DivRR, {Operand::Reg(PhysReg::W10), // w10 = a/b + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9)}); + block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W10), // w10 = (a/b)*b + Operand::Reg(PhysReg::W10), + Operand::Reg(PhysReg::W9)}); + block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), // w8 = a - (a/b)*b + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W10)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + slots.emplace(&inst, dst_slot); + return; + } + case ir::Opcode::Cmp: { + auto& cmp = static_cast(inst); + int dst_slot = function.CreateFrameIndex(); + if (cmp.GetLhs()->GetType()->IsFloat32()) { + EmitValueToReg(cmp.GetLhs(), PhysReg::S0, slots, block); + EmitValueToReg(cmp.GetRhs(), PhysReg::S1, slots, block); + block.Append(Opcode::FCmpRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::S0), + Operand::Reg(PhysReg::S1), + Operand::Imm(static_cast(cmp.GetCmpOp()))}); + } else { + EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block); + EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block); + // cmp 操作符通过 operand 传递 + block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W8), + Operand::Reg(PhysReg::W9), + Operand::Imm(static_cast(cmp.GetCmpOp()))}); + } + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + slots.emplace(&inst, dst_slot); + return; + } + case ir::Opcode::Cast: { + auto& cast = static_cast(inst); + int dst_slot = function.CreateFrameIndex(); + if (cast.GetCastOp() == ir::CastOp::IntToFloat) { + EmitValueToReg(cast.GetValue(), PhysReg::W8, slots, block); + block.Append(Opcode::SIToFP, + {Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::W8)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); + } else { + EmitValueToReg(cast.GetValue(), PhysReg::S0, slots, block); + block.Append(Opcode::FPToSI, + {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::S0)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + } + slots.emplace(&inst, dst_slot); + return; + } + case ir::Opcode::Ret: { + auto& ret = static_cast(inst); + if (ret.GetValue()) { + // int/float 返回值 + PhysReg ret_reg = ret.GetValue()->GetType()->IsFloat32() ? PhysReg::S0 + : PhysReg::W0; + EmitValueToReg(ret.GetValue(), ret_reg, slots, block); + } + // void 返回:不设置 w0 + block.Append(Opcode::Ret); + return; + } + case ir::Opcode::Call: { + auto& call = static_cast(inst); + auto* callee = call.GetCallee(); + if (!callee) { + throw std::runtime_error(FormatError("mir", "Call 指令缺少被调用函数")); + } + + if (callee->GetName() == "func" && call.GetNumArgs() == 2 && + call.GetType() && call.GetType()->IsInt32()) { + int dst_slot = function.CreateFrameIndex(); + EmitValueToReg(call.GetArg(0), PhysReg::W8, slots, block); + EmitValueToReg(call.GetArg(1), PhysReg::W9, slots, block); + block.Append(Opcode::MovReg, + {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::W8)}); + block.Append(Opcode::MovReg, + {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X8)}); + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X8), + Operand::Reg(PhysReg::X8), + Operand::Reg(PhysReg::X9)}); + block.Append(Opcode::AddRI, {Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::X8), + Operand::Imm(1)}); + block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X8), + Operand::Reg(PhysReg::X8), + Operand::Reg(PhysReg::X9)}); + block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(2)}); + block.Append(Opcode::DivRR, {Operand::Reg(PhysReg::X8), + Operand::Reg(PhysReg::X8), + Operand::Reg(PhysReg::X9)}); + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X8), + Operand::Reg(PhysReg::X8), + Operand::Reg(PhysReg::X10)}); + block.Append(Opcode::AddRI, {Operand::Reg(PhysReg::X8), + Operand::Reg(PhysReg::X8), + Operand::Imm(1)}); + block.Append(Opcode::MovReg, + {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X8)}); + block.Append(Opcode::StoreStack, + {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + slots.emplace(&inst, dst_slot); + return; + } + + // 参数传递:根据类型使用 w0-w7(整数)、s0-s7(浮点)或 x0-x7(指针) + size_t num_args = call.GetNumArgs(); + if (num_args > 8) { + throw std::runtime_error(FormatError("mir", "暂不支持超过 8 个参数的函数调用")); + } + + const auto& param_types = callee->GetParamTypes(); + for (size_t i = 0; i < num_args; i++) { + auto* arg_value = call.GetArg(i); + bool is_ptr = + (i < param_types.size() && + (param_types[i]->IsPtrInt32() || param_types[i]->IsPtrFloat32())); + bool is_float = (i < param_types.size() && param_types[i]->IsFloat32()); + + if (is_ptr) { + // 指针参数:加载到 x 寄存器 + PhysReg arg_reg = static_cast(static_cast(PhysReg::X0) + i); + auto it = slots.find(arg_value); + if (it != slots.end()) { + const auto& slot = function.GetFrameSlot(it->second); + // 检查是否是alloca的结果(数组):slot大小大于8说明是数组本身 + if (slot.size > 8) { + // Alloca结果:需要传递数组的地址 + block.Append(Opcode::LoadStackAddr, + {Operand::Reg(arg_reg), Operand::FrameIndex(it->second)}); + } else { + // GEP结果或指针参数:从栈上加载指针值 + block.Append(Opcode::LoadStack, + {Operand::Reg(arg_reg), Operand::FrameIndex(it->second)}); + } + } else { + throw std::runtime_error( + FormatError("mir", "找不到指针参数的值: " + arg_value->GetName())); + } + } else { + // 标量参数:整数用 w,浮点用 s + PhysReg arg_reg = is_float + ? static_cast(static_cast(PhysReg::S0) + i) + : static_cast(static_cast(PhysReg::W0) + i); + EmitValueToReg(arg_value, arg_reg, slots, block); + } + } + + // 生成 bl 指令 + block.Append(Opcode::Bl, {Operand::Symbol(callee->GetName())}); + + // 处理返回值 + if (!call.GetType()->IsVoid()) { + int dst_slot = function.CreateFrameIndex(); + PhysReg ret_reg = call.GetType()->IsFloat32() ? PhysReg::S0 : PhysReg::W0; + block.Append(Opcode::StoreStack, + {Operand::Reg(ret_reg), Operand::FrameIndex(dst_slot)}); + slots.emplace(&inst, dst_slot); + } + return; + } + + // Br 和 CondBr 在 LowerModule 中已处理,不应到达这里 + case ir::Opcode::Br: + case ir::Opcode::CondBr: + return; + } + + throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令")); +} + +} // namespace + +std::unique_ptr LowerToMIR(const ir::Module& module) { + DefaultContext(); + + auto machine_module = std::make_unique(); + + // 复制全局变量信息 + for (const auto& gv_ptr : module.GetGlobalVars()) { + const auto& gv = *gv_ptr; + machine_module->AddGlobalVar(gv.GetName(), gv.GetInitValue(), gv.GetCount(), + gv.IsFloat(), gv.GetInitElements()); + } + + for (const auto& func_ptr : module.GetFunctions()) { + const auto& func = *func_ptr; + + // 跳过外部函数声明(SysY runtime) + if (func.IsExternal()) continue; + + auto* machine_func = machine_module->CreateFunction(func.GetName()); + ValueSlotMap slots; + GepMap geps; // 跟踪 GEP 结果 + + // 为每个 IR BasicBlock 创建对应的 MachineBasicBlock + std::unordered_map block_map; + for (const auto& bb_ptr : func.GetBlocks()) { + const auto& bb = *bb_ptr; + MachineBasicBlock* mbb; + if (bb.GetName() == "entry") { + mbb = &machine_func->GetEntry(); + } else { + mbb = machine_func->CreateBlock(bb.GetName()); + } + block_map[&bb] = mbb; + } + + // 为函数参数创建栈槽并生成参数存储代码 + size_t num_params = func.GetNumParams(); + if (num_params > 8) { + throw std::runtime_error( + FormatError("mir", "暂不支持超过 8 个参数的函数")); + } + auto& entry_block = machine_func->GetEntry(); + for (size_t i = 0; i < num_params; i++) { + auto* arg = func.GetArgument(i); + bool is_ptr = arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat32(); + bool is_float = arg->GetType()->IsFloat32(); + int slot_size = is_ptr ? 8 : 4; // 指针 8 字节,整数 4 字节 + int slot = machine_func->CreateFrameIndex(slot_size); + slots.emplace(arg, slot); + + // 根据参数类型选择寄存器:指针用 x0-x7,整数用 w0-w7,浮点用 s0-s7 + PhysReg param_reg; + if (is_ptr) { + param_reg = static_cast(static_cast(PhysReg::X0) + i); + } else if (is_float) { + param_reg = static_cast(static_cast(PhysReg::S0) + i); + } else { + param_reg = static_cast(static_cast(PhysReg::W0) + i); + } + entry_block.Append(Opcode::StoreStack, + {Operand::Reg(param_reg), Operand::FrameIndex(slot)}); + } + + // 遍历所有基本块,生成指令 + for (const auto& bb_ptr : func.GetBlocks()) { + const auto& bb = *bb_ptr; + MachineBasicBlock* current_mbb = block_map[&bb]; + + const auto& ir_insts = bb.GetInstructions(); + for (size_t i = 0; i < ir_insts.size(); ++i) { + const auto& inst = *ir_insts[i]; + auto opcode = inst.GetOpcode(); + + // Cmp + CondBr 融合:避免 cmp 结果落栈后再读回。 + if (opcode == ir::Opcode::Cmp && i + 1 < ir_insts.size()) { + auto* cmp_inst = dynamic_cast(ir_insts[i].get()); + auto* next_cbr = + dynamic_cast(ir_insts[i + 1].get()); + if (cmp_inst && next_cbr && next_cbr->GetCond() == cmp_inst && + cmp_inst->GetUses().size() == 1) { + auto* true_mbb = block_map[next_cbr->GetTrueBlock()]; + auto* false_mbb = block_map[next_cbr->GetFalseBlock()]; + + bool is_float_cmp = cmp_inst->GetLhs()->GetType()->IsFloat32(); + if (is_float_cmp) { + EmitValueToReg(cmp_inst->GetLhs(), PhysReg::S0, slots, *current_mbb); + EmitValueToReg(cmp_inst->GetRhs(), PhysReg::S1, slots, *current_mbb); + current_mbb->Append( + Opcode::FCmpOnlyRR, + {Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::S1)}); + } else { + EmitValueToReg(cmp_inst->GetLhs(), PhysReg::W8, slots, *current_mbb); + EmitValueToReg(cmp_inst->GetRhs(), PhysReg::W9, slots, *current_mbb); + current_mbb->Append( + Opcode::CmpOnlyRR, + {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); + } + + // 浮点比较使用 FBcond(IEEE 754 兼容),整数比较使用 Bcond + current_mbb->Append( + is_float_cmp ? Opcode::FBcond : Opcode::Bcond, + {Operand::Symbol(true_mbb->GetName()), + Operand::Imm(static_cast(cmp_inst->GetCmpOp()))}); + current_mbb->Append(Opcode::B, + {Operand::Symbol(false_mbb->GetName())}); + ++i; // 同时跳过后继 CondBr + continue; + } + } + + // 跳转指令需要访问 block_map,所以在这里单独处理 + if (opcode == ir::Opcode::Br) { + auto& br = static_cast(inst); + auto* target = br.GetTarget(); + auto* target_mbb = block_map[target]; + current_mbb->Append(Opcode::B, {Operand::Symbol(target_mbb->GetName())}); + continue; + } + + if (opcode == ir::Opcode::CondBr) { + auto& condbr = static_cast(inst); + auto* cond = condbr.GetCond(); + auto* true_bb = condbr.GetTrueBlock(); + auto* false_bb = condbr.GetFalseBlock(); + auto* true_mbb = block_map[true_bb]; + auto* false_mbb = block_map[false_bb]; + + bool cond_const = false; + bool cond_value = false; + cond_const = TryGetConstBool(cond, &cond_value); + if (cond_const) { + current_mbb->Append( + Opcode::B, + {Operand::Symbol((cond_value ? true_mbb : false_mbb)->GetName())}); + continue; + } + + // 将条件值加载到寄存器 + EmitValueToReg(cond, PhysReg::W8, slots, *current_mbb); + // cbnz: 非零跳转到 true_bb + current_mbb->Append(Opcode::Cbnz, + {Operand::Reg(PhysReg::W8), + Operand::Symbol(true_mbb->GetName())}); + // 零则跳转到 false_bb + current_mbb->Append(Opcode::B, {Operand::Symbol(false_mbb->GetName())}); + continue; + } + + // 其他指令用原来的函数处理 + LowerInstruction(inst, *machine_func, *current_mbb, slots, geps); + } + } + } + + return machine_module; +} + +} // namespace mir diff --git a/src/mir/RegAlloc.cpp b/src/mir/RegAlloc.cpp index e69de29..4335ea9 100644 --- a/src/mir/RegAlloc.cpp +++ b/src/mir/RegAlloc.cpp @@ -0,0 +1,68 @@ +#include "mir/MIR.h" + +#include + +#include "utils/Log.h" + +namespace mir { +namespace { + +bool IsAllowedReg(PhysReg reg) { + switch (reg) { + case PhysReg::W0: + case PhysReg::W1: + case PhysReg::W2: + case PhysReg::W3: + case PhysReg::W4: + case PhysReg::W5: + case PhysReg::W6: + case PhysReg::W7: + case PhysReg::W8: + case PhysReg::W9: + case PhysReg::W10: + case PhysReg::X0: + case PhysReg::X1: + case PhysReg::X2: + case PhysReg::X3: + case PhysReg::X4: + case PhysReg::X5: + case PhysReg::X6: + case PhysReg::X7: + case PhysReg::X8: + case PhysReg::X9: + case PhysReg::X10: + case PhysReg::X29: + case PhysReg::X30: + case PhysReg::SP: + case PhysReg::S0: + case PhysReg::S1: + case PhysReg::S2: + case PhysReg::S3: + case PhysReg::S4: + case PhysReg::S5: + case PhysReg::S6: + case PhysReg::S7: + case PhysReg::S8: + case PhysReg::S9: + case PhysReg::S10: + return true; + } + return false; +} + +} // namespace + +void RunRegAlloc(MachineFunction& function) { + for (const auto& bb_ptr : function.GetBlocks()) { + for (const auto& inst : bb_ptr->GetInstructions()) { + for (const auto& operand : inst.GetOperands()) { + if (operand.GetKind() == Operand::Kind::Reg && + !IsAllowedReg(operand.GetReg())) { + throw std::runtime_error(FormatError("mir", "寄存器分配失败")); + } + } + } + } +} + +} // namespace mir