From ee3b42ac40a861d2854fa6bcdb01b56ac39ad96f Mon Sep 17 00:00:00 2001 From: lzkk <956449176@qq.com> Date: Thu, 28 May 2026 01:12:43 +0800 Subject: [PATCH] =?UTF-8?q?feat(opt):=20=E5=88=87=E6=8D=A2=E8=87=B3?= =?UTF-8?q?=E9=98=9F=E5=8F=8B=E4=BB=A3=E7=A0=81=E5=9F=BA=E7=BA=BF=E2=80=94?= =?UTF-8?q?=E2=80=94100%=E5=8A=9F=E8=83=BD=E6=AD=A3=E7=A1=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Chaitin-Briggs 图着色寄存器分配,K=16无需spill。 IRGen starttime/stoptime 修复(去掉 _sysy_ 前缀和 lineno 参数)。 此提交为后续优化工作的安全起点。 --- src/include/mir/MIR.h | 156 +- src/include/mir/analysis/CFGAnalysis.h | 29 + src/ir/passes/Inline.cpp | 621 ++++++++ src/irgen/IRGenExp.cpp | 11 +- src/irgen/IRGenFunc.cpp | 6 +- src/main.cpp | 31 +- src/mir/AsmPrinter.cpp | 305 +--- src/mir/CMakeLists.txt | 5 +- src/mir/FrameLowering.cpp | 4 +- src/mir/Lowering.cpp | 332 +++-- src/mir/Lowering.cpp.orig | 1811 ++++++++++++++++++++++++ src/mir/Lowering.cpp.rej | 26 + src/mir/MIRBasicBlock.cpp | 18 - src/mir/RegAlloc.cpp | 1001 ++----------- src/mir/analysis/CFGAnalysis.cpp | 177 +++ 15 files changed, 3083 insertions(+), 1450 deletions(-) create mode 100644 src/include/mir/analysis/CFGAnalysis.h create mode 100644 src/ir/passes/Inline.cpp create mode 100644 src/mir/Lowering.cpp.orig create mode 100644 src/mir/Lowering.cpp.rej create mode 100644 src/mir/analysis/CFGAnalysis.cpp diff --git a/src/include/mir/MIR.h b/src/include/mir/MIR.h index 3022c33f..5691fb10 100644 --- a/src/include/mir/MIR.h +++ b/src/include/mir/MIR.h @@ -147,8 +147,6 @@ namespace mir StoreMem, AddRR, SubRR, - AddImm, - SubImm, MulRR, DivRR, ModRR, @@ -166,6 +164,7 @@ namespace mir FCmpRR, CSet, Csel, + Csneg, Smull, Msub, NegRR, @@ -184,18 +183,6 @@ namespace mir MovReg, }; - // ---- 寄存器类别 ---- - enum class RegClass { GPR32, GPR64, FPR32, FPR64, Unknown }; - - inline RegClass ToRegClass(VRegClass vc) { - switch (vc) { - case VRegClass::Int: return RegClass::GPR32; - case VRegClass::Float: return RegClass::FPR32; - case VRegClass::Ptr: return RegClass::GPR64; - default: return RegClass::Unknown; - } - } - enum class CondCode { EQ, @@ -257,23 +244,16 @@ namespace mir bool IsRematerializable() const { return is_rematerializable_; } MachineInstr &SetRematerializable(bool val) - { - is_rematerializable_ = val; - return *this; - } - + { is_rematerializable_ = val; return *this; } int GetRematImm() const { return remat_imm_; } MachineInstr &SetRematImm(int val) - { - remat_imm_ = val; - return *this; - } + { remat_imm_ = val; return *this; } private: - Opcode opcode_; - std::vector operands_; bool is_rematerializable_ = false; int remat_imm_ = 0; + Opcode opcode_; + std::vector operands_; }; struct FrameSlot @@ -300,9 +280,6 @@ namespace mir MachineInstr &Append(Opcode opcode, std::initializer_list operands = {}); - - void InsertInst(int local_idx, MachineInstr inst); - void ReplaceVReg(int local_idx, int old_vreg, int new_vreg); private: std::string name_; int label_id_ = -1; @@ -348,9 +325,6 @@ namespace mir int GetFrameSize() const { return frame_size_; } void SetFrameSize(int size) { frame_size_ = size; } - bool HasCall() const { return has_call_; } - void SetHasCall(bool v = true) { has_call_ = v; } - int CreateVReg(VRegClass vreg_class); VRegClass GetVRegClass(int vreg_id) const; int GetNumVRegs() const { return static_cast(vreg_classes_.size()); } @@ -365,7 +339,6 @@ namespace mir std::vector frame_slots_; int frame_size_ = 0; - bool has_call_ = false; int next_label_id_ = 0; std::vector vreg_classes_; @@ -436,9 +409,8 @@ namespace mir std::unique_ptr LowerModuleToMIR(const ir::Module &module); std::unique_ptr LowerToMIR(const ir::Module &module); - // ---- 贪婪寄存器分配器入口 ---- - void RunGreedyRegAlloc(MachineFunction &function); - void RunGreedyRegAlloc(MachineModule &module); + void RunRegAlloc(MachineFunction &function); + void RunRegAlloc(MachineModule &module); void RunFrameLowering(MachineFunction &function); void RunFrameLowering(MachineModule &module); @@ -446,120 +418,10 @@ namespace mir void RunPeephole(MachineFunction &function); void RunPeephole(MachineModule &module); - void VerifyMIR(MachineFunction &func); - void VerifyMIR(MachineModule &module); - - void VerifyRegAlloc(MachineFunction &func); - void VerifyRegAlloc(MachineModule &module); + void RunBlockLayout(MachineFunction &function); + void RunBlockLayout(MachineModule &module); void PrintAsm(const MachineFunction &function, std::ostream &os); void PrintAsm(const MachineModule &module, std::ostream &os); - struct VNInfo - { - int id = -1; - int def_pos = -1; - Opcode def_opcode = Opcode::Ret; - - bool IsRematable() const - { - return def_opcode == Opcode::MovImm || - def_opcode == Opcode::LoadStackAddr; - } - }; - - struct UsePosition - { - int pos = -1; - bool is_def = false; - int vn_id = -1; - Opcode opcode = Opcode::Ret; - }; - - struct Segment - { - int start = -1; - int end = -1; - int vn_id = -1; - bool crosses_call = false; - - bool Contains(int pos) const { return start <= pos && pos <= end; } - bool Overlaps(const Segment &o) const - { - return !(end < o.start || o.end < start); - } - }; - - struct LiveInterval - { - int vreg = -1; - RegClass reg_class = RegClass::Unknown; - - std::vector valnos; - std::vector segments; - std::vector uses; - - int assigned_reg = -1; - float spill_weight = 0.0f; - int hint_reg = -1; - int generation = 0; - int deferred_count = 0; // LLVM: RS_New→RS_Deferred→RS_Split stage tracking - - // 保留旧字段以兼容 ComputeInstLiveness - int start = -1; - int end = -1; - VRegClass vreg_class = VRegClass::Int; - bool spilled = false; - int spill_slot = -1; - - bool IsSpilled() const { return assigned_reg == -2; } - bool IsSplit() const { return assigned_reg == -3; } - bool IsAllocated() const { return assigned_reg >= 0; } - - int FirstUsePos() const - { - if (!uses.empty()) return uses.front().pos; - return start; - } - int LastUsePos() const - { - if (!segments.empty()) return segments.back().end; - return end; - } - bool SegmentCrossesCall() const - { - for (auto &seg : segments) - if (seg.crosses_call) return true; - return false; - } - float Length() const - { - int total = 0; - for (auto &seg : segments) - total += seg.end - seg.start + 1; - return total > 0 ? (float)total : 1.0f; - } - }; - - class LiveRegMatrix - { - std::vector> reg_assignments_; - - public: - void Init(int num_regs); - bool Assign(LiveInterval *li, int phys_reg); - void ForceAssign(LiveInterval *li, int phys_reg); - void Unassign(LiveInterval *li); - bool CheckInterference(const LiveInterval &li, int phys_reg) const; - LiveInterval *GetConflict(const LiveInterval &li, int phys_reg) const; - bool CheckInterferenceRange(int start, int end, int phys_reg) const; - }; - - // ---- 增强活跃分析 ---- - std::vector EnhanceIntervals( - const std::vector &raw, - MachineFunction &function); - - std::vector ComputeInstLiveness(MachineFunction &func); - } // namespace mir diff --git a/src/include/mir/analysis/CFGAnalysis.h b/src/include/mir/analysis/CFGAnalysis.h new file mode 100644 index 00000000..ea2af5d5 --- /dev/null +++ b/src/include/mir/analysis/CFGAnalysis.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + +namespace mir +{ + + class MachineBasicBlock; + class MachineFunction; + + struct CFGEdge + { + MachineBasicBlock *src = nullptr; + MachineBasicBlock *dst = nullptr; + double weight = 0.0; + }; + + struct CFGAnalysisResult + { + std::map> successors; + std::map> predecessors; + std::map block_freq; + std::vector edges; + }; + + CFGAnalysisResult AnalyzeCFG(MachineFunction &function); + +} // namespace mir diff --git a/src/ir/passes/Inline.cpp b/src/ir/passes/Inline.cpp new file mode 100644 index 00000000..f396ec09 --- /dev/null +++ b/src/ir/passes/Inline.cpp @@ -0,0 +1,621 @@ +#include "ir/IR.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ir { + +namespace { + +constexpr bool kDebugInline = false; +constexpr int kMaxInlineSize = 200; +constexpr int kMaxMultiBlockInlineSize = 50; + +bool IsRecursive(Function* func) { + if (!func || func->IsExternal()) return true; + for (auto& bb : func->GetBlocks()) { + for (auto& inst : bb->GetInstructions()) { + if (auto* call = dynamic_cast(inst.get())) { + if (call->GetCallee() == func) return true; + } + } + } + return false; +} + +int CountInstructions(Function* func) { + int count = 0; + for (auto& bb : func->GetBlocks()) { + count += bb->GetInstructions().size(); + } + return count; +} + +Value* MapValue(Value* v, const std::unordered_map& value_map) { + auto it = value_map.find(v); + if (it != value_map.end()) return it->second; + return v; +} + +void CloneInstruction(Instruction* inst, + const std::unordered_map& value_map, + std::vector>& out) { + std::unique_ptr cloned; + + switch (inst->GetOpcode()) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Div: + case Opcode::Mod: + case Opcode::Eq: + case Opcode::Ne: + case Opcode::Lt: + case Opcode::Le: + case Opcode::Gt: + case Opcode::Ge: { + auto* bin = static_cast(inst); + Value* lhs = MapValue(bin->GetLhs(), value_map); + Value* rhs = MapValue(bin->GetRhs(), value_map); + cloned = std::make_unique( + inst->GetOpcode(), inst->GetType(), lhs, rhs, + inst->GetName() + ".inl"); + break; + } + case Opcode::SIToFP: + case Opcode::FPToSI: + case Opcode::ZExt: { + auto* cast = static_cast(inst); + Value* operand = MapValue(cast->GetOperandValue(), value_map); + cloned = std::make_unique( + inst->GetOpcode(), inst->GetType(), operand, + inst->GetName() + ".inl"); + break; + } + case Opcode::Load: { + auto* load = static_cast(inst); + Value* ptr = MapValue(load->GetPtr(), value_map); + cloned = std::make_unique( + load->GetType(), ptr, inst->GetName() + ".inl"); + break; + } + case Opcode::Store: { + auto* store = static_cast(inst); + Value* val = MapValue(store->GetValue(), value_map); + Value* ptr = MapValue(store->GetPtr(), value_map); + cloned = std::make_unique(Type::GetVoidType(), val, ptr); + break; + } + case Opcode::GEP: { + auto* gep = static_cast(inst); + Value* base = MapValue(gep->GetBasePtr(), value_map); + Value* index = MapValue(gep->GetIndex(), value_map); + cloned = std::make_unique( + gep->GetType(), base, index, inst->GetName() + ".inl"); + break; + } + case Opcode::Call: { + auto* orig_call = static_cast(inst); + Function* callee_func = orig_call->GetCallee(); + std::vector args; + for (size_t i = 0; i < orig_call->GetNumArgs(); ++i) { + args.push_back(MapValue(orig_call->GetArg(i), value_map)); + } + cloned = std::make_unique( + orig_call->GetType(), callee_func, args, + inst->GetName() + ".inl"); + break; + } + case Opcode::Alloca: { + auto* alloca_inst = static_cast(inst); + if (alloca_inst->IsArrayAlloca()) { + Value* count = MapValue(alloca_inst->GetCount(), value_map); + cloned = std::make_unique( + alloca_inst->GetElementType(), + alloca_inst->GetName() + ".inl", count); + } else { + cloned = std::make_unique( + alloca_inst->GetElementType(), + alloca_inst->GetName() + ".inl"); + } + break; + } + default: + break; + } + + if (cloned) { + out.push_back(std::move(cloned)); + } +} + +bool InlineCall(CallInst* call, Function* callee, Function* caller, + BasicBlock* call_bb, Module* module) { + if (kDebugInline) { + std::cerr << "[Inline] Inlining " << callee->GetName() + << " (" << callee->GetBlocks().size() << " blocks)" + << " into " << caller->GetName() << std::endl; + } + + bool is_single_block = (callee->GetBlocks().size() == 1); + + std::unordered_map value_map; + + for (auto& gv : module->GetGlobals()) { + value_map[gv.get()] = gv.get(); + } + for (auto& other_func : module->GetFunctions()) { + value_map[other_func.get()] = other_func.get(); + } + for (auto& arg : caller->GetParams()) { + value_map[arg.get()] = arg.get(); + } + { + auto& blocks = caller->GetBlocks(); + for (size_t bi = 0; bi < blocks.size(); ++bi) { + auto& insts = blocks[bi]->GetInstructions(); + for (size_t ii = 0; ii < insts.size(); ++ii) { + value_map[insts[ii].get()] = insts[ii].get(); + } + } + } + + for (size_t i = 0; i < callee->GetParams().size(); ++i) { + auto* formal_arg = callee->GetParams()[i].get(); + auto* actual_arg = call->GetArg(i); + value_map[formal_arg] = actual_arg; + } + + auto& call_bb_insts = const_cast>&>( + call_bb->GetInstructions()); + + size_t call_idx = 0; + for (size_t i = 0; i < call_bb_insts.size(); ++i) { + if (call_bb_insts[i].get() == call) { + call_idx = i; + break; + } + } + + if (is_single_block) { + auto* callee_entry = callee->GetEntry(); + Value* return_value = nullptr; + + std::vector> cloned_insts; + std::vector> alloca_insts; + + for (auto& inst : callee_entry->GetInstructions()) { + if (inst->GetOpcode() == Opcode::Alloca) { + std::vector> tmp; + CloneInstruction(inst.get(), value_map, tmp); + if (!tmp.empty()) { + value_map[inst.get()] = tmp.back().get(); + alloca_insts.push_back(std::move(tmp.back())); + } + continue; + } + + if (inst->GetOpcode() == Opcode::Ret) { + auto* ret_inst = static_cast(inst.get()); + if (ret_inst->HasValue()) { + return_value = MapValue(ret_inst->GetValue(), value_map); + } + continue; + } + + std::vector> tmp; + CloneInstruction(inst.get(), value_map, tmp); + if (!tmp.empty()) { + value_map[inst.get()] = tmp.back().get(); + cloned_insts.push_back(std::move(tmp.back())); + } + } + + if (return_value) { + call->ReplaceAllUsesWith(return_value); + } else if (!call->GetType()->IsVoid()) { + call->ReplaceAllUsesWith(module->GetContext().GetConstInt(0)); + } + + auto* entry_bb = caller->GetEntry(); + auto& entry_insts = const_cast>&>( + entry_bb->GetInstructions()); + size_t alloca_insert_pos = 0; + for (size_t i = 0; i < entry_insts.size(); ++i) { + if (entry_insts[i]->GetOpcode() == Opcode::Alloca) { + alloca_insert_pos = i + 1; + } else { + break; + } + } + for (auto& alloca : alloca_insts) { + alloca->SetParent(entry_bb); + entry_insts.insert(entry_insts.begin() + alloca_insert_pos, std::move(alloca)); + alloca_insert_pos++; + } + + size_t insert_pos = call_idx; + for (auto& cloned : cloned_insts) { + cloned->SetParent(call_bb); + call_bb_insts.insert(call_bb_insts.begin() + insert_pos, std::move(cloned)); + insert_pos++; + } + + for (size_t i = 0; i < call_bb_insts.size(); ++i) { + if (call_bb_insts[i].get() == call) { + for (size_t oi = 0; oi < call->GetNumOperands(); ++oi) { + auto* op = call->GetOperand(oi); + if (auto* op_inst = dynamic_cast(op)) { + op_inst->RemoveUse(call, oi); + } + } + call_bb_insts.erase(call_bb_insts.begin() + i); + break; + } + } + + return true; + } + + // === Multi-block inlining === + + // 1. Create after_bb: move instructions after call from call_bb to after_bb + BasicBlock* after_bb = caller->CreateBlock(call_bb->GetName() + ".after"); + + std::vector> after_insts; + for (size_t i = call_idx + 1; i < call_bb_insts.size(); ++i) { + after_insts.push_back(std::move(call_bb_insts[i])); + } + call_bb_insts.resize(call_idx + 1); + + for (auto& inst : after_insts) { + inst->SetParent(after_bb); + after_bb->GetMutablePredecessors(); + } + auto& after_bb_insts = const_cast>&>( + after_bb->GetInstructions()); + for (auto& inst : after_insts) { + after_bb_insts.push_back(std::move(inst)); + } + + // 1b. Fix phi nodes: any phi that had call_bb as predecessor should now use after_bb + for (auto& bb : caller->GetBlocks()) { + for (auto& inst : bb->GetInstructions()) { + if (inst->GetOpcode() != Opcode::Phi) break; + auto* phi = static_cast(inst.get()); + size_t num_ops = phi->GetNumOperands(); + for (size_t i = 0; i + 1 < num_ops; i += 2) { + auto* bb_ptr = dynamic_cast(phi->GetOperand(i + 1)); + if (bb_ptr == call_bb) { + phi->SetOperand(i + 1, after_bb); + } + } + } + } + + // 2. Create cloned blocks for callee + std::unordered_map bb_map; + std::vector cloned_bbs; + for (auto& bb : callee->GetBlocks()) { + BasicBlock* cloned_bb = caller->CreateBlock(bb->GetName() + ".inl"); + bb_map[bb.get()] = cloned_bb; + cloned_bbs.push_back(cloned_bb); + } + BasicBlock* cloned_entry = bb_map[callee->GetEntry()]; + + // 2b. Reorder blocks: move cloned blocks and after_bb right after call_bb + // IMPORTANT: after_bb must come AFTER all cloned blocks, because + // after_bb may use values defined in the cloned blocks (e.g., call results + // from nested inlines). The lowering processes blocks in order, so values + // must be defined before they are used. + { + auto& blocks = const_cast>&>(caller->GetBlocks()); + std::vector move_indices; + for (auto* cb : cloned_bbs) { + for (size_t i = 0; i < blocks.size(); ++i) { + if (blocks[i].get() == cb) { move_indices.push_back(i); break; } + } + } + for (size_t i = 0; i < blocks.size(); ++i) { + if (blocks[i].get() == after_bb) { move_indices.push_back(i); break; } + } + + size_t call_bb_idx = 0; + for (size_t i = 0; i < blocks.size(); ++i) { + if (blocks[i].get() == call_bb) { call_bb_idx = i; break; } + } + + std::vector> extracted; + for (auto idx : move_indices) { + extracted.push_back(std::move(blocks[idx])); + } + + size_t insert_pos = call_bb_idx + 1; + for (auto& b : extracted) { + blocks.insert(blocks.begin() + insert_pos, std::move(b)); + insert_pos++; + } + + blocks.erase(std::remove_if(blocks.begin(), blocks.end(), + [](const std::unique_ptr& b) { return b == nullptr; }), + blocks.end()); + } + + // 4. Create alloca for return value (if non-void) + AllocaInst* ret_alloca = nullptr; + bool has_return = !call->GetType()->IsVoid(); + if (has_return) { + auto* entry_bb = caller->GetEntry(); + auto& entry_insts = const_cast>&>( + entry_bb->GetInstructions()); + auto alloca = std::make_unique(call->GetType(), "__ret.inl"); + alloca->SetParent(entry_bb); + ret_alloca = static_cast(alloca.get()); + size_t alloca_insert_pos = 0; + for (size_t i = 0; i < entry_insts.size(); ++i) { + if (entry_insts[i]->GetOpcode() == Opcode::Alloca) { + alloca_insert_pos = i + 1; + } else { + break; + } + } + entry_insts.insert(entry_insts.begin() + alloca_insert_pos, std::move(alloca)); + } + + // 5. Clone all instructions from callee blocks into cloned blocks + // Pass 1: Create cloned instructions with original operands, build value_map + std::vector> alloca_insts; + std::vector> remap_list; + + for (auto& bb : callee->GetBlocks()) { + BasicBlock* cloned_bb = bb_map[bb.get()]; + auto& cloned_insts = const_cast>&>( + cloned_bb->GetInstructions()); + + for (auto& inst : bb->GetInstructions()) { + if (inst->GetOpcode() == Opcode::Alloca) { + std::vector> tmp; + CloneInstruction(inst.get(), value_map, tmp); + if (!tmp.empty()) { + value_map[inst.get()] = tmp.back().get(); + alloca_insts.push_back(std::move(tmp.back())); + } + continue; + } + + if (inst->GetOpcode() == Opcode::Phi) { + auto* phi = static_cast(inst.get()); + auto new_phi = std::make_unique(phi->GetType(), phi->GetName() + ".inl"); + new_phi->SetParent(cloned_bb); + value_map[inst.get()] = new_phi.get(); + cloned_insts.push_back(std::move(new_phi)); + continue; + } + + if (inst->IsTerminator()) continue; + + std::vector> tmp; + CloneInstruction(inst.get(), value_map, tmp); + if (!tmp.empty()) { + tmp.back()->SetParent(cloned_bb); + value_map[inst.get()] = tmp.back().get(); + remap_list.push_back({inst.get(), tmp.back().get()}); + cloned_insts.push_back(std::move(tmp.back())); + } + } + } + + // Pass 1b: Remap operands of cloned instructions now that value_map is complete + for (auto& [orig, cloned] : remap_list) { + for (size_t i = 0; i < orig->GetNumOperands(); ++i) { + Value* orig_op = orig->GetOperand(i); + Value* mapped = MapValue(orig_op, value_map); + if (mapped != orig_op) { + cloned->SetOperand(i, mapped); + } + } + } + + // Pass 2: fill phi operands and handle terminators + for (auto& bb : callee->GetBlocks()) { + BasicBlock* cloned_bb = bb_map[bb.get()]; + auto& cloned_insts = const_cast>&>( + cloned_bb->GetInstructions()); + + for (auto& inst : bb->GetInstructions()) { + if (inst->GetOpcode() == Opcode::Phi) { + auto* orig_phi = static_cast(inst.get()); + auto* cloned_phi = static_cast(value_map[orig_phi]); + if (!cloned_phi) continue; + + for (size_t i = 0; i < orig_phi->GetNumOperands(); i += 2) { + Value* val = MapValue(orig_phi->GetOperand(i), value_map); + auto* orig_pred = static_cast(orig_phi->GetOperand(i + 1)); + auto pred_it = bb_map.find(orig_pred); + BasicBlock* pred = (pred_it != bb_map.end()) ? pred_it->second : orig_pred; + cloned_phi->AddOperand(val); + cloned_phi->AddOperand(pred); + } + continue; + } + + if (inst->GetOpcode() == Opcode::Ret) { + auto* ret_inst = static_cast(inst.get()); + if (ret_inst->HasValue() && has_return) { + Value* ret_val = MapValue(ret_inst->GetValue(), value_map); + auto store = std::make_unique( + Type::GetVoidType(), ret_val, ret_alloca); + store->SetParent(cloned_bb); + cloned_insts.push_back(std::move(store)); + } + auto br = std::make_unique(Type::GetVoidType(), after_bb); + br->SetParent(cloned_bb); + cloned_insts.push_back(std::move(br)); + continue; + } + + if (inst->GetOpcode() == Opcode::Br) { + auto* br = static_cast(inst.get()); + auto it = bb_map.find(br->GetTarget()); + BasicBlock* target = (it != bb_map.end()) ? it->second : br->GetTarget(); + auto new_br = std::make_unique(Type::GetVoidType(), target); + new_br->SetParent(cloned_bb); + cloned_insts.push_back(std::move(new_br)); + continue; + } + + if (inst->GetOpcode() == Opcode::CondBr) { + auto* cbr = static_cast(inst.get()); + Value* cond = MapValue(cbr->GetCond(), value_map); + auto true_it = bb_map.find(cbr->GetTrueTarget()); + BasicBlock* true_target = (true_it != bb_map.end()) ? true_it->second : cbr->GetTrueTarget(); + auto false_it = bb_map.find(cbr->GetFalseTarget()); + BasicBlock* false_target = (false_it != bb_map.end()) ? false_it->second : cbr->GetFalseTarget(); + auto new_cbr = std::make_unique( + Type::GetVoidType(), cond, true_target, false_target); + new_cbr->SetParent(cloned_bb); + cloned_insts.push_back(std::move(new_cbr)); + continue; + } + } + } + + // 7. Insert alloca_insts into caller entry + { + auto* entry_bb = caller->GetEntry(); + auto& entry_insts = const_cast>&>( + entry_bb->GetInstructions()); + size_t alloca_insert_pos = 0; + for (size_t i = 0; i < entry_insts.size(); ++i) { + if (entry_insts[i]->GetOpcode() == Opcode::Alloca) { + alloca_insert_pos = i + 1; + } else { + break; + } + } + for (auto& alloca : alloca_insts) { + alloca->SetParent(entry_bb); + entry_insts.insert(entry_insts.begin() + alloca_insert_pos, std::move(alloca)); + alloca_insert_pos++; + } + } + + // 8-9. Handle return value and remove call + auto call_type = call->GetType(); + + if (has_return) { + auto load_ret = std::make_unique( + call_type, ret_alloca, "__ret.load.inl"); + load_ret->SetParent(after_bb); + Value* ret_val = load_ret.get(); + after_bb_insts.insert(after_bb_insts.begin(), std::move(load_ret)); + + call->ReplaceAllUsesWith(ret_val); + } else { + call->ReplaceAllUsesWith(module->GetContext().GetConstInt(0)); + } + + // Remove the call and add branch to cloned_entry + for (size_t i = 0; i < call_bb_insts.size(); ++i) { + if (call_bb_insts[i].get() == call) { + for (size_t oi = 0; oi < call->GetNumOperands(); ++oi) { + auto* op = call->GetOperand(oi); + if (auto* op_inst = dynamic_cast(op)) { + op_inst->RemoveUse(call, oi); + } + } + call_bb_insts.erase(call_bb_insts.begin() + i); + break; + } + } + + auto br_to_entry = std::make_unique(Type::GetVoidType(), cloned_entry); + br_to_entry->SetParent(call_bb); + call_bb_insts.push_back(std::move(br_to_entry)); + + if (kDebugInline) { + std::cerr << "[Inline] Done inlining " << callee->GetName() << std::endl; + } + + return true; +} + +} // namespace + +void RunInline(Module* module) { + if (!module) return; + + std::unordered_map func_sizes; + std::unordered_set recursive_funcs; + + for (auto& func : module->GetFunctions()) { + if (func->IsExternal()) continue; + func_sizes[func->GetName()] = CountInstructions(func.get()); + if (IsRecursive(func.get())) { + recursive_funcs.insert(func->GetName()); + } + } + + struct InlineSite { + CallInst* call; + Function* caller; + BasicBlock* call_bb; + }; + + std::vector inline_sites; + + for (auto& caller : module->GetFunctions()) { + if (caller->IsExternal()) continue; + + for (auto& bb : caller->GetBlocks()) { + for (auto& inst : bb->GetInstructions()) { + auto* call = dynamic_cast(inst.get()); + if (!call) continue; + + auto* callee = call->GetCallee(); + if (!callee) continue; + if (callee->IsExternal()) continue; + if (recursive_funcs.count(callee->GetName())) continue; + + auto size_it = func_sizes.find(callee->GetName()); + int callee_size = (size_it != func_sizes.end()) ? size_it->second : 9999; + + if (callee_size > kMaxInlineSize) continue; + if (callee == caller.get()) continue; + if (callee->GetBlocks().size() > 1 && callee_size > kMaxMultiBlockInlineSize) continue; + + inline_sites.push_back({call, caller.get(), bb.get()}); + } + } + } + + for (auto& site : inline_sites) { + auto* callee = site.call->GetCallee(); + if (!callee) continue; + + bool still_valid = false; + BasicBlock* actual_bb = nullptr; + for (auto& bb : site.caller->GetBlocks()) { + for (auto& inst : bb->GetInstructions()) { + if (inst.get() == site.call) { + still_valid = true; + actual_bb = bb.get(); + break; + } + } + if (still_valid) break; + } + if (!still_valid) continue; + + InlineCall(site.call, callee, site.caller, actual_bb, module); + } +} + +} // namespace ir diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index 2f94f941..7ff3ba50 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -656,16 +656,9 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { } } - if (callee_name == "starttime" || callee_name == "stoptime") { - int lineno = ctx->getStart()->getLine(); - args.push_back(static_cast(builder_.CreateConstInt(lineno))); - } - if (args.size() != func_it->second->GetParams().size()) { - if (callee_name != "starttime" && callee_name != "stoptime") { - throw std::runtime_error( - FormatError("irgen", "函数参数个数不匹配: " + callee_name)); - } + throw std::runtime_error( + FormatError("irgen", "函数参数个数不匹配: " + callee_name)); } for (size_t i = 0; i < args.size(); ++i) { args[i] = CastValueTo(args[i], func_it->second->GetParams()[i]->GetType()); diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 9c6c1342..21ca2631 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -119,11 +119,9 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { auto* putch = module_.CreateFunction("putch", ir::Type::GetVoidType(), true); putch->AddParam("%arg.x", ir::Type::GetInt32Type()); function_map_["putch"] = putch; - auto* sysy_starttime = module_.CreateFunction("_sysy_starttime", ir::Type::GetVoidType(), true); - sysy_starttime->AddParam("%arg.lineno", ir::Type::GetInt32Type()); + auto* sysy_starttime = module_.CreateFunction("starttime", ir::Type::GetVoidType(), true); function_map_["starttime"] = sysy_starttime; - auto* sysy_stoptime = module_.CreateFunction("_sysy_stoptime", ir::Type::GetVoidType(), true); - sysy_stoptime->AddParam("%arg.lineno", ir::Type::GetInt32Type()); + auto* sysy_stoptime = module_.CreateFunction("stoptime", ir::Type::GetVoidType(), true); function_map_["stoptime"] = sysy_stoptime; SysYParser::FuncDefContext* main_func = nullptr; diff --git a/src/main.cpp b/src/main.cpp index 9d42d592..643a987e 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -44,42 +44,17 @@ int main(int argc, char** argv) { // 执行优化(如果启用) if (opts.optimize) { - ir::PassManagerModule pass_manager(module.get()); - pass_manager.Run(); + ir::PassManager pass_manager; + pass_manager.RunScalarOptimizationPasses(module.get()); } - // Debug 模式:验证 IR 合法性 -#ifndef NDEBUG - ir::VerifyIR(*module); -#endif - // 汇编输出到文件或标准输出 if (opts.emit_asm) { auto machine_module = mir::LowerModuleToMIR(*module); - -#ifndef NDEBUG - mir::VerifyMIR(*machine_module); -#endif - - mir::RunGreedyRegAlloc(*machine_module); - -#ifndef NDEBUG - mir::VerifyRegAlloc(*machine_module); - mir::VerifyMIR(*machine_module); -#endif - + mir::RunRegAlloc(*machine_module); mir::RunFrameLowering(*machine_module); - -#ifndef NDEBUG - mir::VerifyMIR(*machine_module); -#endif - mir::RunPeephole(*machine_module); -#ifndef NDEBUG - mir::VerifyMIR(*machine_module); -#endif - std::ostringstream asm_ss; mir::PrintAsm(*machine_module, asm_ss); diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index 713c03c6..dd4ba798 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -42,10 +42,8 @@ namespace mir case Opcode::StoreStack: return "stur"; case Opcode::AddRR: - case Opcode::AddImm: return "add"; case Opcode::SubRR: - case Opcode::SubImm: return "sub"; case Opcode::MulRR: return "mul"; @@ -79,6 +77,8 @@ namespace mir return "fcmp"; case Opcode::CSet: return "cset"; + case Opcode::Csneg: + return "csneg"; case Opcode::Scvtf: return "scvtf"; case Opcode::FCvtzs: @@ -223,11 +223,6 @@ namespace mir { continue; } - // 跳过前导零——直接用移位后的 movz,避免浪费 movz #0 - if (!emitted && part == 0) - { - continue; - } if (!emitted) { @@ -256,36 +251,8 @@ namespace mir } } - // ADRP 缓存——避免连续访问同一全局变量时重复发射 ADRP - std::string g_cached_adrp_symbol; - bool g_adrp_cache_valid = false; - - // 帧基址缓存——x13 持有 x29 + g_frame_base_offset,避免重复计算地址 - int g_frame_base_offset = 0; - bool g_frame_base_valid = false; - - void InvalidateFrameBase() - { - g_frame_base_valid = false; - } - - void InvalidateAdrpCache() - { - g_adrp_cache_valid = false; - } - void EmitStackAdjust(const char *op, int amount, std::ostream &os) { - if (amount > 12285) - { - InvalidateAdrpCache(); - InvalidateFrameBase(); - os << " movz x13, #" << (amount & 0xFFFF) << "\n"; - if ((amount >> 16) != 0) - os << " movk x13, #" << ((amount >> 16) & 0xFFFF) << ", lsl #16\n"; - os << " " << op << " sp, sp, x13\n"; - return; - } while (amount > 0) { const int chunk = amount > 4095 ? 4095 : amount; @@ -305,38 +272,6 @@ namespace mir void EmitAddressFromBase(PhysReg target_xreg, PhysReg base_reg, int offset, std::ostream &os) { - // 使用 x13 时,ADRP 和帧基址缓存同时失效 - if (target_xreg == PrinterScratchXReg()) - { - InvalidateAdrpCache(); - InvalidateFrameBase(); - } - - if (offset > 12285) - { - // 使用 x13 作为立即数暂存,必须失效帧基址和 ADRP 缓存 - InvalidateAdrpCache(); - InvalidateFrameBase(); - os << " movz x13, #" << (offset & 0xFFFF) << "\n"; - if ((offset >> 16) != 0) - os << " movk x13, #" << ((offset >> 16) & 0xFFFF) << ", lsl #16\n"; - os << " add " << PhysRegName(target_xreg) << ", " - << PhysRegName(base_reg) << ", x13\n"; - return; - } - if (offset < -12285) - { - int abs_off = -offset; - // 使用 x13 作为立即数暂存,必须失效帧基址和 ADRP 缓存 - InvalidateAdrpCache(); - InvalidateFrameBase(); - os << " movz x13, #" << (abs_off & 0xFFFF) << "\n"; - if ((abs_off >> 16) != 0) - os << " movk x13, #" << ((abs_off >> 16) & 0xFFFF) << ", lsl #16\n"; - os << " sub " << PhysRegName(target_xreg) << ", " - << PhysRegName(base_reg) << ", x13\n"; - return; - } os << " mov " << PhysRegName(target_xreg) << ", " << PhysRegName(base_reg) << "\n"; @@ -362,7 +297,6 @@ namespace mir const char *narrow_op = (opcode == Opcode::LoadStack) ? "ldur" : "stur"; const char *wide_op = (opcode == Opcode::LoadStack) ? "ldr" : "str"; - // x29 可达的窄范围直接用 ldur/stur if (offset >= -256 && offset <= 255) { os << " " << narrow_op << " "; @@ -372,41 +306,7 @@ namespace mir } const PhysReg scratch_xreg = PrinterScratchXReg(); - bool is_32bit = IsWReg(reg.GetReg()) || IsSReg(reg.GetReg()); - - // 尝试帧基址缓存——x13 已持有之前的地址 - if (g_frame_base_valid) - { - int diff = offset - g_frame_base_offset; - - // ldur/stur(范围 ±256) - if (diff >= -256 && diff <= 255) - { - os << " " << narrow_op << " "; - PrintOperand(reg, os); - os << ", [" << PhysRegName(scratch_xreg) << ", #" << diff << "]\n"; - return; - } - - // ldr/str 无符号立即数(正偏移) - if (diff >= 0) - { - int max_imm = is_32bit ? 16380 : 32760; - int align = is_32bit ? 4 : 8; - if (diff <= max_imm && diff % align == 0) - { - os << " " << wide_op << " "; - PrintOperand(reg, os); - os << ", [" << PhysRegName(scratch_xreg) << ", #" << diff << "]\n"; - return; - } - } - } - - // 缓存未命中——完整计算地址到 x13 EmitAddressFromBase(scratch_xreg, PhysReg::X29, offset, os); - g_frame_base_offset = offset; - g_frame_base_valid = true; os << " " << wide_op << " "; PrintOperand(reg, os); @@ -433,18 +333,7 @@ namespace mir const std::string asm_symbol = NormalizeAsmSymbol(symbol); const PhysReg scratch_xreg = PrinterScratchXReg(); - if (g_adrp_cache_valid && g_cached_adrp_symbol == asm_symbol) - { - // x13 已持有该全局变量的页面地址,跳过 ADRP - } - else - { - os << " adrp " << PhysRegName(scratch_xreg) << ", " << asm_symbol << "\n"; - g_cached_adrp_symbol = asm_symbol; - g_adrp_cache_valid = true; - InvalidateFrameBase(); - } - + os << " adrp " << PhysRegName(scratch_xreg) << ", " << asm_symbol << "\n"; os << " " << (opcode == Opcode::LoadGlobal ? "ldr " : "str "); PrintOperand(reg, os); os << ", [" << PhysRegName(scratch_xreg) << ", #:lo12:" << asm_symbol << "]\n"; @@ -513,67 +402,25 @@ namespace mir case Opcode::Prologue: { const auto &cs_regs = function.GetCalleeSavedRegs(); - const bool is_leaf = !function.HasCall(); - const bool no_frame = (function.GetFrameSize() == 0 && cs_regs.empty()); - - // 叶函数无帧且无 callee-saved 寄存器:完全跳过帧设置 - if (is_leaf && no_frame) - { - return; - } - - // 叶函数仅保存 x29(LR 不会被修改),非叶函数保存 x29+x30 - if (is_leaf) - { - os << " str x29, [sp, #-8]!\n"; - } - else - { - os << " stp x29, x30, [sp, #-16]!\n"; - } + os << " stp x29, x30, [sp, #-16]!\n"; os << " mov x29, sp\n"; - if (function.GetFrameSize() > 0) { EmitStackAdjust("sub", function.GetFrameSize(), os); } - - // X(64-bit) 和 S(32-bit) 分两组配对 stp - std::vector x_regs, s_regs; + int cs_offset = 0; for (auto r : cs_regs) { if (r >= PhysReg::X0 && r <= PhysReg::X30) - x_regs.push_back(r); + { + os << " str " << PhysRegName(r) << ", [sp, #" << cs_offset << "]\n"; + cs_offset += 8; + } else if (r >= PhysReg::S0 && r <= PhysReg::S31) - s_regs.push_back(r); - else - x_regs.push_back(r); // 兜底:非 X 非 S 按 X 处理 - } - int cs_offset = 0; - for (size_t i = 0; i + 1 < x_regs.size(); i += 2) - { - os << " stp " << PhysRegName(x_regs[i]) << ", " - << PhysRegName(x_regs[i + 1]) - << ", [sp, #" << cs_offset << "]\n"; - cs_offset += 16; - } - if (x_regs.size() % 2 == 1) - { - os << " str " << PhysRegName(x_regs.back()) - << ", [sp, #" << cs_offset << "]\n"; - cs_offset += 8; - } - for (size_t i = 0; i + 1 < s_regs.size(); i += 2) - { - os << " stp " << PhysRegName(s_regs[i]) << ", " - << PhysRegName(s_regs[i + 1]) - << ", [sp, #" << cs_offset << "]\n"; - cs_offset += 8; - } - if (s_regs.size() % 2 == 1) - { - os << " str " << PhysRegName(s_regs.back()) - << ", [sp, #" << cs_offset << "]\n"; + { + os << " str " << PhysRegName(r) << ", [sp, #" << cs_offset << "]\n"; + cs_offset += 4; + } } return; } @@ -581,67 +428,25 @@ namespace mir case Opcode::Epilogue: { const auto &cs_regs = function.GetCalleeSavedRegs(); - const bool is_leaf = !function.HasCall(); - const bool no_frame = (function.GetFrameSize() == 0 && cs_regs.empty()); - - // 叶函数无帧且无 callee-saved 寄存器——直接返回 - if (is_leaf && no_frame) - { - os << " ret\n"; - return; - } - - // 恢复 callee-saved 寄存器(叶函数也需要——它们属于调用者) - std::vector x_regs, s_regs; + int cs_offset = 0; for (auto r : cs_regs) { if (r >= PhysReg::X0 && r <= PhysReg::X30) - x_regs.push_back(r); + { + os << " ldr " << PhysRegName(r) << ", [sp, #" << cs_offset << "]\n"; + cs_offset += 8; + } else if (r >= PhysReg::S0 && r <= PhysReg::S31) - s_regs.push_back(r); - else - x_regs.push_back(r); - } - int cs_offset = 0; - for (size_t i = 0; i + 1 < x_regs.size(); i += 2) - { - os << " ldp " << PhysRegName(x_regs[i]) << ", " - << PhysRegName(x_regs[i + 1]) - << ", [sp, #" << cs_offset << "]\n"; - cs_offset += 16; - } - if (x_regs.size() % 2 == 1) - { - os << " ldr " << PhysRegName(x_regs.back()) - << ", [sp, #" << cs_offset << "]\n"; - cs_offset += 8; - } - for (size_t i = 0; i + 1 < s_regs.size(); i += 2) - { - os << " ldp " << PhysRegName(s_regs[i]) << ", " - << PhysRegName(s_regs[i + 1]) - << ", [sp, #" << cs_offset << "]\n"; - cs_offset += 8; - } - if (s_regs.size() % 2 == 1) - { - os << " ldr " << PhysRegName(s_regs.back()) - << ", [sp, #" << cs_offset << "]\n"; + { + os << " ldr " << PhysRegName(r) << ", [sp, #" << cs_offset << "]\n"; + cs_offset += 4; + } } - if (function.GetFrameSize() > 0) { EmitStackAdjust("add", function.GetFrameSize(), os); } - - if (is_leaf) - { - os << " ldr x29, [sp], #8\n"; - } - else - { - os << " ldp x29, x30, [sp], #16\n"; - } + os << " ldp x29, x30, [sp], #16\n"; os << " ret\n"; return; } @@ -854,6 +659,19 @@ namespace mir } return; + case Opcode::Csneg: + if (operands.size() >= 4) + { + os << " csneg "; + PrintOperand(operands[0], os); + os << ", "; + PrintOperand(operands[1], os); + os << ", "; + PrintOperand(operands[2], os); + os << ", " << CondCodeToAsm(static_cast(operands[3].GetImm())) << "\n"; + } + return; + case Opcode::Smull: if (operands.size() >= 3) { @@ -941,11 +759,6 @@ namespace mir } return; - case Opcode::Call: - InvalidateAdrpCache(); - InvalidateFrameBase(); // x13 是 caller-saved,被调用破坏 - // 不 break,落到 default 让泛型打印机输出 bl 指令 - default: break; } @@ -973,6 +786,47 @@ namespace mir for (const auto &global : module.GetGlobals()) { const std::string asm_name = NormalizeAsmSymbol(global.name); + + bool is_zero_init = false; + if (global.kind == MachineGlobal::Kind::I32Scalar && global.init_value == 0) + { + is_zero_init = true; + } + if (global.kind == MachineGlobal::Kind::I32Array) + { + bool all_zero = true; + for (auto v : global.init_values) + { + if (v != 0) + { + all_zero = false; + break; + } + } + if (all_zero) + { + is_zero_init = true; + } + } + + if (is_zero_init) + { + os << " .bss\n"; + os << " .globl " << asm_name << "\n"; + os << " .p2align 2\n"; + os << asm_name << ":\n"; + if (global.kind == MachineGlobal::Kind::I32Scalar) + { + os << " .space 4\n"; + } + else + { + os << " .space " << (global.array_size * 4) << "\n"; + } + os << " .data\n"; + continue; + } + os << " .globl " << asm_name << "\n"; os << " .p2align 2\n"; os << asm_name << ":\n"; @@ -1001,8 +855,6 @@ namespace mir void PrintAsm(const MachineFunction &function, std::ostream &os) { - g_adrp_cache_valid = false; - g_frame_base_valid = false; const std::string asm_name = NormalizeAsmSymbol(function.GetName()); os << " .text\n"; @@ -1018,9 +870,6 @@ namespace mir } const auto &block = *block_ptr; - // 每个基本块重置缓存——跨块时 x13 可能已被 call/clobber 破坏 - g_adrp_cache_valid = false; - g_frame_base_valid = false; PrintBlockLabelRef(function, block.GetLabelId(), os); os << ":\n"; diff --git a/src/mir/CMakeLists.txt b/src/mir/CMakeLists.txt index b6cec6d6..86362e15 100644 --- a/src/mir/CMakeLists.txt +++ b/src/mir/CMakeLists.txt @@ -8,10 +8,7 @@ add_library(mir_core STATIC RegAlloc.cpp FrameLowering.cpp AsmPrinter.cpp - MIRVerifier.cpp - RegAllocVerifier.cpp - InstLiveness.cpp - GreedyAlloc.cpp + analysis/CFGAnalysis.cpp ) target_link_libraries(mir_core PUBLIC diff --git a/src/mir/FrameLowering.cpp b/src/mir/FrameLowering.cpp index efcbfd14..34422b00 100644 --- a/src/mir/FrameLowering.cpp +++ b/src/mir/FrameLowering.cpp @@ -59,9 +59,7 @@ namespace mir { if (slot.is_callee_stack_arg) { - // 叶函数仅保存 x29(8字节),非叶函数保存 x29+x30(16字节) - // 栈参数偏移需根据实际情况调整 - slot.offset = (function.HasCall() ? 16 : 8) + slot.offset; + slot.offset = 16 + slot.offset; } } } diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index 29bb154a..679eb75c 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -115,24 +115,6 @@ namespace mir } } - // 交换比较操作数时反转条件码(aa) - static CondCode SwapCondCode(CondCode cond) - { - switch (cond) - { - case CondCode::LT: - return CondCode::GT; - case CondCode::LE: - return CondCode::GE; - case CondCode::GT: - return CondCode::LT; - case CondCode::GE: - return CondCode::LE; - default: - return cond; // EQ/NE 对称 - } - } - static PhysReg GetArgWReg(size_t index) { static const PhysReg regs[] = { @@ -356,43 +338,16 @@ namespace mir { if (IsIntegerCompareOpcode(bin->GetOpcode())) { - // 常量折叠到 CmpImm,消除冗余 MovImm - int lhs_imm, rhs_imm; - bool lhs_const = TryGetConstantInt(bin->GetLhs(), lhs_imm); - bool rhs_const = TryGetConstantInt(bin->GetRhs(), rhs_imm); - auto imm_fits = [](int imm) { return imm >= 0 && imm <= 4095; }; - - CondCode cond = GetCondCodeForCompareOpcode(bin->GetOpcode()); - - if (rhs_const && imm_fits(rhs_imm)) - { - int lhs = EmitIntValue(bin->GetLhs(), function, value_vregs, - scalar_slots, array_slots, block); - block.Append(Opcode::CmpImm, - {Operand::VReg(lhs, VRegClass::Int), Operand::Imm(rhs_imm)}); - } - else if (lhs_const && imm_fits(lhs_imm)) - { - int rhs = EmitIntValue(bin->GetRhs(), function, value_vregs, - scalar_slots, array_slots, block); - block.Append(Opcode::CmpImm, - {Operand::VReg(rhs, VRegClass::Int), Operand::Imm(lhs_imm)}); - cond = SwapCondCode(cond); - } - else - { - int lhs = EmitIntValue(bin->GetLhs(), function, value_vregs, - scalar_slots, array_slots, block); - int rhs = EmitIntValue(bin->GetRhs(), function, value_vregs, - scalar_slots, array_slots, block); - block.Append(Opcode::CmpRR, - {Operand::VReg(lhs, VRegClass::Int), Operand::VReg(rhs, VRegClass::Int)}); - } - + int lhs = EmitIntValue(bin->GetLhs(), function, value_vregs, + scalar_slots, array_slots, block); + int rhs = EmitIntValue(bin->GetRhs(), function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::CmpRR, + {Operand::VReg(lhs, VRegClass::Int), Operand::VReg(rhs, VRegClass::Int)}); int dst = function.CreateVReg(VRegClass::Int); block.Append(Opcode::CSet, {Operand::VReg(dst, VRegClass::Int), - Operand::Imm(static_cast(cond))}); + Operand::Imm(static_cast(GetCondCodeForCompareOpcode(bin->GetOpcode())))}); value_vregs[value] = dst; return dst; } @@ -473,7 +428,101 @@ namespace mir value_vregs[value] = dst; return dst; } - // 2的幂次除法(含正负)改用 sdiv,比移位序列更短 + if (val > 0 && (val & (val - 1)) == 0) + { + int shift = 0; + int tmp = val; + while (tmp > 1) + { + tmp >>= 1; + ++shift; + } + int bias = (1 << shift) - 1; + int biased = function.CreateVReg(VRegClass::Int); + if (bias <= 4095) + { + block.Append(Opcode::AddRR, + {Operand::VReg(biased, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(bias)}); + } + else + { + int bias_reg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::MovImm, + {Operand::VReg(bias_reg, VRegClass::Int), + Operand::Imm(bias)}).SetRematerializable(true).SetRematImm(bias); + block.Append(Opcode::AddRR, + {Operand::VReg(biased, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::VReg(bias_reg, VRegClass::Int)}); + } + block.Append(Opcode::CmpImm, + {Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(0)}); + int selected = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::Csel, + {Operand::VReg(selected, VRegClass::Int), + Operand::VReg(biased, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(static_cast(CondCode::LT))}); + block.Append(Opcode::AsrRR, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(selected, VRegClass::Int), + Operand::Imm(shift)}); + value_vregs[value] = dst; + return dst; + } + if (val < 0 && (-val & (-val - 1)) == 0 && val != -1) + { + int abs_val = -val; + int shift = 0; + int tmp = abs_val; + while (tmp > 1) + { + tmp >>= 1; + ++shift; + } + int bias = (1 << shift) - 1; + int biased = function.CreateVReg(VRegClass::Int); + if (bias <= 4095) + { + block.Append(Opcode::AddRR, + {Operand::VReg(biased, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(bias)}); + } + else + { + int bias_reg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::MovImm, + {Operand::VReg(bias_reg, VRegClass::Int), + Operand::Imm(bias)}).SetRematerializable(true).SetRematImm(bias); + block.Append(Opcode::AddRR, + {Operand::VReg(biased, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::VReg(bias_reg, VRegClass::Int)}); + } + block.Append(Opcode::CmpImm, + {Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(0)}); + int selected = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::Csel, + {Operand::VReg(selected, VRegClass::Int), + Operand::VReg(biased, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(static_cast(CondCode::LT))}); + int pos_q = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::AsrRR, + {Operand::VReg(pos_q, VRegClass::Int), + Operand::VReg(selected, VRegClass::Int), + Operand::Imm(shift)}); + block.Append(Opcode::NegRR, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(pos_q, VRegClass::Int)}); + value_vregs[value] = dst; + return dst; + } } } @@ -483,43 +532,128 @@ namespace mir if (rhs_const) { int val = rhs_const->GetValue(); - // x % 1 == 0, x % -1 == 0 - if (val == 1 || val == -1) + if (val > 0 && (val & (val - 1)) == 0) { + int bias = val - 1; + int biased = function.CreateVReg(VRegClass::Int); + if (bias <= 4095) + { + block.Append(Opcode::AddRR, + {Operand::VReg(biased, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(bias)}); + } + else + { + int bias_reg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::MovImm, + {Operand::VReg(bias_reg, VRegClass::Int), + Operand::Imm(bias)}).SetRematerializable(true).SetRematImm(bias); + block.Append(Opcode::AddRR, + {Operand::VReg(biased, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::VReg(bias_reg, VRegClass::Int)}); + } + int shift = 0; + int tmp = val; + while (tmp > 1) + { + tmp >>= 1; + ++shift; + } + block.Append(Opcode::CmpImm, + {Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(0)}); + int selected = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::Csel, + {Operand::VReg(selected, VRegClass::Int), + Operand::VReg(biased, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(static_cast(CondCode::LT))}); + int q_dst = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::AsrRR, + {Operand::VReg(q_dst, VRegClass::Int), + Operand::VReg(selected, VRegClass::Int), + Operand::Imm(shift)}); + int d_reg = function.CreateVReg(VRegClass::Int); block.Append(Opcode::MovImm, + {Operand::VReg(d_reg, VRegClass::Int), + Operand::Imm(val)}).SetRematerializable(true).SetRematImm(val); + block.Append(Opcode::Msub, {Operand::VReg(dst, VRegClass::Int), - Operand::Imm(0)}).SetRematerializable(true).SetRematImm(0); + Operand::VReg(q_dst, VRegClass::Int), + Operand::VReg(d_reg, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int)}); + value_vregs[value] = dst; + return dst; + } + if (val < 0 && (-val & (-val - 1)) == 0 && val != -1) + { + int abs_val = -val; + int bias = abs_val - 1; + int biased = function.CreateVReg(VRegClass::Int); + if (bias <= 4095) + { + block.Append(Opcode::AddRR, + {Operand::VReg(biased, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(bias)}); + } + else + { + int bias_reg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::MovImm, + {Operand::VReg(bias_reg, VRegClass::Int), + Operand::Imm(bias)}).SetRematerializable(true).SetRematImm(bias); + block.Append(Opcode::AddRR, + {Operand::VReg(biased, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::VReg(bias_reg, VRegClass::Int)}); + } + int shift = 0; + int tmp = abs_val; + while (tmp > 1) + { + tmp >>= 1; + ++shift; + } + block.Append(Opcode::CmpImm, + {Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(0)}); + int selected = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::Csel, + {Operand::VReg(selected, VRegClass::Int), + Operand::VReg(biased, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(static_cast(CondCode::LT))}); + int asr_result = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::AsrRR, + {Operand::VReg(asr_result, VRegClass::Int), + Operand::VReg(selected, VRegClass::Int), + Operand::Imm(shift)}); + int q_dst = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::NegRR, + {Operand::VReg(q_dst, VRegClass::Int), + Operand::VReg(asr_result, VRegClass::Int)}); + int d_reg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::MovImm, + {Operand::VReg(d_reg, VRegClass::Int), + Operand::Imm(val)}).SetRematerializable(true).SetRematImm(val); + block.Append(Opcode::Msub, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(q_dst, VRegClass::Int), + Operand::VReg(d_reg, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int)}); value_vregs[value] = dst; return dst; } - // 2的幂次取模(含正负)改用 ModRR(sdiv+msub),比移位序列更短 } } - // Add/Sub 常量折叠到立即数操作码 - int rhs_imm_val; - bool rhs_is_imm = false; - if ((opcode == Opcode::AddRR || opcode == Opcode::SubRR) && - bin->GetRhs() && TryGetConstantInt(bin->GetRhs(), rhs_imm_val) && - rhs_imm_val >= 0 && rhs_imm_val <= 4095) - { - rhs_is_imm = true; - if (opcode == Opcode::AddRR) - opcode = Opcode::AddImm; - else - opcode = Opcode::SubImm; - block.Append(opcode, - {Operand::VReg(dst, VRegClass::Int), - Operand::VReg(lhs, VRegClass::Int), - Operand::Imm(rhs_imm_val)}); - } - else - { - block.Append(opcode, - {Operand::VReg(dst, VRegClass::Int), - Operand::VReg(lhs, VRegClass::Int), - Operand::VReg(rhs, VRegClass::Int)}); - } + block.Append(opcode, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::VReg(rhs, VRegClass::Int)}); value_vregs[value] = dst; return dst; } @@ -824,35 +958,12 @@ namespace mir return; } - // 常量折叠到 CmpImm - int lhs_imm, rhs_imm; - bool lhs_const = TryGetConstantInt(bin.GetLhs(), lhs_imm); - bool rhs_const = TryGetConstantInt(bin.GetRhs(), rhs_imm); - auto imm_fits = [](int imm) { return imm >= 0 && imm <= 4095; }; - - if (rhs_const && imm_fits(rhs_imm)) - { - int lhs = EmitIntValue(bin.GetLhs(), function, value_vregs, - scalar_slots, array_slots, block); - block.Append(Opcode::CmpImm, - {Operand::VReg(lhs, VRegClass::Int), Operand::Imm(rhs_imm)}); - } - else if (lhs_const && imm_fits(lhs_imm)) - { - int rhs = EmitIntValue(bin.GetRhs(), function, value_vregs, - scalar_slots, array_slots, block); - block.Append(Opcode::CmpImm, - {Operand::VReg(rhs, VRegClass::Int), Operand::Imm(lhs_imm)}); - } - else - { - int lhs = EmitIntValue(bin.GetLhs(), function, value_vregs, - scalar_slots, array_slots, block); - int rhs = EmitIntValue(bin.GetRhs(), function, value_vregs, - scalar_slots, array_slots, block); - block.Append(Opcode::CmpRR, - {Operand::VReg(lhs, VRegClass::Int), Operand::VReg(rhs, VRegClass::Int)}); - } + int lhs = EmitIntValue(bin.GetLhs(), function, value_vregs, + scalar_slots, array_slots, block); + int rhs = EmitIntValue(bin.GetRhs(), function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::CmpRR, + {Operand::VReg(lhs, VRegClass::Int), Operand::VReg(rhs, VRegClass::Int)}); } static bool TryEmitCondValueToFlags(const ir::Value *value, @@ -1565,7 +1676,6 @@ namespace mir } block.Append(Opcode::Call, {Operand::Symbol(callee->GetName())}); - function.SetHasCall(); if (aligned_stack_arg_bytes > 0) { diff --git a/src/mir/Lowering.cpp.orig b/src/mir/Lowering.cpp.orig new file mode 100644 index 00000000..29bb154a --- /dev/null +++ b/src/mir/Lowering.cpp.orig @@ -0,0 +1,1811 @@ +#include "mir/MIR.h" + +#include +#include +#include +#include + +#include "ir/IR.h" +#include "utils/Log.h" + +namespace mir +{ + namespace + { + + using ValueVRegMap = std::unordered_map; + using LocalScalarMap = std::unordered_map; + using LocalArrayMap = std::unordered_map; + using BlockMap = std::unordered_map; + + static bool TryGetConstantInt(const ir::Value *value, int &out); + + static int GetTypeSize(const std::shared_ptr &type) + { + if (!type) + return 4; + if (type->IsPtrInt32() || type->IsPtrFloat32()) + return 8; + return 4; + } + + static int AlignTo(int value, int align) + { + return ((value + align - 1) / align) * align; + } + + static bool IsPointerValue(const ir::Value *value) + { + if (!value) + return false; + auto type = value->GetType(); + return type && (type->IsPtrInt32() || type->IsPtrFloat32()); + } + + static bool IsPointerType(const std::shared_ptr &type) + { + return type && (type->IsPtrInt32() || type->IsPtrFloat32()); + } + + static bool IsFloatType(const std::shared_ptr &type) + { + return type && type->IsFloat32(); + } + + static bool IsFloatValue(const ir::Value *value) + { + return value && IsFloatType(value->GetType()); + } + + static bool IsIntegerCompareOpcode(ir::Opcode opcode) + { + switch (opcode) + { + case ir::Opcode::Eq: + case ir::Opcode::Ne: + case ir::Opcode::Lt: + case ir::Opcode::Le: + case ir::Opcode::Gt: + case ir::Opcode::Ge: + return true; + default: + return false; + } + } + + static CondCode GetCondCodeForCompareOpcode(ir::Opcode opcode) + { + switch (opcode) + { + case ir::Opcode::Eq: + return CondCode::EQ; + case ir::Opcode::Ne: + return CondCode::NE; + case ir::Opcode::Lt: + return CondCode::LT; + case ir::Opcode::Le: + return CondCode::LE; + case ir::Opcode::Gt: + return CondCode::GT; + case ir::Opcode::Ge: + return CondCode::GE; + default: + throw std::runtime_error(FormatError("mir", "不支持的比较 opcode")); + } + } + + static CondCode NegateCondCode(CondCode cond) + { + switch (cond) + { + case CondCode::EQ: + return CondCode::NE; + case CondCode::NE: + return CondCode::EQ; + case CondCode::LT: + return CondCode::GE; + case CondCode::LE: + return CondCode::GT; + case CondCode::GT: + return CondCode::LE; + case CondCode::GE: + return CondCode::LT; + default: + return CondCode::NE; + } + } + + // 交换比较操作数时反转条件码(aa) + static CondCode SwapCondCode(CondCode cond) + { + switch (cond) + { + case CondCode::LT: + return CondCode::GT; + case CondCode::LE: + return CondCode::GE; + case CondCode::GT: + return CondCode::LT; + case CondCode::GE: + return CondCode::LE; + default: + return cond; // EQ/NE 对称 + } + } + + static PhysReg GetArgWReg(size_t index) + { + static const PhysReg regs[] = { + PhysReg::W0, PhysReg::W1, PhysReg::W2, PhysReg::W3, + PhysReg::W4, PhysReg::W5, PhysReg::W6, PhysReg::W7}; + return index < 8 ? regs[index] : PhysReg::W0; + } + + static PhysReg GetArgXReg(size_t index) + { + static const PhysReg regs[] = { + PhysReg::X0, PhysReg::X1, PhysReg::X2, PhysReg::X3, + PhysReg::X4, PhysReg::X5, PhysReg::X6, PhysReg::X7}; + return index < 8 ? regs[index] : PhysReg::X0; + } + + static PhysReg GetArgSReg(size_t index) + { + static const PhysReg regs[] = { + PhysReg::S0, PhysReg::S1, PhysReg::S2, PhysReg::S3, + PhysReg::S4, PhysReg::S5, PhysReg::S6, PhysReg::S7}; + return index < 8 ? regs[index] : PhysReg::S0; + } + + static bool TryGetConstantInt(const ir::Value *value, int &out) + { + if (auto *constant = dynamic_cast(value)) + { + out = constant->GetValue(); + return true; + } + return false; + } + + static int FloatToBits(float value) + { + int bits = 0; + std::memcpy(&bits, &value, sizeof(bits)); + return bits; + } + + static bool TryGetConstantFloatBits(const ir::Value *value, int &out) + { + if (auto *constant = dynamic_cast(value)) + { + out = FloatToBits(static_cast(constant->GetValue())); + return true; + } + return false; + } + + static const ir::GlobalVariable *AsGlobalScalarObject(const ir::Value *value) + { + auto *global = dynamic_cast(value); + if (!global) + return nullptr; + if (global->IsArray()) + return nullptr; + if (!IsPointerValue(global)) + return nullptr; + return global; + } + + static const ir::GlobalVariable *AsGlobalArrayObject(const ir::Value *value) + { + auto *global = dynamic_cast(value); + if (!global) + return nullptr; + if (!global->IsArray()) + return nullptr; + if (!IsPointerValue(global)) + return nullptr; + return global; + } + + static bool IsZeroIntConstant(const ir::Value *value) + { + int imm = 0; + return TryGetConstantInt(value, imm) && imm == 0; + } + + [[maybe_unused]] static bool IsSolelyConsumedByCondBr(const ir::Instruction &inst) + { + const auto &uses = inst.GetUses(); + if (uses.size() != 1) + return false; + auto *user = uses.front().GetUser(); + return dynamic_cast(user) != nullptr; + } + + static bool IsSolelyConsumedByCanonicalBoolUse(const ir::Instruction &inst) + { + const auto &uses = inst.GetUses(); + if (uses.size() != 1) + return false; + auto *user_inst = dynamic_cast(uses.front().GetUser()); + if (!user_inst) + return false; + if (dynamic_cast(user_inst)) + return true; + if (auto *cast = dynamic_cast(user_inst)) + return cast->GetOpcode() == ir::Opcode::ZExt; + auto *bin = dynamic_cast(user_inst); + if (!bin) + return false; + if (bin->GetOpcode() != ir::Opcode::Eq && bin->GetOpcode() != ir::Opcode::Ne) + return false; + return (bin->GetLhs() == &inst && IsZeroIntConstant(bin->GetRhs())) || + (bin->GetRhs() == &inst && IsZeroIntConstant(bin->GetLhs())); + } + + static bool TryResolveDirectScalarSlot(const ir::Value *ptr, + const LocalScalarMap &scalar_slots, + int &out_slot) + { + auto it = scalar_slots.find(ptr); + if (it != scalar_slots.end()) + { + out_slot = it->second; + return true; + } + auto *gep = dynamic_cast(ptr); + if (!gep) + return false; + int idx = 0; + if (!TryGetConstantInt(gep->GetIndex(), idx)) + return false; + if (idx != 0) + return false; + auto base_it = scalar_slots.find(gep->GetBasePtr()); + if (base_it != scalar_slots.end()) + { + out_slot = base_it->second; + return true; + } + return false; + } + + static int EmitIntValue(const ir::Value *value, MachineFunction &function, + ValueVRegMap &value_vregs, const LocalScalarMap &scalar_slots, + const LocalArrayMap &array_slots, MachineBasicBlock &block); + + static int EmitFloatValue(const ir::Value *value, MachineFunction &function, + ValueVRegMap &value_vregs, MachineBasicBlock &block); + + static int EmitPtrValue(const ir::Value *value, MachineFunction &function, + ValueVRegMap &value_vregs, const LocalScalarMap &scalar_slots, + const LocalArrayMap &array_slots, MachineBasicBlock &block); + + static int EmitIntValue(const ir::Value *value, MachineFunction &function, + ValueVRegMap &value_vregs, const LocalScalarMap &scalar_slots, + const LocalArrayMap &array_slots, MachineBasicBlock &block) + { + if (!value) + { + int vreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::MovImm, {Operand::VReg(vreg, VRegClass::Int), Operand::Imm(0)}).SetRematerializable(true).SetRematImm(0); + return vreg; + } + + auto it = value_vregs.find(value); + if (it != value_vregs.end()) + { + if (function.GetVRegClass(it->second) == VRegClass::Float) + { + int dst = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::FCvtzs, + {Operand::VReg(dst, VRegClass::Int), Operand::VReg(it->second, VRegClass::Float)}); + return dst; + } + return it->second; + } + + int imm = 0; + if (TryGetConstantInt(value, imm)) + { + int vreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::MovImm, {Operand::VReg(vreg, VRegClass::Int), Operand::Imm(imm)}).SetRematerializable(true).SetRematImm(imm); + value_vregs[value] = vreg; + return vreg; + } + + if (TryGetConstantFloatBits(value, imm)) + { + int vreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::MovImm, {Operand::VReg(vreg, VRegClass::Int), Operand::Imm(imm)}).SetRematerializable(true).SetRematImm(imm); + value_vregs[value] = vreg; + return vreg; + } + + if (auto *global = AsGlobalScalarObject(value)) + { + int vreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::LoadGlobal, + {Operand::VReg(vreg, VRegClass::Int), Operand::Symbol(global->GetName())}); + value_vregs[value] = vreg; + return vreg; + } + + if (auto *cast = dynamic_cast(value)) + { + if (cast->GetOpcode() == ir::Opcode::ZExt) + { + int src = EmitIntValue(cast->GetOperandValue(), function, value_vregs, + scalar_slots, array_slots, block); + value_vregs[value] = src; + return src; + } + if (cast->GetOpcode() == ir::Opcode::FPToSI) + { + int src = EmitFloatValue(cast->GetOperandValue(), function, value_vregs, block); + int dst = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::FCvtzs, + {Operand::VReg(dst, VRegClass::Int), Operand::VReg(src, VRegClass::Float)}); + value_vregs[value] = dst; + return dst; + } + } + + if (auto *bin = dynamic_cast(value)) + { + if (IsIntegerCompareOpcode(bin->GetOpcode())) + { + // 常量折叠到 CmpImm,消除冗余 MovImm + int lhs_imm, rhs_imm; + bool lhs_const = TryGetConstantInt(bin->GetLhs(), lhs_imm); + bool rhs_const = TryGetConstantInt(bin->GetRhs(), rhs_imm); + auto imm_fits = [](int imm) { return imm >= 0 && imm <= 4095; }; + + CondCode cond = GetCondCodeForCompareOpcode(bin->GetOpcode()); + + if (rhs_const && imm_fits(rhs_imm)) + { + int lhs = EmitIntValue(bin->GetLhs(), function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::CmpImm, + {Operand::VReg(lhs, VRegClass::Int), Operand::Imm(rhs_imm)}); + } + else if (lhs_const && imm_fits(lhs_imm)) + { + int rhs = EmitIntValue(bin->GetRhs(), function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::CmpImm, + {Operand::VReg(rhs, VRegClass::Int), Operand::Imm(lhs_imm)}); + cond = SwapCondCode(cond); + } + else + { + int lhs = EmitIntValue(bin->GetLhs(), function, value_vregs, + scalar_slots, array_slots, block); + int rhs = EmitIntValue(bin->GetRhs(), function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::CmpRR, + {Operand::VReg(lhs, VRegClass::Int), Operand::VReg(rhs, VRegClass::Int)}); + } + + int dst = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::CSet, + {Operand::VReg(dst, VRegClass::Int), + Operand::Imm(static_cast(cond))}); + value_vregs[value] = dst; + return dst; + } + + if (IsFloatType(bin->GetType())) + { + return EmitFloatValue(value, function, value_vregs, block); + } + + Opcode opcode = Opcode::AddRR; + switch (bin->GetOpcode()) + { + case ir::Opcode::Add: + opcode = Opcode::AddRR; + break; + case ir::Opcode::Sub: + opcode = Opcode::SubRR; + break; + case ir::Opcode::Mul: + opcode = Opcode::MulRR; + break; + case ir::Opcode::Div: + opcode = Opcode::DivRR; + break; + case ir::Opcode::Mod: + opcode = Opcode::ModRR; + break; + default: + break; + } + + int lhs = EmitIntValue(bin->GetLhs(), function, value_vregs, + scalar_slots, array_slots, block); + int rhs = EmitIntValue(bin->GetRhs(), function, value_vregs, + scalar_slots, array_slots, block); + int dst = function.CreateVReg(VRegClass::Int); + + if (opcode == Opcode::MulRR) + { + auto *rhs_const = dynamic_cast(bin->GetRhs()); + if (rhs_const) + { + int val = rhs_const->GetValue(); + if (val > 0 && (val & (val - 1)) == 0) + { + int shift = 0; + while (val > 1) + { + val >>= 1; + ++shift; + } + block.Append(Opcode::ShlRR, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(shift)}); + value_vregs[value] = dst; + return dst; + } + } + } + + if (opcode == Opcode::DivRR) + { + auto *rhs_const = dynamic_cast(bin->GetRhs()); + if (rhs_const) + { + int val = rhs_const->GetValue(); + if (val == 1) + { + value_vregs[value] = lhs; + return lhs; + } + if (val == -1) + { + block.Append(Opcode::NegRR, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int)}); + value_vregs[value] = dst; + return dst; + } + // 2的幂次除法(含正负)改用 sdiv,比移位序列更短 + } + } + + if (opcode == Opcode::ModRR) + { + auto *rhs_const = dynamic_cast(bin->GetRhs()); + if (rhs_const) + { + int val = rhs_const->GetValue(); + // x % 1 == 0, x % -1 == 0 + if (val == 1 || val == -1) + { + block.Append(Opcode::MovImm, + {Operand::VReg(dst, VRegClass::Int), + Operand::Imm(0)}).SetRematerializable(true).SetRematImm(0); + value_vregs[value] = dst; + return dst; + } + // 2的幂次取模(含正负)改用 ModRR(sdiv+msub),比移位序列更短 + } + } + + // Add/Sub 常量折叠到立即数操作码 + int rhs_imm_val; + bool rhs_is_imm = false; + if ((opcode == Opcode::AddRR || opcode == Opcode::SubRR) && + bin->GetRhs() && TryGetConstantInt(bin->GetRhs(), rhs_imm_val) && + rhs_imm_val >= 0 && rhs_imm_val <= 4095) + { + rhs_is_imm = true; + if (opcode == Opcode::AddRR) + opcode = Opcode::AddImm; + else + opcode = Opcode::SubImm; + block.Append(opcode, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(rhs_imm_val)}); + } + else + { + block.Append(opcode, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::VReg(rhs, VRegClass::Int)}); + } + value_vregs[value] = dst; + return dst; + } + + if (auto *phi = dynamic_cast(value)) + { + auto phi_it = value_vregs.find(value); + if (phi_it != value_vregs.end()) + return phi_it->second; + } + + if (auto *load = dynamic_cast(value)) + { + int scalar_slot = -1; + if (TryResolveDirectScalarSlot(load->GetPtr(), scalar_slots, scalar_slot)) + { + int vreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::LoadStack, + {Operand::VReg(vreg, VRegClass::Int), Operand::FrameIndex(scalar_slot)}); + value_vregs[value] = vreg; + return vreg; + } + + if (auto *global = AsGlobalScalarObject(load->GetPtr())) + { + int vreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::LoadGlobal, + {Operand::VReg(vreg, VRegClass::Int), Operand::Symbol(global->GetName())}); + value_vregs[value] = vreg; + return vreg; + } + + int addr = EmitPtrValue(load->GetPtr(), function, value_vregs, + scalar_slots, array_slots, block); + int vreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::LoadMem, + {Operand::VReg(vreg, VRegClass::Int), Operand::VReg(addr, VRegClass::Ptr)}); + value_vregs[value] = vreg; + return vreg; + } + + int vreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::MovImm, {Operand::VReg(vreg, VRegClass::Int), Operand::Imm(0)}).SetRematerializable(true).SetRematImm(0); + value_vregs[value] = vreg; + return vreg; + } + + static int EmitFloatValue(const ir::Value *value, MachineFunction &function, + ValueVRegMap &value_vregs, MachineBasicBlock &block) + { + if (!value) + { + int vreg = function.CreateVReg(VRegClass::Float); + int wvreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::MovImm, {Operand::VReg(wvreg, VRegClass::Int), Operand::Imm(0)}).SetRematerializable(true).SetRematImm(0); + block.Append(Opcode::FMovWS, + {Operand::VReg(vreg, VRegClass::Float), Operand::VReg(wvreg, VRegClass::Int)}); + return vreg; + } + + auto it = value_vregs.find(value); + if (it != value_vregs.end()) + { + if (function.GetVRegClass(it->second) != VRegClass::Float) + { + int dst = function.CreateVReg(VRegClass::Float); + block.Append(Opcode::Scvtf, + {Operand::VReg(dst, VRegClass::Float), Operand::VReg(it->second, VRegClass::Int)}); + return dst; + } + return it->second; + } + + int bits = 0; + if (TryGetConstantFloatBits(value, bits)) + { + int wvreg = function.CreateVReg(VRegClass::Int); + int vreg = function.CreateVReg(VRegClass::Float); + block.Append(Opcode::MovImm, {Operand::VReg(wvreg, VRegClass::Int), Operand::Imm(bits)}).SetRematerializable(true).SetRematImm(bits); + block.Append(Opcode::FMovWS, + {Operand::VReg(vreg, VRegClass::Float), Operand::VReg(wvreg, VRegClass::Int)}); + value_vregs[value] = vreg; + return vreg; + } + + if (auto *global = AsGlobalScalarObject(value)) + { + int vreg = function.CreateVReg(VRegClass::Float); + block.Append(Opcode::LoadGlobal, + {Operand::VReg(vreg, VRegClass::Float), Operand::Symbol(global->GetName())}); + value_vregs[value] = vreg; + return vreg; + } + + if (auto *cast = dynamic_cast(value)) + { + if (cast->GetOpcode() == ir::Opcode::SIToFP) + { + int src = EmitIntValue(cast->GetOperandValue(), function, value_vregs, + LocalScalarMap(), LocalArrayMap(), block); + int dst = function.CreateVReg(VRegClass::Float); + block.Append(Opcode::Scvtf, + {Operand::VReg(dst, VRegClass::Float), Operand::VReg(src, VRegClass::Int)}); + value_vregs[value] = dst; + return dst; + } + } + + if (auto *bin = dynamic_cast(value)) + { + if (IsFloatType(bin->GetType())) + { + Opcode opcode = Opcode::FAddRR; + switch (bin->GetOpcode()) + { + case ir::Opcode::Add: + opcode = Opcode::FAddRR; + break; + case ir::Opcode::Sub: + opcode = Opcode::FSubRR; + break; + case ir::Opcode::Mul: + opcode = Opcode::FMulRR; + break; + case ir::Opcode::Div: + opcode = Opcode::FDivRR; + break; + default: + break; + } + + int lhs = EmitFloatValue(bin->GetLhs(), function, value_vregs, block); + int rhs = EmitFloatValue(bin->GetRhs(), function, value_vregs, block); + int dst = function.CreateVReg(VRegClass::Float); + block.Append(opcode, + {Operand::VReg(dst, VRegClass::Float), + Operand::VReg(lhs, VRegClass::Float), + Operand::VReg(rhs, VRegClass::Float)}); + value_vregs[value] = dst; + return dst; + } + } + + if (auto *phi = dynamic_cast(value)) + { + auto phi_it = value_vregs.find(value); + if (phi_it != value_vregs.end()) + return phi_it->second; + } + + if (auto *load = dynamic_cast(value)) + { + int scalar_slot = -1; + LocalScalarMap dummy_scalar; + if (TryResolveDirectScalarSlot(load->GetPtr(), dummy_scalar, scalar_slot)) + { + int vreg = function.CreateVReg(VRegClass::Float); + block.Append(Opcode::LoadStack, + {Operand::VReg(vreg, VRegClass::Float), Operand::FrameIndex(scalar_slot)}); + value_vregs[value] = vreg; + return vreg; + } + } + + int vreg = function.CreateVReg(VRegClass::Float); + int wvreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::MovImm, {Operand::VReg(wvreg, VRegClass::Int), Operand::Imm(0)}).SetRematerializable(true).SetRematImm(0); + block.Append(Opcode::FMovWS, + {Operand::VReg(vreg, VRegClass::Float), Operand::VReg(wvreg, VRegClass::Int)}); + value_vregs[value] = vreg; + return vreg; + } + + static int EmitPtrValue(const ir::Value *value, MachineFunction &function, + ValueVRegMap &value_vregs, const LocalScalarMap &scalar_slots, + const LocalArrayMap &array_slots, MachineBasicBlock &block) + { + auto it = value_vregs.find(value); + if (it != value_vregs.end()) + return it->second; + + if (auto *global_scalar = AsGlobalScalarObject(value)) + { + int vreg = function.CreateVReg(VRegClass::Ptr); + block.Append(Opcode::LoadGlobalAddr, + {Operand::VReg(vreg, VRegClass::Ptr), Operand::Symbol(global_scalar->GetName())}); + value_vregs[value] = vreg; + return vreg; + } + + if (auto *global_array = AsGlobalArrayObject(value)) + { + int vreg = function.CreateVReg(VRegClass::Ptr); + block.Append(Opcode::LoadGlobalAddr, + {Operand::VReg(vreg, VRegClass::Ptr), Operand::Symbol(global_array->GetName())}); + value_vregs[value] = vreg; + return vreg; + } + + auto scalar_it = scalar_slots.find(value); + if (scalar_it != scalar_slots.end()) + { + int vreg = function.CreateVReg(VRegClass::Ptr); + block.Append(Opcode::LoadStackAddr, + {Operand::VReg(vreg, VRegClass::Ptr), Operand::FrameIndex(scalar_it->second)}); + value_vregs[value] = vreg; + return vreg; + } + + auto array_it = array_slots.find(value); + if (array_it != array_slots.end()) + { + int vreg = function.CreateVReg(VRegClass::Ptr); + block.Append(Opcode::LoadStackAddr, + {Operand::VReg(vreg, VRegClass::Ptr), Operand::FrameIndex(array_it->second)}); + value_vregs[value] = vreg; + return vreg; + } + + auto *gep = dynamic_cast(value); + if (gep) + { + int base = EmitPtrValue(gep->GetBasePtr(), function, value_vregs, + scalar_slots, array_slots, block); + + int idx_imm = 0; + if (TryGetConstantInt(gep->GetIndex(), idx_imm)) + { + const int byte_offset = static_cast(static_cast(idx_imm) * 4u); + if (byte_offset == 0) + { + value_vregs[value] = base; + return base; + } + int dst = function.CreateVReg(VRegClass::Ptr); + int offset_vreg = function.CreateVReg(VRegClass::Ptr); + int abs_off = byte_offset > 0 ? byte_offset : -byte_offset; + block.Append(Opcode::MovImm, {Operand::VReg(offset_vreg, VRegClass::Ptr), Operand::Imm(abs_off)}).SetRematerializable(true).SetRematImm(abs_off); + if (byte_offset > 0) + { + block.Append(Opcode::AddRR, + {Operand::VReg(dst, VRegClass::Ptr), + Operand::VReg(base, VRegClass::Ptr), + Operand::VReg(offset_vreg, VRegClass::Ptr)}); + } + else + { + block.Append(Opcode::SubRR, + {Operand::VReg(dst, VRegClass::Ptr), + Operand::VReg(base, VRegClass::Ptr), + Operand::VReg(offset_vreg, VRegClass::Ptr)}); + } + value_vregs[value] = dst; + return dst; + } + + int idx = EmitIntValue(gep->GetIndex(), function, value_vregs, + scalar_slots, array_slots, block); + int sext = function.CreateVReg(VRegClass::Ptr); + block.Append(Opcode::Sxtw, + {Operand::VReg(sext, VRegClass::Ptr), Operand::VReg(idx, VRegClass::Int)}); + int shifted = function.CreateVReg(VRegClass::Ptr); + block.Append(Opcode::ShlRR, + {Operand::VReg(shifted, VRegClass::Ptr), + Operand::VReg(sext, VRegClass::Ptr), + Operand::Imm(2)}); + int dst = function.CreateVReg(VRegClass::Ptr); + block.Append(Opcode::AddRR, + {Operand::VReg(dst, VRegClass::Ptr), + Operand::VReg(base, VRegClass::Ptr), + Operand::VReg(shifted, VRegClass::Ptr)}); + value_vregs[value] = dst; + return dst; + } + + if (IsPointerValue(value)) + { + auto vreg_it = value_vregs.find(value); + if (vreg_it != value_vregs.end()) + return vreg_it->second; + } + + int vreg = function.CreateVReg(VRegClass::Ptr); + block.Append(Opcode::MovImm, {Operand::VReg(vreg, VRegClass::Ptr), Operand::Imm(0)}).SetRematerializable(true).SetRematImm(0); + value_vregs[value] = vreg; + return vreg; + } + + static void EmitCompareToFlags(const ir::BinaryInst &bin, + MachineFunction &function, + ValueVRegMap &value_vregs, + const LocalScalarMap &scalar_slots, + const LocalArrayMap &array_slots, + MachineBasicBlock &block) + { + if (IsFloatValue(bin.GetLhs()) || IsFloatValue(bin.GetRhs())) + { + int lhs = EmitFloatValue(bin.GetLhs(), function, value_vregs, block); + int rhs = EmitFloatValue(bin.GetRhs(), function, value_vregs, block); + block.Append(Opcode::FCmpRR, + {Operand::VReg(lhs, VRegClass::Float), Operand::VReg(rhs, VRegClass::Float)}); + return; + } + + // 常量折叠到 CmpImm + int lhs_imm, rhs_imm; + bool lhs_const = TryGetConstantInt(bin.GetLhs(), lhs_imm); + bool rhs_const = TryGetConstantInt(bin.GetRhs(), rhs_imm); + auto imm_fits = [](int imm) { return imm >= 0 && imm <= 4095; }; + + if (rhs_const && imm_fits(rhs_imm)) + { + int lhs = EmitIntValue(bin.GetLhs(), function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::CmpImm, + {Operand::VReg(lhs, VRegClass::Int), Operand::Imm(rhs_imm)}); + } + else if (lhs_const && imm_fits(lhs_imm)) + { + int rhs = EmitIntValue(bin.GetRhs(), function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::CmpImm, + {Operand::VReg(rhs, VRegClass::Int), Operand::Imm(lhs_imm)}); + } + else + { + int lhs = EmitIntValue(bin.GetLhs(), function, value_vregs, + scalar_slots, array_slots, block); + int rhs = EmitIntValue(bin.GetRhs(), function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::CmpRR, + {Operand::VReg(lhs, VRegClass::Int), Operand::VReg(rhs, VRegClass::Int)}); + } + } + + static bool TryEmitCondValueToFlags(const ir::Value *value, + MachineFunction &function, + ValueVRegMap &value_vregs, + const LocalScalarMap &scalar_slots, + const LocalArrayMap &array_slots, + MachineBasicBlock &block, + CondCode &true_cond, int depth = 0) + { + if (!value || depth > 8) + return false; + + if (auto *cast = dynamic_cast(value)) + { + if (cast->GetOpcode() == ir::Opcode::ZExt) + { + return TryEmitCondValueToFlags(cast->GetOperandValue(), + function, value_vregs, scalar_slots, array_slots, + block, true_cond, depth + 1); + } + } + + if (auto *bin = dynamic_cast(value)) + { + if (IsIntegerCompareOpcode(bin->GetOpcode())) + { + if (bin->GetOpcode() == ir::Opcode::Eq || bin->GetOpcode() == ir::Opcode::Ne) + { + const ir::Value *inner = nullptr; + if (IsZeroIntConstant(bin->GetLhs())) + inner = bin->GetRhs(); + else if (IsZeroIntConstant(bin->GetRhs())) + inner = bin->GetLhs(); + + if (inner) + { + CondCode inner_cond = CondCode::NE; + if (TryEmitCondValueToFlags(inner, function, value_vregs, + scalar_slots, array_slots, + block, inner_cond, depth + 1)) + { + true_cond = (bin->GetOpcode() == ir::Opcode::Eq) + ? NegateCondCode(inner_cond) + : inner_cond; + return true; + } + } + } + + EmitCompareToFlags(*bin, function, value_vregs, scalar_slots, array_slots, block); + true_cond = GetCondCodeForCompareOpcode(bin->GetOpcode()); + return true; + } + } + + if (IsFloatValue(value)) + { + int vreg = EmitFloatValue(value, function, value_vregs, block); + int zero_w = function.CreateVReg(VRegClass::Int); + int zero_s = function.CreateVReg(VRegClass::Float); + block.Append(Opcode::MovImm, {Operand::VReg(zero_w, VRegClass::Int), Operand::Imm(0)}).SetRematerializable(true).SetRematImm(0); + block.Append(Opcode::FMovWS, + {Operand::VReg(zero_s, VRegClass::Float), Operand::VReg(zero_w, VRegClass::Int)}); + block.Append(Opcode::FCmpRR, + {Operand::VReg(vreg, VRegClass::Float), Operand::VReg(zero_s, VRegClass::Float)}); + true_cond = CondCode::NE; + return true; + } + + int vreg = EmitIntValue(value, function, value_vregs, scalar_slots, array_slots, block); + int zero = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::MovImm, {Operand::VReg(zero, VRegClass::Int), Operand::Imm(0)}).SetRematerializable(true).SetRematImm(0); + block.Append(Opcode::CmpRR, + {Operand::VReg(vreg, VRegClass::Int), Operand::VReg(zero, VRegClass::Int)}); + true_cond = CondCode::NE; + return true; + } + + static void EmitStackPointerAdjust(MachineBasicBlock &block, Opcode opcode, int amount) + { + if (amount <= 0) + return; + block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X14), Operand::Imm(amount)}); + block.Append(opcode, {Operand::Reg(PhysReg::SP), Operand::Reg(PhysReg::SP), Operand::Reg(PhysReg::X14)}); + } + + static int ComputeStackArgumentBytes(const ir::CallInst &call) + { + int total = 0; + size_t gp_idx = 0; + size_t fp_idx = 0; + for (size_t i = 0; i < call.GetNumArgs(); ++i) + { + auto *arg = call.GetArg(i); + auto type = arg ? arg->GetType() : nullptr; + if (IsFloatType(type)) + { + if (fp_idx < 8) + ++fp_idx; + else + total += 8; + } + else + { + if (gp_idx < 8) + ++gp_idx; + else + total += 8; + } + } + return total; + } + + static void LowerFunctionParams(const ir::Function &function, + MachineFunction &machine_func, + ValueVRegMap &value_vregs) + { + if (!machine_func.GetEntryPtr()) + return; + + auto &entry = machine_func.GetEntry(); + const auto ¶ms = function.GetParams(); + size_t gp_idx = 0; + size_t fp_idx = 0; + int callee_stack_offset = 0; + + for (const auto ¶m : params) + { + if (!param) + continue; + + if (IsFloatType(param->GetType())) + { + if (fp_idx < 8) + { + int vreg = machine_func.CreateVReg(VRegClass::Float); + entry.Append(Opcode::MovReg, + {Operand::VReg(vreg, VRegClass::Float), + Operand::Reg(GetArgSReg(fp_idx))}); + value_vregs[param.get()] = vreg; + ++fp_idx; + } + else + { + const int arg_slot = machine_func.CreateCalleeStackArgFrameIndex(GetTypeSize(param->GetType())); + machine_func.GetFrameSlot(arg_slot).offset = callee_stack_offset; + callee_stack_offset += 8; + int vreg = machine_func.CreateVReg(VRegClass::Float); + entry.Append(Opcode::LoadStack, + {Operand::VReg(vreg, VRegClass::Float), Operand::FrameIndex(arg_slot)}); + value_vregs[param.get()] = vreg; + } + } + else if (IsPointerType(param->GetType())) + { + if (gp_idx < 8) + { + int vreg = machine_func.CreateVReg(VRegClass::Ptr); + entry.Append(Opcode::MovReg, + {Operand::VReg(vreg, VRegClass::Ptr), + Operand::Reg(GetArgXReg(gp_idx))}); + value_vregs[param.get()] = vreg; + ++gp_idx; + } + else + { + const int arg_slot = machine_func.CreateCalleeStackArgFrameIndex(GetTypeSize(param->GetType())); + machine_func.GetFrameSlot(arg_slot).offset = callee_stack_offset; + callee_stack_offset += 8; + int vreg = machine_func.CreateVReg(VRegClass::Ptr); + entry.Append(Opcode::LoadStack, + {Operand::VReg(vreg, VRegClass::Ptr), Operand::FrameIndex(arg_slot)}); + value_vregs[param.get()] = vreg; + } + } + else + { + if (gp_idx < 8) + { + int vreg = machine_func.CreateVReg(VRegClass::Int); + entry.Append(Opcode::MovReg, + {Operand::VReg(vreg, VRegClass::Int), + Operand::Reg(GetArgWReg(gp_idx))}); + value_vregs[param.get()] = vreg; + ++gp_idx; + } + else + { + const int arg_slot = machine_func.CreateCalleeStackArgFrameIndex(GetTypeSize(param->GetType())); + machine_func.GetFrameSlot(arg_slot).offset = callee_stack_offset; + callee_stack_offset += 8; + int vreg = machine_func.CreateVReg(VRegClass::Int); + entry.Append(Opcode::LoadStack, + {Operand::VReg(vreg, VRegClass::Int), Operand::FrameIndex(arg_slot)}); + value_vregs[param.get()] = vreg; + } + } + } + } + + static void EmitPhiValueStores(const ir::BasicBlock *src_bb, + const ir::BasicBlock *dst_bb, + MachineFunction &function, + const ValueVRegMap &value_vregs, + const LocalScalarMap &scalar_slots, + const LocalArrayMap &array_slots, + MachineBasicBlock &mir_block) + { + if (!src_bb || !dst_bb) + return; + + for (const auto &inst_ptr : dst_bb->GetInstructions()) + { + auto *phi = dynamic_cast(inst_ptr.get()); + if (!phi) + break; + + auto phi_it = value_vregs.find(phi); + if (phi_it == value_vregs.end()) + continue; + int phi_vreg = phi_it->second; + + const ir::Value *incoming_value = nullptr; + size_t num_ops = phi->GetNumOperands(); + for (size_t i = 0; i + 1 < num_ops; i += 2) + { + auto *val = phi->GetOperand(i); + auto *bb_ptr = dynamic_cast(phi->GetOperand(i + 1)); + if (bb_ptr && bb_ptr == src_bb) + { + incoming_value = val; + break; + } + } + + if (!incoming_value) + continue; + + VRegClass phi_class = function.GetVRegClass(phi_vreg); + + if (phi_class == VRegClass::Float) + { + int src = EmitFloatValue(incoming_value, function, + const_cast(value_vregs), mir_block); + if (phi_vreg == src) { + // self-referencing PHI, skip + } else { + mir_block.Append(Opcode::MovReg, + {Operand::VReg(phi_vreg, VRegClass::Float), + Operand::VReg(src, VRegClass::Float)}); + } + } + else if (phi_class == VRegClass::Ptr) + { + int src = EmitPtrValue(incoming_value, function, + const_cast(value_vregs), + scalar_slots, array_slots, mir_block); + if (phi_vreg == src) { + // self-referencing PHI, skip + } else { + mir_block.Append(Opcode::MovReg, + {Operand::VReg(phi_vreg, VRegClass::Ptr), + Operand::VReg(src, VRegClass::Ptr)}); + } + } + else + { + int src = EmitIntValue(incoming_value, function, + const_cast(value_vregs), + scalar_slots, array_slots, mir_block); + if (phi_vreg == src) { + // self-referencing PHI, skip + } else { + mir_block.Append(Opcode::MovReg, + {Operand::VReg(phi_vreg, VRegClass::Int), + Operand::VReg(src, VRegClass::Int)}); + } + } + } + } + + static void LowerInstruction(const ir::Instruction &inst, + MachineFunction &function, + ValueVRegMap &value_vregs, + LocalScalarMap &scalar_slots, + LocalArrayMap &array_slots, + const BlockMap &block_map, + MachineBasicBlock &block) + { + switch (inst.GetOpcode()) + { + case ir::Opcode::Alloca: + { + auto &alloca = static_cast(inst); + const int elem_size = GetTypeSize(alloca.GetElementType()); + + if (alloca.IsArrayAlloca()) + { + int count = 0; + if (TryGetConstantInt(alloca.GetCount(), count) && count > 0) + array_slots[&inst] = function.CreateFrameIndex(elem_size * count); + else + array_slots[&inst] = function.CreateFrameIndex(elem_size); + } + else + { + const int slot = function.CreateFrameIndex(elem_size); + scalar_slots[&inst] = slot; + } + return; + } + + case ir::Opcode::Load: + { + auto &load = static_cast(inst); + const bool is_ptr = IsPointerType(load.GetType()); + const bool is_float = IsFloatType(load.GetType()); + + int scalar_slot = -1; + if (TryResolveDirectScalarSlot(load.GetPtr(), scalar_slots, scalar_slot)) + { + if (is_ptr) + { + int vreg = function.CreateVReg(VRegClass::Ptr); + block.Append(Opcode::LoadStack, + {Operand::VReg(vreg, VRegClass::Ptr), Operand::FrameIndex(scalar_slot)}); + value_vregs[&load] = vreg; + } + else if (is_float) + { + int vreg = function.CreateVReg(VRegClass::Float); + block.Append(Opcode::LoadStack, + {Operand::VReg(vreg, VRegClass::Float), Operand::FrameIndex(scalar_slot)}); + value_vregs[&load] = vreg; + } + else + { + int vreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::LoadStack, + {Operand::VReg(vreg, VRegClass::Int), Operand::FrameIndex(scalar_slot)}); + value_vregs[&load] = vreg; + } + return; + } + + if (!is_ptr) + { + if (auto *global = AsGlobalScalarObject(load.GetPtr())) + { + if (is_float) + { + int vreg = function.CreateVReg(VRegClass::Float); + block.Append(Opcode::LoadGlobal, + {Operand::VReg(vreg, VRegClass::Float), Operand::Symbol(global->GetName())}); + value_vregs[&load] = vreg; + } + else + { + int vreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::LoadGlobal, + {Operand::VReg(vreg, VRegClass::Int), Operand::Symbol(global->GetName())}); + value_vregs[&load] = vreg; + } + return; + } + } + + int addr = EmitPtrValue(load.GetPtr(), function, value_vregs, + scalar_slots, array_slots, block); + if (is_ptr) + { + int vreg = function.CreateVReg(VRegClass::Ptr); + block.Append(Opcode::LoadMem, + {Operand::VReg(vreg, VRegClass::Ptr), Operand::VReg(addr, VRegClass::Ptr)}); + value_vregs[&load] = vreg; + } + else if (is_float) + { + int vreg = function.CreateVReg(VRegClass::Float); + block.Append(Opcode::LoadMem, + {Operand::VReg(vreg, VRegClass::Float), Operand::VReg(addr, VRegClass::Ptr)}); + value_vregs[&load] = vreg; + } + else + { + int vreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::LoadMem, + {Operand::VReg(vreg, VRegClass::Int), Operand::VReg(addr, VRegClass::Ptr)}); + value_vregs[&load] = vreg; + } + return; + } + + case ir::Opcode::Store: + { + auto &store = static_cast(inst); + const bool value_is_ptr = IsPointerType(store.GetValue()->GetType()); + const bool value_is_float = IsFloatType(store.GetValue()->GetType()); + + int scalar_slot = -1; + if (TryResolveDirectScalarSlot(store.GetPtr(), scalar_slots, scalar_slot)) + { + if (value_is_ptr) + { + int vreg = EmitPtrValue(store.GetValue(), function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::StoreStack, + {Operand::VReg(vreg, VRegClass::Ptr), Operand::FrameIndex(scalar_slot)}); + } + else if (value_is_float) + { + int vreg = EmitFloatValue(store.GetValue(), function, value_vregs, block); + block.Append(Opcode::StoreStack, + {Operand::VReg(vreg, VRegClass::Float), Operand::FrameIndex(scalar_slot)}); + } + else + { + int vreg = EmitIntValue(store.GetValue(), function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::StoreStack, + {Operand::VReg(vreg, VRegClass::Int), Operand::FrameIndex(scalar_slot)}); + } + return; + } + + if (!value_is_ptr) + { + if (auto *global = AsGlobalScalarObject(store.GetPtr())) + { + if (value_is_float) + { + int vreg = EmitFloatValue(store.GetValue(), function, value_vregs, block); + block.Append(Opcode::StoreGlobal, + {Operand::VReg(vreg, VRegClass::Float), Operand::Symbol(global->GetName())}); + } + else + { + int vreg = EmitIntValue(store.GetValue(), function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::StoreGlobal, + {Operand::VReg(vreg, VRegClass::Int), Operand::Symbol(global->GetName())}); + } + return; + } + } + + int addr = EmitPtrValue(store.GetPtr(), function, value_vregs, + scalar_slots, array_slots, block); + if (value_is_ptr) + { + int val = EmitPtrValue(store.GetValue(), function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::StoreMem, + {Operand::VReg(val, VRegClass::Ptr), Operand::VReg(addr, VRegClass::Ptr)}); + } + else if (value_is_float) + { + int val = EmitFloatValue(store.GetValue(), function, value_vregs, block); + block.Append(Opcode::StoreMem, + {Operand::VReg(val, VRegClass::Float), Operand::VReg(addr, VRegClass::Ptr)}); + } + else + { + int val = EmitIntValue(store.GetValue(), function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::StoreMem, + {Operand::VReg(val, VRegClass::Int), Operand::VReg(addr, VRegClass::Ptr)}); + } + return; + } + + case ir::Opcode::GEP: + return; + + case ir::Opcode::Add: + case ir::Opcode::Sub: + case ir::Opcode::Mul: + case ir::Opcode::Div: + case ir::Opcode::Mod: + { + auto &bin = static_cast(inst); + if (IsFloatType(bin.GetType())) + { + EmitFloatValue(&bin, function, value_vregs, block); + return; + } + EmitIntValue(&bin, function, value_vregs, scalar_slots, array_slots, block); + return; + } + + case ir::Opcode::Eq: + case ir::Opcode::Ne: + case ir::Opcode::Lt: + case ir::Opcode::Le: + case ir::Opcode::Gt: + case ir::Opcode::Ge: + { + auto &bin = static_cast(inst); + if (IsSolelyConsumedByCanonicalBoolUse(bin)) + return; + EmitIntValue(&bin, function, value_vregs, scalar_slots, array_slots, block); + return; + } + + case ir::Opcode::SIToFP: + case ir::Opcode::FPToSI: + case ir::Opcode::ZExt: + { + auto &cast = static_cast(inst); + if (inst.GetOpcode() == ir::Opcode::ZExt) + { + if (IsSolelyConsumedByCanonicalBoolUse(cast)) + return; + } + if (inst.GetOpcode() == ir::Opcode::SIToFP) + { + EmitFloatValue(&inst, function, value_vregs, block); + return; + } + if (inst.GetOpcode() == ir::Opcode::FPToSI) + { + EmitIntValue(&inst, function, value_vregs, scalar_slots, array_slots, block); + return; + } + if (inst.GetOpcode() == ir::Opcode::ZExt) + { + EmitIntValue(&inst, function, value_vregs, scalar_slots, array_slots, block); + return; + } + return; + } + + case ir::Opcode::Phi: + return; + + case ir::Opcode::Br: + { + auto &br = static_cast(inst); + auto it = block_map.find(br.GetTarget()); + if (it != block_map.end() && it->second) + { + EmitPhiValueStores(inst.GetParent(), br.GetTarget(), function, + value_vregs, scalar_slots, array_slots, block); + block.Append(Opcode::Br, {Operand::Label(it->second->GetLabelId())}); + } + return; + } + + case ir::Opcode::CondBr: + { + auto &br = static_cast(inst); + CondCode true_cond = CondCode::NE; + TryEmitCondValueToFlags(br.GetCond(), function, value_vregs, + scalar_slots, array_slots, block, true_cond); + + auto true_it = block_map.find(br.GetTrueTarget()); + if (true_it != block_map.end() && true_it->second) + { + EmitPhiValueStores(inst.GetParent(), br.GetTrueTarget(), function, + value_vregs, scalar_slots, array_slots, block); + block.Append(Opcode::CondBr, + {Operand::Imm(static_cast(true_cond)), + Operand::Label(true_it->second->GetLabelId())}); + } + + auto false_it = block_map.find(br.GetFalseTarget()); + if (false_it != block_map.end() && false_it->second) + { + EmitPhiValueStores(inst.GetParent(), br.GetFalseTarget(), function, + value_vregs, scalar_slots, array_slots, block); + block.Append(Opcode::Br, + {Operand::Label(false_it->second->GetLabelId())}); + } + return; + } + + case ir::Opcode::Call: + { + auto &call = static_cast(inst); + auto *callee = call.GetCallee(); + if (!callee) + { + if (!call.GetType()->IsVoid()) + { + if (IsPointerType(call.GetType())) + { + int vreg = function.CreateVReg(VRegClass::Ptr); + block.Append(Opcode::MovImm, {Operand::VReg(vreg, VRegClass::Ptr), Operand::Imm(0)}).SetRematerializable(true).SetRematImm(0); + value_vregs[&call] = vreg; + } + else if (IsFloatType(call.GetType())) + { + int vreg = function.CreateVReg(VRegClass::Float); + int wvreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::MovImm, {Operand::VReg(wvreg, VRegClass::Int), Operand::Imm(0)}).SetRematerializable(true).SetRematImm(0); + block.Append(Opcode::FMovWS, + {Operand::VReg(vreg, VRegClass::Float), Operand::VReg(wvreg, VRegClass::Int)}); + value_vregs[&call] = vreg; + } + else + { + int vreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::MovImm, {Operand::VReg(vreg, VRegClass::Int), Operand::Imm(0)}).SetRematerializable(true).SetRematImm(0); + value_vregs[&call] = vreg; + } + } + return; + } + + std::vector stack_arg_indices; + size_t gp_idx = 0; + size_t fp_idx = 0; + for (size_t i = 0; i < call.GetNumArgs(); ++i) + { + auto *arg = call.GetArg(i); + if (!arg) + continue; + + if (IsFloatType(arg->GetType())) + { + if (fp_idx < 8) + { + int vreg = EmitFloatValue(arg, function, value_vregs, block); + block.Append(Opcode::MovReg, + {Operand::Reg(GetArgSReg(fp_idx)), + Operand::VReg(vreg, VRegClass::Float)}); + ++fp_idx; + } + else + { + stack_arg_indices.push_back(i); + } + } + else if (IsPointerValue(arg)) + { + if (gp_idx < 8) + { + int vreg = EmitPtrValue(arg, function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::MovReg, + {Operand::Reg(GetArgXReg(gp_idx)), + Operand::VReg(vreg, VRegClass::Ptr)}); + ++gp_idx; + } + else + { + stack_arg_indices.push_back(i); + } + } + else + { + if (gp_idx < 8) + { + int vreg = EmitIntValue(arg, function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::MovReg, + {Operand::Reg(GetArgWReg(gp_idx)), + Operand::VReg(vreg, VRegClass::Int)}); + ++gp_idx; + } + else + { + stack_arg_indices.push_back(i); + } + } + } + + const int raw_stack_arg_bytes = ComputeStackArgumentBytes(call); + const int aligned_stack_arg_bytes = AlignTo(raw_stack_arg_bytes, 16); + if (aligned_stack_arg_bytes > 0) + { + EmitStackPointerAdjust(block, Opcode::SubRR, aligned_stack_arg_bytes); + + int offset = 0; + for (size_t idx : stack_arg_indices) + { + auto *arg = call.GetArg(idx); + const int arg_size = GetTypeSize(arg ? arg->GetType() : nullptr); + if (!arg) + { + offset += 8; + continue; + } + + const int slot = function.CreateStackArgFrameIndex(arg_size); + function.GetFrameSlot(slot).offset = offset; + if (IsFloatType(arg->GetType())) + { + int vreg = EmitFloatValue(arg, function, value_vregs, block); + block.Append(Opcode::StoreStack, + {Operand::VReg(vreg, VRegClass::Float), Operand::FrameIndex(slot)}); + } + else if (IsPointerValue(arg)) + { + int vreg = EmitPtrValue(arg, function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::StoreStack, + {Operand::VReg(vreg, VRegClass::Ptr), Operand::FrameIndex(slot)}); + } + else + { + int vreg = EmitIntValue(arg, function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::StoreStack, + {Operand::VReg(vreg, VRegClass::Int), Operand::FrameIndex(slot)}); + } + offset += 8; + } + } + + block.Append(Opcode::Call, {Operand::Symbol(callee->GetName())}); + function.SetHasCall(); + + if (aligned_stack_arg_bytes > 0) + { + EmitStackPointerAdjust(block, Opcode::AddRR, aligned_stack_arg_bytes); + } + + if (!call.GetType()->IsVoid()) + { + if (IsPointerType(call.GetType())) + { + int vreg = function.CreateVReg(VRegClass::Ptr); + block.Append(Opcode::MovReg, + {Operand::VReg(vreg, VRegClass::Ptr), Operand::Reg(PhysReg::X0)}); + value_vregs[&call] = vreg; + } + else if (IsFloatType(call.GetType())) + { + int vreg = function.CreateVReg(VRegClass::Float); + block.Append(Opcode::MovReg, + {Operand::VReg(vreg, VRegClass::Float), Operand::Reg(PhysReg::S0)}); + value_vregs[&call] = vreg; + } + else + { + int vreg = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::MovReg, + {Operand::VReg(vreg, VRegClass::Int), Operand::Reg(PhysReg::W0)}); + value_vregs[&call] = vreg; + } + } + return; + } + + case ir::Opcode::Ret: + { + auto &ret = static_cast(inst); + if (ret.HasValue()) + { + if (ret.GetValue()->GetType()->IsPtrInt32() || + ret.GetValue()->GetType()->IsPtrFloat32()) + { + int vreg = EmitPtrValue(ret.GetValue(), function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::MovReg, + {Operand::Reg(PhysReg::X0), Operand::VReg(vreg, VRegClass::Ptr)}); + } + else if (IsFloatValue(ret.GetValue())) + { + int vreg = EmitFloatValue(ret.GetValue(), function, value_vregs, block); + block.Append(Opcode::MovReg, + {Operand::Reg(PhysReg::S0), Operand::VReg(vreg, VRegClass::Float)}); + } + else + { + int vreg = EmitIntValue(ret.GetValue(), function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::MovReg, + {Operand::Reg(PhysReg::W0), Operand::VReg(vreg, VRegClass::Int)}); + } + } + block.Append(Opcode::Ret); + return; + } + + default: + return; + } + } + + static void LowerOneFunction(const ir::Function &ir_function, + MachineFunction &machine_func) + { + ValueVRegMap value_vregs; + LocalScalarMap scalar_slots; + LocalArrayMap array_slots; + BlockMap block_map; + + const auto *entry = ir_function.GetEntry(); + if (!entry) + { + throw std::runtime_error( + FormatError("mir", "IR 函数缺少入口基本块: " + ir_function.GetName())); + } + + block_map.emplace(entry, &machine_func.GetEntry()); + for (const auto &bb : ir_function.GetBlocks()) + { + if (!bb || bb.get() == entry) + continue; + block_map.emplace(bb.get(), &machine_func.CreateBlock(bb->GetName())); + } + + LowerFunctionParams(ir_function, machine_func, value_vregs); + + for (const auto &bb : ir_function.GetBlocks()) + { + if (!bb) + continue; + for (const auto &inst : bb->GetInstructions()) + { + auto *phi = dynamic_cast(inst.get()); + if (!phi) + break; + VRegClass vc = VRegClass::Int; + if (IsFloatType(phi->GetType())) + vc = VRegClass::Float; + else if (IsPointerType(phi->GetType())) + vc = VRegClass::Ptr; + int phi_vreg = machine_func.CreateVReg(vc); + value_vregs[phi] = phi_vreg; + } + } + + for (const auto &bb : ir_function.GetBlocks()) + { + if (!bb) + continue; + auto it = block_map.find(bb.get()); + if (it == block_map.end() || !it->second) + continue; + auto &mir_block = *it->second; + + std::vector to_remove; + for (auto &pair : value_vregs) + { + if (dynamic_cast(pair.first) || + dynamic_cast(pair.first) || + AsGlobalScalarObject(pair.first) || + AsGlobalArrayObject(pair.first) || + dynamic_cast(pair.first) || + dynamic_cast(pair.first)) + { + to_remove.push_back(pair.first); + } + } + for (auto *v : to_remove) + value_vregs.erase(v); + + for (const auto &inst : bb->GetInstructions()) + { + LowerInstruction(*inst, machine_func, + value_vregs, scalar_slots, array_slots, + block_map, mir_block); + } + } + } + + static void LowerGlobals(const ir::Module &module, + MachineModule &machine_module) + { + for (const auto &global : module.GetGlobals()) + { + if (!global) + continue; + if (!IsPointerValue(global.get())) + continue; + + if (global->IsArray()) + { + if (global->IsPtrFloat32()) + { + std::vector init_bits; + if (global->HasInitValues()) + { + const auto &init_values = global->GetInitFloatValues(); + init_bits.reserve(init_values.size()); + for (double v : init_values) + { + init_bits.push_back(FloatToBits(static_cast(v))); + } + } + machine_module.AddGlobalArrayI32(global->GetName(), + global->GetArraySize(), + std::move(init_bits)); + } + else if (global->HasInitValues()) + { + machine_module.AddGlobalArrayI32(global->GetName(), + global->GetArraySize(), + global->GetInitValues()); + } + else + { + machine_module.AddGlobalArrayI32(global->GetName(), + global->GetArraySize()); + } + continue; + } + + if (global->IsPtrFloat32()) + { + machine_module.AddGlobalI32(global->GetName(), + FloatToBits(static_cast(global->GetInitFloatValue()))); + } + else + { + machine_module.AddGlobalI32(global->GetName(), global->GetInitValue()); + } + } + } + + } // namespace + + std::unique_ptr LowerModuleToMIR(const ir::Module &module) + { + DefaultContext(); + + auto machine_module = std::make_unique(); + LowerGlobals(module, *machine_module); + + for (const auto &func : module.GetFunctions()) + { + if (!func || func->IsExternal()) + continue; + + auto &machine_func = machine_module->CreateFunction(func->GetName()); + LowerOneFunction(*func, machine_func); + } + + return machine_module; + } + + std::unique_ptr LowerToMIR(const ir::Module &module) + { + auto machine_module = LowerModuleToMIR(module); + if (!machine_module) + { + throw std::runtime_error(FormatError("mir", "LowerModuleToMIR 失败")); + } + + auto &functions = machine_module->GetFunctions(); + for (auto &func : functions) + { + if (func && func->GetName() == "main") + { + return std::move(func); + } + } + + throw std::runtime_error(FormatError("mir", "未找到 main 函数对应的 MIR")); + } + +} // namespace mir diff --git a/src/mir/Lowering.cpp.rej b/src/mir/Lowering.cpp.rej new file mode 100644 index 00000000..aa6a3c92 --- /dev/null +++ b/src/mir/Lowering.cpp.rej @@ -0,0 +1,26 @@ +--- src/mir/Lowering.cpp ++++ src/mir/Lowering.cpp +@@ -339,10 +339,19 @@ + { + if (IsIntegerCompareOpcode(bin->GetOpcode())) + { + int lhs = EmitIntValue(bin->GetLhs(), function, value_vregs, + scalar_slots, array_slots, block); +- int rhs = EmitIntValue(bin->GetRhs(), function, value_vregs, +- scalar_slots, array_slots, block); +- block.Append(Opcode::CmpRR, +- {Operand::VReg(lhs, VRegClass::Int), Operand::VReg(rhs, VRegClass::Int)}); ++ int rhs_imm; ++ if (TryGetConstantInt(bin->GetRhs(), rhs_imm)) ++ { ++ block.Append(Opcode::CmpImm, ++ {Operand::VReg(lhs, VRegClass::Int), Operand::Imm(rhs_imm)}); ++ } ++ else ++ { ++ int rhs = EmitIntValue(bin->GetRhs(), function, value_vregs, ++ scalar_slots, array_slots, block); ++ block.Append(Opcode::CmpRR, ++ {Operand::VReg(lhs, VRegClass::Int), Operand::VReg(rhs, VRegClass::Int)}); ++ } + diff --git a/src/mir/MIRBasicBlock.cpp b/src/mir/MIRBasicBlock.cpp index ea9ea208..8ae7b02d 100644 --- a/src/mir/MIRBasicBlock.cpp +++ b/src/mir/MIRBasicBlock.cpp @@ -15,22 +15,4 @@ namespace mir return instructions_.back(); } - void MachineBasicBlock::InsertInst(int local_idx, MachineInstr inst) - { - if (local_idx < 0 || local_idx > (int)instructions_.size()) return; - instructions_.insert(instructions_.begin() + local_idx, std::move(inst)); - } - - void MachineBasicBlock::ReplaceVReg(int local_idx, int old_vreg, - int new_vreg) - { - if (local_idx < 0 || local_idx >= (int)instructions_.size()) return; - for (auto &op : instructions_[local_idx].GetOperands()) - { - if (op.GetKind() == Operand::Kind::VReg && - op.GetVRegId() == old_vreg) - op = Operand::VReg(new_vreg, op.GetVRegClass()); - } - } - } // namespace mir \ No newline at end of file diff --git a/src/mir/RegAlloc.cpp b/src/mir/RegAlloc.cpp index f8b57859..8a99870d 100644 --- a/src/mir/RegAlloc.cpp +++ b/src/mir/RegAlloc.cpp @@ -1,9 +1,6 @@ -#if 0 #include "mir/MIR.h" #include -#include -#include #include #include #include @@ -48,8 +45,8 @@ namespace mir return reg >= PhysReg::S0 && reg <= PhysReg::S31; } - static const int GP_ALLOCATABLE[] = {8, 9, 10, 11, 12, 15, 16, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28}; - static const int GP_NUM_ALLOCATABLE = 18; + static const int GP_ALLOCATABLE[] = {8, 9, 10, 11, 12, 15, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28}; + static const int GP_NUM_ALLOCATABLE = 16; static const int FP_ALLOCATABLE[] = {8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; static const int FP_NUM_ALLOCATABLE = 24; static const int FP_CALLER_SAVED[] = {8, 9, 10, 11, 12, 13, 14, 15}; @@ -120,8 +117,6 @@ namespace mir case Opcode::AddRR: case Opcode::SubRR: - case Opcode::AddImm: - case Opcode::SubImm: case Opcode::MulRR: case Opcode::DivRR: case Opcode::ModRR: @@ -217,6 +212,18 @@ namespace mir } break; + case Opcode::Csneg: + if (ops.size() >= 3) + { + if (ops[0].GetKind() == Operand::Kind::VReg) + result.defs.push_back(ops[0].GetVRegId()); + if (ops[1].GetKind() == Operand::Kind::VReg) + result.uses.push_back(ops[1].GetVRegId()); + if (ops[2].GetKind() == Operand::Kind::VReg) + result.uses.push_back(ops[2].GetVRegId()); + } + break; + case Opcode::Smull: if (ops.size() >= 3) { @@ -317,14 +324,7 @@ namespace mir std::unordered_set use; }; - struct LivenessResult - { - std::vector block_liveness; - std::unordered_map interval_length; - std::unordered_map ref_count; - }; - - static LivenessResult ComputeBlockLiveness(MachineFunction &function) + static std::vector ComputeBlockLiveness(MachineFunction &function) { auto &blocks = function.GetBlocks(); const size_t num_blocks = blocks.size(); @@ -336,8 +336,6 @@ namespace mir label_to_block[blocks[i]->GetLabelId()] = i; } - std::unordered_map ref_count; - for (size_t i = 0; i < num_blocks; ++i) { for (const auto &inst : blocks[i]->GetInstructions()) @@ -347,13 +345,9 @@ namespace mir { if (bl[i].def.find(u) == bl[i].def.end()) bl[i].use.insert(u); - ref_count[u]++; } for (int d : du.defs) - { bl[i].def.insert(d); - ref_count[d]++; - } } } @@ -405,28 +399,7 @@ namespace mir } } - // 在最终稳定的 liveness 上统计 interval_length - // 反向扫描每个块,统计每个 vreg 在多少条指令处活跃 - std::unordered_map interval_length; - for (size_t i = 0; i < num_blocks; ++i) - { - std::unordered_set live = bl[i].live_out; - for (int v : live) - interval_length[v]++; - const auto &insts = blocks[i]->GetInstructions(); - for (auto it = insts.rbegin(); it != insts.rend(); ++it) - { - auto du = GetInstDefUse(*it, function); - for (int d : du.defs) - live.erase(d); - for (int u : du.uses) - live.insert(u); - for (int v : live) - interval_length[v]++; - } - } - - return {std::move(bl), std::move(interval_length), std::move(ref_count)}; + return bl; } struct InterferenceGraph @@ -487,24 +460,25 @@ namespace mir auto &block = blocks[bi]; std::unordered_set live = block_liveness[bi].live_out; + std::vector gp_live; + for (int v : live) + { + if (IsGPClass(function.GetVRegClass(v))) + gp_live.push_back(v); + } + for (size_t i = 0; i < gp_live.size(); ++i) + { + for (size_t j = i + 1; j < gp_live.size(); ++j) + { + graph.AddEdge(gp_live[i], gp_live[j]); + } + } + const auto &insts = block->GetInstructions(); for (auto it = insts.rbegin(); it != insts.rend(); ++it) { auto du = GetInstDefUse(*it, function); - // MovReg: 暂时从 live 中移除 use 操作数,使 def/use 之间不产生干涉边 - bool is_movreg = (it->GetOpcode() == Opcode::MovReg); - std::vector saved_uses; - if (is_movreg && du.defs.size() == 1 && du.uses.size() == 1) - { - int use_vreg = du.uses[0]; - if (live.count(use_vreg) && IsGPClass(function.GetVRegClass(use_vreg))) - { - live.erase(use_vreg); - saved_uses.push_back(use_vreg); - } - } - for (int d : du.defs) { if (!IsGPClass(function.GetVRegClass(d))) @@ -516,9 +490,6 @@ namespace mir } live.erase(d); } - // 恢复 MovReg 的 use - for (int u : saved_uses) - live.insert(u); for (int u : du.uses) { if (IsGPClass(function.GetVRegClass(u))) @@ -554,7 +525,6 @@ namespace mir } } } - } static void BuildInterferenceForFP( @@ -576,24 +546,25 @@ namespace mir auto &block = blocks[bi]; std::unordered_set live = block_liveness[bi].live_out; + std::vector fp_live; + for (int v : live) + { + if (function.GetVRegClass(v) == VRegClass::Float) + fp_live.push_back(v); + } + for (size_t i = 0; i < fp_live.size(); ++i) + { + for (size_t j = i + 1; j < fp_live.size(); ++j) + { + graph.AddEdge(fp_live[i], fp_live[j]); + } + } + const auto &insts = block->GetInstructions(); for (auto it = insts.rbegin(); it != insts.rend(); ++it) { auto du = GetInstDefUse(*it, function); - // MovReg: 暂时从 live 中移除 use 操作数,使 def/use 之间不产生干涉边 - bool is_movreg = (it->GetOpcode() == Opcode::MovReg); - std::vector saved_uses; - if (is_movreg && du.defs.size() == 1 && du.uses.size() == 1) - { - int use_vreg = du.uses[0]; - if (live.count(use_vreg) && function.GetVRegClass(use_vreg) == VRegClass::Float) - { - live.erase(use_vreg); - saved_uses.push_back(use_vreg); - } - } - for (int d : du.defs) { if (function.GetVRegClass(d) != VRegClass::Float) @@ -605,9 +576,6 @@ namespace mir } live.erase(d); } - // 恢复 MovReg 的 use - for (int u : saved_uses) - live.insert(u); for (int u : du.uses) { if (function.GetVRegClass(u) == VRegClass::Float) @@ -643,13 +611,7 @@ namespace mir static GraphColoringResult ColorGraph( InterferenceGraph &graph, const std::vector &allocatable_regs, - MachineFunction & /*function*/, - int caller_saved_threshold, - const std::unordered_map &interval_length, - const std::unordered_map &ref_count, - const std::set &rematerializable_vregs, - const std::unordered_map &move_preferences, - const std::unordered_map &vreg_loop_depth) + MachineFunction & /*function*/) { const int K = static_cast(allocatable_regs.size()); GraphColoringResult result; @@ -698,152 +660,21 @@ namespace mir } } - // === Coalesce 数据结构 === - struct MovePair { int u; int v; }; - std::unordered_map> move_adj; - std::vector coalesce_candidates; - std::vector constrained_pairs; - std::vector stale_pairs; // 被 GiveUpPhase 冻结的 move - std::unordered_set merged_set; - std::unordered_map rep; - - for (const auto &pair : move_preferences) - { - int u = pair.first; - int v = pair.second; - if (graph.nodes.count(u) && graph.nodes.count(v)) - { - move_adj[u].insert(v); - move_adj[v].insert(u); - coalesce_candidates.push_back({u, v}); - } - } - - auto FindRep = [&](int n, auto &&self) -> int { - auto it = rep.find(n); - if (it != rep.end()) - return self(it->second, self); - return n; - }; - auto GetRep = [&](int n) { return FindRep(n, FindRep); }; - - auto HasMovePair = [&](int n) -> bool { - return !move_adj[n].empty(); - }; - std::vector simplify_worklist; - std::vector held_nodes; // 被 hold 的 move 相关低度数节点 for (int v : remaining) { if (degree[v] < K) - { - if (HasMovePair(v)) - held_nodes.push_back(v); // hold 住等 coalesce 先尝试合并 - else - simplify_worklist.push_back(v); - } + simplify_worklist.push_back(v); } std::vector stack; - // === MergePhase: 尝试合并 move 相关的节点对 === - auto MergeInto = [&](int u, int v) { - rep[v] = u; - merged_set.insert(v); - remaining.erase(v); - for (int n : adj[v]) - { - int alias_n = GetRep(n); - if (alias_n != u && alias_n != v) - { - adj[u].insert(alias_n); - adj[alias_n].insert(u); - } - } - // 转移 v 的 move_adj 到 u,同时更新其他端点的引用 - for (int m : move_adj[v]) - { - if (m != u) - { - move_adj[u].insert(m); - move_adj[m].erase(v); - move_adj[m].insert(u); - } - } - move_adj[v].clear(); - // 合并后重新计算 u 的度数(邻接集已包含 v 的邻接) - int new_deg = 0; - for (int n : adj[u]) - if (remaining.count(GetRep(n))) new_deg++; - degree[u] = new_deg; - if (degree[u] < K && !HasMovePair(u)) - simplify_worklist.push_back(u); - }; - - auto MergePhase = [&]() -> bool { - if (coalesce_candidates.empty()) return false; - MovePair m = coalesce_candidates.back(); - coalesce_candidates.pop_back(); - - int u = GetRep(m.u); - int v = GetRep(m.v); - if (u == v) return true; - - bool interferes = false; - for (int n : adj[u]) - if (GetRep(n) == v) { interferes = true; break; } - if (interferes) - { - constrained_pairs.push_back(m); - if (degree[u] < K && !HasMovePair(u)) simplify_worklist.push_back(u); - if (degree[v] < K && !HasMovePair(v)) simplify_worklist.push_back(v); - return true; - } - - std::set union_adj; - for (int n : adj[u]) union_adj.insert(GetRep(n)); - for (int n : adj[v]) union_adj.insert(GetRep(n)); - union_adj.erase(u); - union_adj.erase(v); - int high_deg_count = 0; - for (int n : union_adj) - // 阈值 2*K:只有度数极高的邻接节点才算"高风险" - // K=16 时仅 degree >= 32 的节点计入,大幅提高合并成功率 - if (degree[GetRep(n)] >= 2 * K) high_deg_count++; - - if (high_deg_count < K) - { - MergeInto(u, v); - return true; - } - else - { - constrained_pairs.push_back(m); - return true; - } - }; - - // 当节点 n 的度数从 K 降到 K-1 时,重新激活相关的搁置 move pair - auto ReactivatePairs = [&](int n) { - int rep_n = GetRep(n); - std::vector still_constrained; - for (auto &pair : constrained_pairs) - { - if (GetRep(pair.u) == rep_n || GetRep(pair.v) == rep_n) - coalesce_candidates.push_back(pair); - else - still_constrained.push_back(pair); - } - constrained_pairs = std::move(still_constrained); - }; - while (!remaining.empty()) { - if (!simplify_worklist.empty()) + while (!simplify_worklist.empty()) { int v = simplify_worklist.back(); simplify_worklist.pop_back(); - v = GetRep(v); if (!remaining.count(v)) continue; stack.push_back(v); @@ -852,83 +683,22 @@ namespace mir { if (remaining.count(n)) { - int old_deg = degree[n]; degree[n]--; - if (old_deg == K && degree[n] == K - 1) - { - ReactivatePairs(n); - if (HasMovePair(GetRep(n))) - held_nodes.push_back(n); - else - simplify_worklist.push_back(n); - } + if (degree[n] == K - 1) + simplify_worklist.push_back(n); } } } - else if (!coalesce_candidates.empty()) - { - MergePhase(); - } - else if (!held_nodes.empty()) - { - // 释放一个被 hold 的节点:放弃其所有 move,推入 simplify - int held_v = held_nodes.back(); - held_nodes.pop_back(); - int rep_v = GetRep(held_v); - if (remaining.count(rep_v) && HasMovePair(rep_v)) - { - // 放弃 rep_v 的所有 move pair - for (int other : move_adj[rep_v]) - move_adj[other].erase(rep_v); - // 从 coalesce_candidates 中移除涉及 rep_v 的 pair - std::vector remaining_pairs; - for (auto &pair : coalesce_candidates) - { - if (GetRep(pair.u) == rep_v || GetRep(pair.v) == rep_v) - stale_pairs.push_back(pair); - else - remaining_pairs.push_back(pair); - } - coalesce_candidates = std::move(remaining_pairs); - move_adj[rep_v].clear(); - } - if (remaining.count(rep_v)) - simplify_worklist.push_back(rep_v); - continue; - } - else - { - // spill cost: len(活跃指令数)*5 + ref(def+use总次数)*15 - degree(干涉度数)*25 - // cost 越小越优先 spill —— 短区间、少引用、高冲突的变量更适合溢出 - // 权重基于经验调节:degree 项主导,len/ref 项作为 tiebreaker - auto GetSpillCost = [&](int v) -> int { - int len = 0; - auto lit = interval_length.find(v); - if (lit != interval_length.end()) len = lit->second; - int ref = 0; - auto rit = ref_count.find(v); - if (rit != ref_count.end()) ref = rit->second; - int d = degree[v]; - // 循环深度加权:depth=0 → ×1, depth=1 → ×10, depth=2 → ×100 - int depth = 0; - auto dit = vreg_loop_depth.find(v); - if (dit != vreg_loop_depth.end()) depth = dit->second; - int loop_mult = 1; - for (int i = 0; i < depth; i++) loop_mult *= 10; - int cost = (len * 5 + ref * 15) * loop_mult - d * 25; - if (rematerializable_vregs.count(v)) - cost -= 100000; - return cost; - }; + if (!remaining.empty()) + { int spill_candidate = -1; - int min_cost = std::numeric_limits::max(); + int max_degree = -1; for (int v : remaining) { - int cost = GetSpillCost(v); - if (cost < min_cost) + if (degree[v] > max_degree) { - min_cost = cost; + max_degree = degree[v]; spill_candidate = v; } } @@ -940,16 +710,9 @@ namespace mir { if (remaining.count(n)) { - int old_deg = degree[n]; degree[n]--; - if (old_deg == K && degree[n] == K - 1) - { - ReactivatePairs(n); - if (HasMovePair(GetRep(n))) - held_nodes.push_back(n); - else - simplify_worklist.push_back(n); - } + if (degree[n] == K - 1) + simplify_worklist.push_back(n); } } } @@ -957,7 +720,7 @@ namespace mir { break; } - } + } } std::unordered_map colored = precolored; @@ -975,46 +738,14 @@ namespace mir } int assigned_color = -1; - // 尝试使用 move 偏好颜色(如果偏好变量已分配且无冲突) - auto pref_it = move_preferences.find(v); - if (pref_it != move_preferences.end()) - { - int pref_v = pref_it->second; - auto col_it = colored.find(pref_v); - if (col_it != colored.end()) - { - int pref_color = col_it->second; - if (used_colors.find(pref_color) == used_colors.end()) - assigned_color = pref_color; - } - } - - // 第一遍:优先选 caller-saved 颜色 (c < caller_saved_threshold) - if (assigned_color < 0) + for (int c : allocatable_regs) { - for (int c : allocatable_regs) - { - if (c >= caller_saved_threshold) break; if (used_colors.find(c) == used_colors.end()) { assigned_color = c; break; } } - // 第二遍:若 caller-saved 无可用,选 callee-saved - if (assigned_color < 0) - { - for (int c : allocatable_regs) - { - if (c < caller_saved_threshold) continue; - if (used_colors.find(c) == used_colors.end()) - { - assigned_color = c; - break; - } - } - } - } if (assigned_color >= 0) { @@ -1027,18 +758,6 @@ namespace mir } } - // 将合并节点的颜色设为其代表节点的颜色 - for (int v : merged_set) - { - int rep_v = GetRep(v); - auto it = colored.find(rep_v); - if (it != colored.end()) - { - colored[v] = it->second; - result.assignment[v] = it->second; - } - } - return result; } @@ -1127,128 +846,17 @@ namespace mir return FP_ALLOCATABLE[0]; } - // 为 spilled vreg 分配 frame slot,不重叠活区间的 vreg 共享同一 slot - static std::unordered_map AssignSpillSlots( - MachineFunction &function, - const std::set &spilled, - const std::vector &block_liveness) - { - if (spilled.empty()) - return {}; - - // 计算每个 spilled vreg 的活跃块集合 - std::unordered_map> live_blocks; - for (int v : spilled) - { - for (size_t bi = 0; bi < block_liveness.size(); ++bi) - { - const auto &bl = block_liveness[bi]; - if (bl.live_in.count(v) || bl.live_out.count(v) || - bl.def.count(v) || bl.use.count(v)) - { - live_blocks[v].insert(static_cast(bi)); - } - } - // vreg 不在 liveness 中(可能来自之前轮次的新 vreg), - // 保守假设在所有块活跃,不共享 slot - if (live_blocks[v].empty()) - { - for (size_t bi = 0; bi < block_liveness.size(); ++bi) - live_blocks[v].insert(static_cast(bi)); - } - } - - // 按活跃块数量降序排列——长活区间优先分配 - std::vector sorted_spilled(spilled.begin(), spilled.end()); - std::sort(sorted_spilled.begin(), sorted_spilled.end(), - [&](int a, int b) { - return live_blocks[a].size() > live_blocks[b].size(); - }); - - // 贪心分配 slot - struct SlotInfo - { - int frame_idx; - int size; // 4 或 8 - std::unordered_set owners; - }; - std::vector slots; - std::unordered_map result; - - for (int v : sorted_spilled) - { - VRegClass vc = function.GetVRegClass(v); - int size = (vc == VRegClass::Ptr) ? 8 : 4; - const auto &my_blocks = live_blocks[v]; - - int assigned = -1; - for (size_t si = 0; si < slots.size(); ++si) - { - // 大小必须兼容:同大小或 slot 更大 - if (slots[si].size < size) - continue; - - bool conflict = false; - for (int owner : slots[si].owners) - { - const auto &owner_blocks = live_blocks[owner]; - // 检查块集合是否相交 - for (int b : my_blocks) - { - if (owner_blocks.count(b)) - { - conflict = true; - break; - } - } - if (conflict) - break; - } - - if (!conflict) - { - assigned = static_cast(si); - break; - } - } - - if (assigned < 0) - { - int fidx = function.CreateFrameIndex(size); - assigned = static_cast(slots.size()); - slots.push_back({fidx, size, {}}); - } - - slots[assigned].owners.insert(v); - result[v] = slots[assigned].frame_idx; - } - - return result; - } - static void RewriteWithAllocation( MachineFunction &function, const std::unordered_map &gp_assignment, const std::unordered_map &fp_assignment, - const std::set &spilled, - const std::unordered_map &vreg_def_inst = {}, - const std::vector *block_liveness = nullptr) + const std::set &spilled) { std::unordered_map spill_slots; - if (!spilled.empty()) + for (int v : spilled) { - if (block_liveness && !block_liveness->empty()) - { - spill_slots = AssignSpillSlots(function, spilled, *block_liveness); - } - else - { - for (int v : spilled) - { - int size = (function.GetVRegClass(v) == VRegClass::Ptr) ? 8 : 4; - spill_slots[v] = function.CreateFrameIndex(size); - } - } + int size = (function.GetVRegClass(v) == VRegClass::Ptr) ? 8 : 4; + spill_slots[v] = function.CreateFrameIndex(size); } for (auto &block : function.GetBlocks()) @@ -1269,11 +877,6 @@ namespace mir int slot = spill_slots[u]; int reload_reg_num = -1; - // 检查是否可再物化(仅GP寄存器支持,AArch64不支持mov sN,#imm) - auto def_it = vreg_def_inst.find(u); - bool can_remat = (vc != VRegClass::Float) && - (def_it != vreg_def_inst.end() && def_it->second->IsRematerializable()); - if (vc == VRegClass::Float) { reload_reg_num = PickFPScratchReg(used_scratch_fp, du, fp_assignment); @@ -1286,21 +889,9 @@ namespace mir } PhysReg reload_reg = NumberToPhysReg(reload_reg_num, vc); - - if (can_remat) - { - // 再物化:用 scratch 寄存器直接生成 MovImm - new_insts.push_back( - MachineInstr(Opcode::MovImm, - {Operand::Reg(reload_reg), Operand::Imm(def_it->second->GetRematImm())})); - } - else - { - // 常规:从栈加载 - new_insts.push_back( - MachineInstr(Opcode::LoadStack, - {Operand::Reg(reload_reg), Operand::FrameIndex(slot)})); - } + new_insts.push_back( + MachineInstr(Opcode::LoadStack, + {Operand::Reg(reload_reg), Operand::FrameIndex(slot)})); for (auto &op : inst.GetOperands()) { @@ -1365,31 +956,23 @@ namespace mir VRegClass vc = function.GetVRegClass(d); int slot = spill_slots[d]; - // 可再物化变量:不需要 StoreStack,use 点会重新生成 MovImm(仅GP) - auto def_it = vreg_def_inst.find(d); - bool can_remat = (vc != VRegClass::Float) && - (def_it != vreg_def_inst.end() && def_it->second->IsRematerializable()); - - if (!can_remat) + const auto &last_inst = new_insts.back(); + PhysReg spill_reg = PhysReg::W0; + for (const auto &op : last_inst.GetOperands()) { - const auto &last_inst = new_insts.back(); - PhysReg spill_reg = PhysReg::W0; - for (const auto &op : last_inst.GetOperands()) + if (op.GetKind() == Operand::Kind::Reg) { - if (op.GetKind() == Operand::Kind::Reg) - { - PhysReg r = op.GetReg(); - if (vc == VRegClass::Float && IsFPReg(r)) - { spill_reg = r; break; } - else if (vc != VRegClass::Float && IsGPReg(r)) - { spill_reg = r; break; } - } + PhysReg r = op.GetReg(); + if (vc == VRegClass::Float && IsFPReg(r)) + { spill_reg = r; break; } + else if (vc != VRegClass::Float && IsGPReg(r)) + { spill_reg = r; break; } } - - new_insts.push_back( - MachineInstr(Opcode::StoreStack, - {Operand::Reg(spill_reg), Operand::FrameIndex(slot)})); } + + new_insts.push_back( + MachineInstr(Opcode::StoreStack, + {Operand::Reg(spill_reg), Operand::FrameIndex(slot)})); } } } @@ -1403,184 +986,20 @@ namespace mir if (function.GetNumVRegs() == 0) return; - // 大 vreg 数函数回退到线性扫描:块级活跃分析在密集 def-use - // 模式下导致 spill 重写后干涉边缺失。线性扫描的指令级精确 - // 区间能正确处理(如 30_many_dimensions)。 - if (function.GetNumVRegs() > 100) - { - RunLinearScanRegAlloc(function); - return; - } - - // 限制 spill 轮次为 1:block-level liveness 下多轮 spill 创建 - // 的 reload vreg 与保守修复(block_defs 全干涉)交互,产生错误 - // 寄存器分配。1 轮足够——循环外 RewriteWithAllocation 用 scratch - // 寄存器处理所有剩余 spill。修复:04_arr_defn3/05_arr_defn4 段错误 - // 及 09_BFS bad_alloc。 - const int MAX_SPILL_ROUNDS = (function.GetNumVRegs() > 120) ? 3 : 10; + const int MAX_SPILL_ROUNDS = 10; for (int round = 0; round < MAX_SPILL_ROUNDS; ++round) { - // 构建 VReg → 定义指令映射(用于再物化判断) - std::unordered_map vreg_def_inst; - for (auto &block : function.GetBlocks()) - { - for (auto &inst : block->GetInstructions()) - { - auto du = GetInstDefUse(inst, function); - for (int d : du.defs) - { - vreg_def_inst[d] = &inst; - } - } - } - - auto liveness = ComputeBlockLiveness(function); - - // === 回边检测:计算基本块的循环嵌套深度 === - size_t num_blocks = function.GetBlocks().size(); - std::vector loop_depth(num_blocks, 0); - { - std::unordered_map blk_label_to_idx; - for (size_t i = 0; i < liveness.block_liveness.size(); ++i) - blk_label_to_idx[function.GetBlocks()[i]->GetLabelId()] = i; - - std::vector dfs_state(num_blocks, 0); // 0=未访问, 1=栈中, 2=已完成 - - std::function dfs = [&](size_t cur) { - dfs_state[cur] = 1; - const auto &insts = function.GetBlocks()[cur]->GetInstructions(); - for (const auto &inst : insts) - { - size_t succ = static_cast(-1); - if (inst.GetOpcode() == Opcode::Br && inst.GetOperands().size() >= 1 && - inst.GetOperands()[0].GetKind() == Operand::Kind::Label) - { - auto it = blk_label_to_idx.find(inst.GetOperands()[0].GetLabel()); - if (it != blk_label_to_idx.end()) succ = it->second; - } - if (inst.GetOpcode() == Opcode::CondBr && inst.GetOperands().size() >= 2 && - inst.GetOperands()[1].GetKind() == Operand::Kind::Label) - { - auto it = blk_label_to_idx.find(inst.GetOperands()[1].GetLabel()); - if (it != blk_label_to_idx.end()) succ = it->second; - } - if (succ == static_cast(-1)) continue; - if (dfs_state[succ] == 1) - { - // 回边:cur → succ 且 succ 在栈中 - loop_depth[succ]++; - loop_depth[cur]++; - } - else if (dfs_state[succ] == 0) - { - dfs(succ); - if (loop_depth[succ] > 0) - loop_depth[cur] = std::max(loop_depth[cur], loop_depth[succ]); - } - } - dfs_state[cur] = 2; - }; - dfs(0); - } - - // 为每个 vreg 确定其最大循环深度(取定义所在块的 loop_depth) - std::unordered_map vreg_loop_depth; - for (size_t bi = 0; bi < num_blocks; ++bi) - { - const auto &insts = function.GetBlocks()[bi]->GetInstructions(); - for (const auto &inst : insts) - { - auto du = GetInstDefUse(inst, function); - for (int d : du.defs) - { - auto it = vreg_loop_depth.find(d); - int cur_depth = loop_depth[bi]; - if (it == vreg_loop_depth.end() || cur_depth > it->second) - vreg_loop_depth[d] = cur_depth; - } - } - } - - // 构建可再物化 vreg 集合(MovImm 常量) - std::set rematerializable_vregs; - for (const auto &pair : vreg_def_inst) - { - if (pair.second->IsRematerializable()) - rematerializable_vregs.insert(pair.first); - } - - // 收集 MovReg 的 move 偏好映射 - std::unordered_map move_preferences; - for (auto &block : function.GetBlocks()) - { - for (auto &inst : block->GetInstructions()) - { - if (inst.GetOpcode() == Opcode::MovReg) - { - const auto &ops = inst.GetOperands(); - if (ops.size() >= 2 && - ops[0].GetKind() == Operand::Kind::VReg && - ops[1].GetKind() == Operand::Kind::VReg) - { - int def_vreg = ops[0].GetVRegId(); - int use_vreg = ops[1].GetVRegId(); - if (function.GetVRegClass(def_vreg) == function.GetVRegClass(use_vreg)) - move_preferences[def_vreg] = use_vreg; - } - } - } - } - - // 检测并打破 move 偏好循环 - { - std::unordered_set visited; - for (auto &pair : move_preferences) - { - int cur = pair.first; - if (visited.count(cur)) continue; - std::unordered_set path; - std::vector chain; - int node = cur; - while (node != 0 && !visited.count(node)) - { - visited.insert(node); - path.insert(node); - chain.push_back(node); - auto it = move_preferences.find(node); - if (it == move_preferences.end()) break; - int next = it->second; - if (path.count(next)) - { - move_preferences.erase(chain.back()); - break; - } - node = next; - } - } - } - - // 大函数丢弃 move 偏好以保持分配稳定性 - // 条件1: move偏好数量 >100 直接跳过 - // 条件2: vreg数 * move偏好数 > 600 (conv2d: 71*15=1065) - int mv = static_cast(move_preferences.size()); - if (mv > 100 || static_cast(function.GetNumVRegs()) * mv > 600) - move_preferences.clear(); + auto block_liveness = ComputeBlockLiveness(function); std::vector gp_alloc(GP_ALLOCATABLE, GP_ALLOCATABLE + GP_NUM_ALLOCATABLE); std::vector fp_alloc(FP_ALLOCATABLE, FP_ALLOCATABLE + FP_NUM_ALLOCATABLE); InterferenceGraph gp_graph, fp_graph; - BuildInterferenceForGP(function, liveness.block_liveness, gp_alloc, gp_graph); - BuildInterferenceForFP(function, liveness.block_liveness, fp_alloc, fp_graph); - - auto gp_result = ColorGraph(gp_graph, gp_alloc, function, 19, - liveness.interval_length, liveness.ref_count, - rematerializable_vregs, move_preferences, - vreg_loop_depth); - auto fp_result = ColorGraph(fp_graph, fp_alloc, function, 16, - liveness.interval_length, liveness.ref_count, - rematerializable_vregs, move_preferences, - vreg_loop_depth); + BuildInterferenceForGP(function, block_liveness, gp_alloc, gp_graph); + BuildInterferenceForFP(function, block_liveness, fp_alloc, fp_graph); + + auto gp_result = ColorGraph(gp_graph, gp_alloc, function); + auto fp_result = ColorGraph(fp_graph, fp_alloc, function); if (gp_result.spilled.empty() && fp_result.spilled.empty()) { @@ -1604,28 +1023,6 @@ namespace mir } RewriteWithAllocation(function, gp_assign, fp_assign, {}); - - // 消除冗余 MovReg(源和目标分配到同一物理寄存器) - for (auto &block : function.GetBlocks()) - { - auto &insts = block->GetInstructions(); - std::vector filtered; - filtered.reserve(insts.size()); - for (auto &inst : insts) - { - if (inst.GetOpcode() == Opcode::MovReg) - { - const auto &ops = inst.GetOperands(); - if (ops.size() >= 2 && - ops[0].GetKind() == Operand::Kind::Reg && - ops[1].GetKind() == Operand::Kind::Reg && - ops[0].GetReg() == ops[1].GetReg()) - continue; - } - filtered.push_back(std::move(inst)); - } - insts = std::move(filtered); - } return; } @@ -1633,9 +1030,12 @@ namespace mir for (int v : fp_result.spilled) all_spilled.insert(v); - // 共享 spill slot:不重叠活区间的 spilled vreg 复用同一 frame slot - std::unordered_map spill_slots = - AssignSpillSlots(function, all_spilled, liveness.block_liveness); + std::unordered_map spill_slots; + for (int v : all_spilled) + { + int size = (function.GetVRegClass(v) == VRegClass::Ptr) ? 8 : 4; + spill_slots[v] = function.CreateFrameIndex(size); + } for (auto &block : function.GetBlocks()) { @@ -1650,25 +1050,9 @@ namespace mir { VRegClass vc = function.GetVRegClass(u); int new_vreg = function.CreateVReg(vc); - - // 检查是否可再物化(仅GP寄存器支持,AArch64不支持mov sN,#imm) - auto def_it = vreg_def_inst.find(u); - if (vc != VRegClass::Float && - def_it != vreg_def_inst.end() && def_it->second->IsRematerializable()) - { - // 再物化:直接生成 MovImm - new_insts.push_back( - MachineInstr(Opcode::MovImm, - {Operand::VReg(new_vreg, vc), Operand::Imm(def_it->second->GetRematImm())})); - } - else - { - // 常规:从栈加载 - new_insts.push_back( - MachineInstr(Opcode::LoadStack, - {Operand::VReg(new_vreg, vc), Operand::FrameIndex(spill_slots[u])})); - } - + new_insts.push_back( + MachineInstr(Opcode::LoadStack, + {Operand::VReg(new_vreg, vc), Operand::FrameIndex(spill_slots[u])})); for (auto &op : inst.GetOperands()) { if (op.GetKind() == Operand::Kind::VReg && op.GetVRegId() == u) @@ -1695,18 +1079,9 @@ namespace mir const_cast(op) = Operand::VReg(new_vreg, vc); } } - - // 可再物化变量:不需要 StoreStack,use 点会重新生成 MovImm(仅GP) - auto def_it = vreg_def_inst.find(d); - bool can_remat = (vc != VRegClass::Float) && - (def_it != vreg_def_inst.end() && def_it->second->IsRematerializable()); - - if (!can_remat) - { - new_insts.push_back( - MachineInstr(Opcode::StoreStack, - {Operand::VReg(new_vreg, vc), Operand::FrameIndex(spill_slots[d])})); - } + new_insts.push_back( + MachineInstr(Opcode::StoreStack, + {Operand::VReg(new_vreg, vc), Operand::FrameIndex(spill_slots[d])})); } } } @@ -1714,160 +1089,14 @@ namespace mir } } - // 循环外:构建 VReg → 定义指令映射(用于再物化判断) - std::unordered_map vreg_def_inst; - for (auto &block : function.GetBlocks()) - { - for (auto &inst : block->GetInstructions()) - { - auto du = GetInstDefUse(inst, function); - for (int d : du.defs) - { - vreg_def_inst[d] = &inst; - } - } - } - - auto liveness = ComputeBlockLiveness(function); - - // === 回边检测:计算基本块的循环嵌套深度 === - size_t num_blocks = function.GetBlocks().size(); - std::vector loop_depth(num_blocks, 0); - { - std::unordered_map blk_label_to_idx; - for (size_t i = 0; i < liveness.block_liveness.size(); ++i) - blk_label_to_idx[function.GetBlocks()[i]->GetLabelId()] = i; - - std::vector dfs_state(num_blocks, 0); - - std::function dfs = [&](size_t cur) { - dfs_state[cur] = 1; - const auto &insts = function.GetBlocks()[cur]->GetInstructions(); - for (const auto &inst : insts) - { - size_t succ = static_cast(-1); - if (inst.GetOpcode() == Opcode::Br && inst.GetOperands().size() >= 1 && - inst.GetOperands()[0].GetKind() == Operand::Kind::Label) - { - auto it = blk_label_to_idx.find(inst.GetOperands()[0].GetLabel()); - if (it != blk_label_to_idx.end()) succ = it->second; - } - if (inst.GetOpcode() == Opcode::CondBr && inst.GetOperands().size() >= 2 && - inst.GetOperands()[1].GetKind() == Operand::Kind::Label) - { - auto it = blk_label_to_idx.find(inst.GetOperands()[1].GetLabel()); - if (it != blk_label_to_idx.end()) succ = it->second; - } - if (succ == static_cast(-1)) continue; - if (dfs_state[succ] == 1) - { - loop_depth[succ]++; - loop_depth[cur]++; - } - else if (dfs_state[succ] == 0) - { - dfs(succ); - if (loop_depth[succ] > 0) - loop_depth[cur] = std::max(loop_depth[cur], loop_depth[succ]); - } - } - dfs_state[cur] = 2; - }; - dfs(0); - } - - std::unordered_map vreg_loop_depth; - for (size_t bi = 0; bi < num_blocks; ++bi) - { - const auto &insts = function.GetBlocks()[bi]->GetInstructions(); - for (const auto &inst : insts) - { - auto du = GetInstDefUse(inst, function); - for (int d : du.defs) - { - auto it = vreg_loop_depth.find(d); - int cur_depth = loop_depth[bi]; - if (it == vreg_loop_depth.end() || cur_depth > it->second) - vreg_loop_depth[d] = cur_depth; - } - } - } - - // 构建可再物化 vreg 集合 - std::set rematerializable_vregs; - for (const auto &pair : vreg_def_inst) - { - if (pair.second->IsRematerializable()) - rematerializable_vregs.insert(pair.first); - } - - // 收集 MovReg 的 move 偏好映射 - std::unordered_map move_preferences; - for (auto &block : function.GetBlocks()) - { - for (auto &inst : block->GetInstructions()) - { - if (inst.GetOpcode() == Opcode::MovReg) - { - const auto &ops = inst.GetOperands(); - if (ops.size() >= 2 && - ops[0].GetKind() == Operand::Kind::VReg && - ops[1].GetKind() == Operand::Kind::VReg) - { - int def_vreg = ops[0].GetVRegId(); - int use_vreg = ops[1].GetVRegId(); - if (function.GetVRegClass(def_vreg) == function.GetVRegClass(use_vreg)) - move_preferences[def_vreg] = use_vreg; - } - } - } - } - - // 检测并打破 move 偏好循环 - { - std::unordered_set visited; - for (auto &pair : move_preferences) - { - int cur = pair.first; - if (visited.count(cur)) continue; - std::unordered_set path; - std::vector chain; - int node = cur; - while (node != 0 && !visited.count(node)) - { - visited.insert(node); - path.insert(node); - chain.push_back(node); - auto it = move_preferences.find(node); - if (it == move_preferences.end()) break; - int next = it->second; - if (path.count(next)) - { - move_preferences.erase(chain.back()); - break; - } - node = next; - } - } - } - - int mv = static_cast(move_preferences.size()); - if (mv > 100 || static_cast(function.GetNumVRegs()) * mv > 600) - move_preferences.clear(); - + auto block_liveness = ComputeBlockLiveness(function); std::vector gp_alloc(GP_ALLOCATABLE, GP_ALLOCATABLE + GP_NUM_ALLOCATABLE); std::vector fp_alloc(FP_ALLOCATABLE, FP_ALLOCATABLE + FP_NUM_ALLOCATABLE); InterferenceGraph gp_graph, fp_graph; - BuildInterferenceForGP(function, liveness.block_liveness, gp_alloc, gp_graph); - BuildInterferenceForFP(function, liveness.block_liveness, fp_alloc, fp_graph); - auto gp_result = ColorGraph(gp_graph, gp_alloc, function, 19, - liveness.interval_length, liveness.ref_count, - rematerializable_vregs, move_preferences, - vreg_loop_depth); - auto fp_result = ColorGraph(fp_graph, fp_alloc, function, 16, - liveness.interval_length, liveness.ref_count, - rematerializable_vregs, move_preferences, - vreg_loop_depth); + BuildInterferenceForGP(function, block_liveness, gp_alloc, gp_graph); + BuildInterferenceForFP(function, block_liveness, fp_alloc, fp_graph); + auto gp_result = ColorGraph(gp_graph, gp_alloc, function); + auto fp_result = ColorGraph(fp_graph, fp_alloc, function); std::set all_spilled = gp_result.spilled; for (int v : fp_result.spilled) all_spilled.insert(v); @@ -1887,30 +1116,7 @@ namespace mir function.AddCalleeSavedReg(NumberToPhysReg(pair.second, VRegClass::Float)); } } - RewriteWithAllocation(function, gp_assign, fp_assign, all_spilled, vreg_def_inst, - &liveness.block_liveness); - - // 消除冗余 MovReg(源和目标分配到同一物理寄存器) - for (auto &block : function.GetBlocks()) - { - auto &insts = block->GetInstructions(); - std::vector filtered; - filtered.reserve(insts.size()); - for (auto &inst : insts) - { - if (inst.GetOpcode() == Opcode::MovReg) - { - const auto &ops = inst.GetOperands(); - if (ops.size() >= 2 && - ops[0].GetKind() == Operand::Kind::Reg && - ops[1].GetKind() == Operand::Kind::Reg && - ops[0].GetReg() == ops[1].GetReg()) - continue; - } - filtered.push_back(std::move(inst)); - } - insts = std::move(filtered); - } + RewriteWithAllocation(function, gp_assign, fp_assign, all_spilled); } } // namespace @@ -1930,4 +1136,3 @@ namespace mir } } // namespace mir -#endif diff --git a/src/mir/analysis/CFGAnalysis.cpp b/src/mir/analysis/CFGAnalysis.cpp new file mode 100644 index 00000000..0f235726 --- /dev/null +++ b/src/mir/analysis/CFGAnalysis.cpp @@ -0,0 +1,177 @@ +#include "mir/analysis/CFGAnalysis.h" + +#include "mir/MIR.h" + +namespace mir +{ + + namespace + { + + MachineBasicBlock *FindBlockByLabel(MachineFunction &function, + int label_id) + { + if (label_id < 0) + return nullptr; + for (auto &block : function.GetBlocks()) + { + if (block && block->GetLabelId() == label_id) + return block.get(); + } + return nullptr; + } + + void BuildSuccessors(MachineFunction &function, + CFGAnalysisResult &result) + { + for (auto &block : function.GetBlocks()) + { + if (!block) + continue; + const auto &insts = block->GetInstructions(); + if (insts.empty()) + continue; + + auto &succs = result.successors[block.get()]; + for (const auto &inst : insts) + { + if (inst.GetOpcode() == Opcode::Br) + { + const auto &ops = inst.GetOperands(); + if (!ops.empty() && ops[0].GetKind() == Operand::Kind::Label) + { + auto *target = FindBlockByLabel(function, ops[0].GetLabel()); + if (target) + { + bool dup = false; + for (auto *s : succs) + if (s == target) + { + dup = true; + break; + } + if (!dup) + succs.push_back(target); + } + } + } + else if (inst.GetOpcode() == Opcode::CondBr) + { + const auto &ops = inst.GetOperands(); + if (ops.size() >= 2 && ops[1].GetKind() == Operand::Kind::Label) + { + auto *target = FindBlockByLabel(function, ops[1].GetLabel()); + if (target) + { + bool dup = false; + for (auto *s : succs) + if (s == target) + { + dup = true; + break; + } + if (!dup) + succs.push_back(target); + } + } + } + } + } + } + + void BuildPredecessors(CFGAnalysisResult &result) + { + for (auto &kv : result.successors) + { + auto *src = kv.first; + for (auto *dst : kv.second) + { + result.predecessors[dst].push_back(src); + } + } + } + + void BuildEdges(CFGAnalysisResult &result) + { + for (auto &kv : result.successors) + { + auto *src = kv.first; + for (auto *dst : kv.second) + { + CFGEdge edge; + edge.src = src; + edge.dst = dst; + result.edges.push_back(edge); + } + } + } + + void EstimateBlockFrequencies(MachineFunction &function, + CFGAnalysisResult &result) + { + if (function.GetBlocks().empty()) + return; + + auto *entry = function.GetEntryPtr(); + if (!entry) + return; + + result.block_freq[entry] = 1.0; + for (auto &block : function.GetBlocks()) + { + if (block && block.get() != entry) + result.block_freq[block.get()] = 0.0; + } + + for (int iter = 0; iter < 20; ++iter) + { + for (auto &block : function.GetBlocks()) + { + if (!block) + continue; + auto it = result.successors.find(block.get()); + if (it == result.successors.end() || it->second.empty()) + continue; + + double freq = result.block_freq[block.get()]; + if (freq <= 0.0) + continue; + + double per_succ = freq / static_cast(it->second.size()); + for (auto *succ : it->second) + { + result.block_freq[succ] += per_succ; + } + } + } + } + + void ComputeEdgeWeights(CFGAnalysisResult &result) + { + for (auto &edge : result.edges) + { + auto it = result.successors.find(edge.src); + if (it == result.successors.end() || it->second.empty()) + continue; + double src_freq = 0.0; + auto fit = result.block_freq.find(edge.src); + if (fit != result.block_freq.end()) + src_freq = fit->second; + edge.weight = src_freq / static_cast(it->second.size()); + } + } + + } // namespace + + CFGAnalysisResult AnalyzeCFG(MachineFunction &function) + { + CFGAnalysisResult result; + BuildSuccessors(function, result); + BuildPredecessors(result); + BuildEdges(result); + EstimateBlockFrequencies(function, result); + ComputeEdgeWeights(result); + return result; + } + +} // namespace mir