diff --git a/src/include/ir/IR.h b/src/include/ir/IR.h index 87a35e0e..61a2929c 100644 --- a/src/include/ir/IR.h +++ b/src/include/ir/IR.h @@ -422,6 +422,12 @@ class BasicBlock : public Value { void RemoveInstruction(Instruction* inst) { for (auto it = instructions_.begin(); it != instructions_.end(); ++it) { if (it->get() == inst) { + for (size_t i = 0; i < inst->GetNumOperands(); ++i) { + auto* op = inst->GetOperand(i); + if (auto* op_inst = dynamic_cast(op)) { + op_inst->RemoveUse(inst, i); + } + } instructions_.erase(it); break; } diff --git a/src/include/ir/passes/PassManager.h b/src/include/ir/passes/PassManager.h index 286389e5..a4915d41 100644 --- a/src/include/ir/passes/PassManager.h +++ b/src/include/ir/passes/PassManager.h @@ -9,6 +9,7 @@ namespace ir { void RunMem2Reg(Module& module); +void RunInline(Module* module); void RunLICM(Module* module); void RunConstFold(Module& module); void RunConstProp(Module& module); @@ -26,27 +27,40 @@ class PassManagerModule { } RunMem2Reg(*module_); + RunInline(module_); - RunLICM(module_); + bool inline_changed = true; + int inline_rounds = 3; + while (inline_changed && inline_rounds > 0) { + inline_changed = false; + inline_rounds--; - bool changed = true; - int max_iterations = 10; - int iterations = 0; + RunLICM(module_); - while (changed && iterations < max_iterations) { - changed = false; - iterations++; + bool changed = true; + int max_iterations = 10; + int iterations = 0; - auto before = SerializeModule(*module_); + while (changed && iterations < max_iterations) { + changed = false; + iterations++; - RunConstFold(*module_); - RunConstProp(*module_); - RunCFGSimplify(*module_); - RunCSE(*module_); - RunDCE(*module_); + auto before = SerializeModule(*module_); - auto after = SerializeModule(*module_); - changed = (before != after); + RunConstFold(*module_); + RunConstProp(*module_); + RunCFGSimplify(*module_); + RunCSE(*module_); + RunDCE(*module_); + + auto after = SerializeModule(*module_); + changed = (before != after); + } + + auto before_inline = SerializeModule(*module_); + RunInline(module_); + auto after_inline = SerializeModule(*module_); + inline_changed = (before_inline != after_inline); } } @@ -70,12 +84,29 @@ class PassManager { RunMem2Reg(*module); - RunConstFold(*module); - RunDCE(*module); - RunCFGSimplify(*module); + for (int round = 0; round < 3; ++round) { + RunInline(module); + RunMem2Reg(*module); + + RunLICM(module); + + for (int i = 0; i < 10; ++i) { + RunConstFold(*module); + RunConstProp(*module); + RunCFGSimplify(*module); + RunCSE(*module); + RunDCE(*module); + } + } } private: + std::string SerializeModule(const Module& module) { + std::ostringstream oss; + IRPrinter printer; + printer.Print(module, oss); + return oss.str(); + } }; } // namespace ir diff --git a/src/ir/passes/CFGSimplify.cpp b/src/ir/passes/CFGSimplify.cpp index 76fbc24c..413726c3 100644 --- a/src/ir/passes/CFGSimplify.cpp +++ b/src/ir/passes/CFGSimplify.cpp @@ -13,7 +13,44 @@ namespace ir { namespace { -// 从入口块开始进行 BFS/DFS,标记所有可达的基本块 +void RemoveUsesOfInst(Instruction* inst) { + for (size_t i = 0; i < inst->GetNumOperands(); ++i) { + auto* op = inst->GetOperand(i); + if (auto* op_inst = dynamic_cast(op)) { + op_inst->RemoveUse(inst, i); + } + } +} + +void SafeEraseInstructions(std::vector>& insts, + const std::vector& to_delete) { + for (auto* inst : to_delete) { + RemoveUsesOfInst(inst); + } + auto new_end = std::remove_if(insts.begin(), insts.end(), + [&to_delete](const std::unique_ptr& inst_ptr) { + return std::find(to_delete.begin(), to_delete.end(), inst_ptr.get()) != to_delete.end(); + } + ); + insts.erase(new_end, insts.end()); +} + +void SafeEraseBlock(std::vector>& blocks, + const std::unordered_set& to_remove) { + for (auto& bb_ptr : blocks) { + if (to_remove.find(bb_ptr.get()) == to_remove.end()) continue; + for (auto& inst_ptr : bb_ptr->GetInstructions()) { + RemoveUsesOfInst(inst_ptr.get()); + } + } + auto new_end = std::remove_if(blocks.begin(), blocks.end(), + [&to_remove](const std::unique_ptr& bb_ptr) { + return to_remove.find(bb_ptr.get()) != to_remove.end(); + } + ); + blocks.erase(new_end, blocks.end()); +} + std::unordered_set FindReachableBlocks(Function* func) { std::unordered_set reachable; std::vector worklist; @@ -92,7 +129,7 @@ void RunCFGSimplify(Module& module) { if (unreachable.find(bb) != unreachable.end()) continue; std::vector> phi_replacements; - std::vector phi_to_delete; + std::vector phi_to_delete; for (auto& inst_ptr : bb->GetInstructions()) { auto* phi = dynamic_cast(inst_ptr.get()); @@ -123,23 +160,11 @@ void RunCFGSimplify(Module& module) { } auto& insts = const_cast>&>(bb->GetInstructions()); - auto new_end = std::remove_if(insts.begin(), insts.end(), - [&phi_to_delete](const std::unique_ptr& inst_ptr) { - return std::find(phi_to_delete.begin(), phi_to_delete.end(), inst_ptr.get()) != phi_to_delete.end(); - } - ); - insts.erase(new_end, insts.end()); + SafeEraseInstructions(insts, phi_to_delete); } size_t old_size = blocks.size(); - blocks.erase( - std::remove_if(blocks.begin(), blocks.end(), - [&reachable](const std::unique_ptr& bb_ptr) { - return reachable.find(bb_ptr.get()) == reachable.end(); - } - ), - blocks.end() - ); + SafeEraseBlock(blocks, unreachable); if (blocks.size() != old_size) { changed = true; } @@ -187,13 +212,15 @@ void RunCFGSimplify(Module& module) { phi->ReplaceAllUsesWith(val); } + std::vector phi_to_delete; + for (auto& inst_ptr : bb->GetInstructions()) { + auto* phi = dynamic_cast(inst_ptr.get()); + if (!phi) break; + phi_to_delete.push_back(phi); + } + auto& insts = const_cast>&>(bb->GetInstructions()); - auto new_end = std::remove_if(insts.begin(), insts.end(), - [](const std::unique_ptr& inst_ptr) { - return dynamic_cast(inst_ptr.get()) != nullptr; - } - ); - insts.erase(new_end, insts.end()); + SafeEraseInstructions(insts, phi_to_delete); } } } diff --git a/src/ir/passes/CMakeLists.txt b/src/ir/passes/CMakeLists.txt index ffd5cf47..92448474 100644 --- a/src/ir/passes/CMakeLists.txt +++ b/src/ir/passes/CMakeLists.txt @@ -1,6 +1,7 @@ add_library(ir_passes STATIC PassManager.cpp Mem2Reg.cpp + Inline.cpp LICM.cpp ConstFold.cpp ConstProp.cpp diff --git a/src/ir/passes/Inline.cpp b/src/ir/passes/Inline.cpp new file mode 100644 index 00000000..f396ec09 --- /dev/null +++ b/src/ir/passes/Inline.cpp @@ -0,0 +1,621 @@ +#include "ir/IR.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ir { + +namespace { + +constexpr bool kDebugInline = false; +constexpr int kMaxInlineSize = 200; +constexpr int kMaxMultiBlockInlineSize = 50; + +bool IsRecursive(Function* func) { + if (!func || func->IsExternal()) return true; + for (auto& bb : func->GetBlocks()) { + for (auto& inst : bb->GetInstructions()) { + if (auto* call = dynamic_cast(inst.get())) { + if (call->GetCallee() == func) return true; + } + } + } + return false; +} + +int CountInstructions(Function* func) { + int count = 0; + for (auto& bb : func->GetBlocks()) { + count += bb->GetInstructions().size(); + } + return count; +} + +Value* MapValue(Value* v, const std::unordered_map& value_map) { + auto it = value_map.find(v); + if (it != value_map.end()) return it->second; + return v; +} + +void CloneInstruction(Instruction* inst, + const std::unordered_map& value_map, + std::vector>& out) { + std::unique_ptr cloned; + + switch (inst->GetOpcode()) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Div: + case Opcode::Mod: + case Opcode::Eq: + case Opcode::Ne: + case Opcode::Lt: + case Opcode::Le: + case Opcode::Gt: + case Opcode::Ge: { + auto* bin = static_cast(inst); + Value* lhs = MapValue(bin->GetLhs(), value_map); + Value* rhs = MapValue(bin->GetRhs(), value_map); + cloned = std::make_unique( + inst->GetOpcode(), inst->GetType(), lhs, rhs, + inst->GetName() + ".inl"); + break; + } + case Opcode::SIToFP: + case Opcode::FPToSI: + case Opcode::ZExt: { + auto* cast = static_cast(inst); + Value* operand = MapValue(cast->GetOperandValue(), value_map); + cloned = std::make_unique( + inst->GetOpcode(), inst->GetType(), operand, + inst->GetName() + ".inl"); + break; + } + case Opcode::Load: { + auto* load = static_cast(inst); + Value* ptr = MapValue(load->GetPtr(), value_map); + cloned = std::make_unique( + load->GetType(), ptr, inst->GetName() + ".inl"); + break; + } + case Opcode::Store: { + auto* store = static_cast(inst); + Value* val = MapValue(store->GetValue(), value_map); + Value* ptr = MapValue(store->GetPtr(), value_map); + cloned = std::make_unique(Type::GetVoidType(), val, ptr); + break; + } + case Opcode::GEP: { + auto* gep = static_cast(inst); + Value* base = MapValue(gep->GetBasePtr(), value_map); + Value* index = MapValue(gep->GetIndex(), value_map); + cloned = std::make_unique( + gep->GetType(), base, index, inst->GetName() + ".inl"); + break; + } + case Opcode::Call: { + auto* orig_call = static_cast(inst); + Function* callee_func = orig_call->GetCallee(); + std::vector args; + for (size_t i = 0; i < orig_call->GetNumArgs(); ++i) { + args.push_back(MapValue(orig_call->GetArg(i), value_map)); + } + cloned = std::make_unique( + orig_call->GetType(), callee_func, args, + inst->GetName() + ".inl"); + break; + } + case Opcode::Alloca: { + auto* alloca_inst = static_cast(inst); + if (alloca_inst->IsArrayAlloca()) { + Value* count = MapValue(alloca_inst->GetCount(), value_map); + cloned = std::make_unique( + alloca_inst->GetElementType(), + alloca_inst->GetName() + ".inl", count); + } else { + cloned = std::make_unique( + alloca_inst->GetElementType(), + alloca_inst->GetName() + ".inl"); + } + break; + } + default: + break; + } + + if (cloned) { + out.push_back(std::move(cloned)); + } +} + +bool InlineCall(CallInst* call, Function* callee, Function* caller, + BasicBlock* call_bb, Module* module) { + if (kDebugInline) { + std::cerr << "[Inline] Inlining " << callee->GetName() + << " (" << callee->GetBlocks().size() << " blocks)" + << " into " << caller->GetName() << std::endl; + } + + bool is_single_block = (callee->GetBlocks().size() == 1); + + std::unordered_map value_map; + + for (auto& gv : module->GetGlobals()) { + value_map[gv.get()] = gv.get(); + } + for (auto& other_func : module->GetFunctions()) { + value_map[other_func.get()] = other_func.get(); + } + for (auto& arg : caller->GetParams()) { + value_map[arg.get()] = arg.get(); + } + { + auto& blocks = caller->GetBlocks(); + for (size_t bi = 0; bi < blocks.size(); ++bi) { + auto& insts = blocks[bi]->GetInstructions(); + for (size_t ii = 0; ii < insts.size(); ++ii) { + value_map[insts[ii].get()] = insts[ii].get(); + } + } + } + + for (size_t i = 0; i < callee->GetParams().size(); ++i) { + auto* formal_arg = callee->GetParams()[i].get(); + auto* actual_arg = call->GetArg(i); + value_map[formal_arg] = actual_arg; + } + + auto& call_bb_insts = const_cast>&>( + call_bb->GetInstructions()); + + size_t call_idx = 0; + for (size_t i = 0; i < call_bb_insts.size(); ++i) { + if (call_bb_insts[i].get() == call) { + call_idx = i; + break; + } + } + + if (is_single_block) { + auto* callee_entry = callee->GetEntry(); + Value* return_value = nullptr; + + std::vector> cloned_insts; + std::vector> alloca_insts; + + for (auto& inst : callee_entry->GetInstructions()) { + if (inst->GetOpcode() == Opcode::Alloca) { + std::vector> tmp; + CloneInstruction(inst.get(), value_map, tmp); + if (!tmp.empty()) { + value_map[inst.get()] = tmp.back().get(); + alloca_insts.push_back(std::move(tmp.back())); + } + continue; + } + + if (inst->GetOpcode() == Opcode::Ret) { + auto* ret_inst = static_cast(inst.get()); + if (ret_inst->HasValue()) { + return_value = MapValue(ret_inst->GetValue(), value_map); + } + continue; + } + + std::vector> tmp; + CloneInstruction(inst.get(), value_map, tmp); + if (!tmp.empty()) { + value_map[inst.get()] = tmp.back().get(); + cloned_insts.push_back(std::move(tmp.back())); + } + } + + if (return_value) { + call->ReplaceAllUsesWith(return_value); + } else if (!call->GetType()->IsVoid()) { + call->ReplaceAllUsesWith(module->GetContext().GetConstInt(0)); + } + + auto* entry_bb = caller->GetEntry(); + auto& entry_insts = const_cast>&>( + entry_bb->GetInstructions()); + size_t alloca_insert_pos = 0; + for (size_t i = 0; i < entry_insts.size(); ++i) { + if (entry_insts[i]->GetOpcode() == Opcode::Alloca) { + alloca_insert_pos = i + 1; + } else { + break; + } + } + for (auto& alloca : alloca_insts) { + alloca->SetParent(entry_bb); + entry_insts.insert(entry_insts.begin() + alloca_insert_pos, std::move(alloca)); + alloca_insert_pos++; + } + + size_t insert_pos = call_idx; + for (auto& cloned : cloned_insts) { + cloned->SetParent(call_bb); + call_bb_insts.insert(call_bb_insts.begin() + insert_pos, std::move(cloned)); + insert_pos++; + } + + for (size_t i = 0; i < call_bb_insts.size(); ++i) { + if (call_bb_insts[i].get() == call) { + for (size_t oi = 0; oi < call->GetNumOperands(); ++oi) { + auto* op = call->GetOperand(oi); + if (auto* op_inst = dynamic_cast(op)) { + op_inst->RemoveUse(call, oi); + } + } + call_bb_insts.erase(call_bb_insts.begin() + i); + break; + } + } + + return true; + } + + // === Multi-block inlining === + + // 1. Create after_bb: move instructions after call from call_bb to after_bb + BasicBlock* after_bb = caller->CreateBlock(call_bb->GetName() + ".after"); + + std::vector> after_insts; + for (size_t i = call_idx + 1; i < call_bb_insts.size(); ++i) { + after_insts.push_back(std::move(call_bb_insts[i])); + } + call_bb_insts.resize(call_idx + 1); + + for (auto& inst : after_insts) { + inst->SetParent(after_bb); + after_bb->GetMutablePredecessors(); + } + auto& after_bb_insts = const_cast>&>( + after_bb->GetInstructions()); + for (auto& inst : after_insts) { + after_bb_insts.push_back(std::move(inst)); + } + + // 1b. Fix phi nodes: any phi that had call_bb as predecessor should now use after_bb + for (auto& bb : caller->GetBlocks()) { + for (auto& inst : bb->GetInstructions()) { + if (inst->GetOpcode() != Opcode::Phi) break; + auto* phi = static_cast(inst.get()); + size_t num_ops = phi->GetNumOperands(); + for (size_t i = 0; i + 1 < num_ops; i += 2) { + auto* bb_ptr = dynamic_cast(phi->GetOperand(i + 1)); + if (bb_ptr == call_bb) { + phi->SetOperand(i + 1, after_bb); + } + } + } + } + + // 2. Create cloned blocks for callee + std::unordered_map bb_map; + std::vector cloned_bbs; + for (auto& bb : callee->GetBlocks()) { + BasicBlock* cloned_bb = caller->CreateBlock(bb->GetName() + ".inl"); + bb_map[bb.get()] = cloned_bb; + cloned_bbs.push_back(cloned_bb); + } + BasicBlock* cloned_entry = bb_map[callee->GetEntry()]; + + // 2b. Reorder blocks: move cloned blocks and after_bb right after call_bb + // IMPORTANT: after_bb must come AFTER all cloned blocks, because + // after_bb may use values defined in the cloned blocks (e.g., call results + // from nested inlines). The lowering processes blocks in order, so values + // must be defined before they are used. + { + auto& blocks = const_cast>&>(caller->GetBlocks()); + std::vector move_indices; + for (auto* cb : cloned_bbs) { + for (size_t i = 0; i < blocks.size(); ++i) { + if (blocks[i].get() == cb) { move_indices.push_back(i); break; } + } + } + for (size_t i = 0; i < blocks.size(); ++i) { + if (blocks[i].get() == after_bb) { move_indices.push_back(i); break; } + } + + size_t call_bb_idx = 0; + for (size_t i = 0; i < blocks.size(); ++i) { + if (blocks[i].get() == call_bb) { call_bb_idx = i; break; } + } + + std::vector> extracted; + for (auto idx : move_indices) { + extracted.push_back(std::move(blocks[idx])); + } + + size_t insert_pos = call_bb_idx + 1; + for (auto& b : extracted) { + blocks.insert(blocks.begin() + insert_pos, std::move(b)); + insert_pos++; + } + + blocks.erase(std::remove_if(blocks.begin(), blocks.end(), + [](const std::unique_ptr& b) { return b == nullptr; }), + blocks.end()); + } + + // 4. Create alloca for return value (if non-void) + AllocaInst* ret_alloca = nullptr; + bool has_return = !call->GetType()->IsVoid(); + if (has_return) { + auto* entry_bb = caller->GetEntry(); + auto& entry_insts = const_cast>&>( + entry_bb->GetInstructions()); + auto alloca = std::make_unique(call->GetType(), "__ret.inl"); + alloca->SetParent(entry_bb); + ret_alloca = static_cast(alloca.get()); + size_t alloca_insert_pos = 0; + for (size_t i = 0; i < entry_insts.size(); ++i) { + if (entry_insts[i]->GetOpcode() == Opcode::Alloca) { + alloca_insert_pos = i + 1; + } else { + break; + } + } + entry_insts.insert(entry_insts.begin() + alloca_insert_pos, std::move(alloca)); + } + + // 5. Clone all instructions from callee blocks into cloned blocks + // Pass 1: Create cloned instructions with original operands, build value_map + std::vector> alloca_insts; + std::vector> remap_list; + + for (auto& bb : callee->GetBlocks()) { + BasicBlock* cloned_bb = bb_map[bb.get()]; + auto& cloned_insts = const_cast>&>( + cloned_bb->GetInstructions()); + + for (auto& inst : bb->GetInstructions()) { + if (inst->GetOpcode() == Opcode::Alloca) { + std::vector> tmp; + CloneInstruction(inst.get(), value_map, tmp); + if (!tmp.empty()) { + value_map[inst.get()] = tmp.back().get(); + alloca_insts.push_back(std::move(tmp.back())); + } + continue; + } + + if (inst->GetOpcode() == Opcode::Phi) { + auto* phi = static_cast(inst.get()); + auto new_phi = std::make_unique(phi->GetType(), phi->GetName() + ".inl"); + new_phi->SetParent(cloned_bb); + value_map[inst.get()] = new_phi.get(); + cloned_insts.push_back(std::move(new_phi)); + continue; + } + + if (inst->IsTerminator()) continue; + + std::vector> tmp; + CloneInstruction(inst.get(), value_map, tmp); + if (!tmp.empty()) { + tmp.back()->SetParent(cloned_bb); + value_map[inst.get()] = tmp.back().get(); + remap_list.push_back({inst.get(), tmp.back().get()}); + cloned_insts.push_back(std::move(tmp.back())); + } + } + } + + // Pass 1b: Remap operands of cloned instructions now that value_map is complete + for (auto& [orig, cloned] : remap_list) { + for (size_t i = 0; i < orig->GetNumOperands(); ++i) { + Value* orig_op = orig->GetOperand(i); + Value* mapped = MapValue(orig_op, value_map); + if (mapped != orig_op) { + cloned->SetOperand(i, mapped); + } + } + } + + // Pass 2: fill phi operands and handle terminators + for (auto& bb : callee->GetBlocks()) { + BasicBlock* cloned_bb = bb_map[bb.get()]; + auto& cloned_insts = const_cast>&>( + cloned_bb->GetInstructions()); + + for (auto& inst : bb->GetInstructions()) { + if (inst->GetOpcode() == Opcode::Phi) { + auto* orig_phi = static_cast(inst.get()); + auto* cloned_phi = static_cast(value_map[orig_phi]); + if (!cloned_phi) continue; + + for (size_t i = 0; i < orig_phi->GetNumOperands(); i += 2) { + Value* val = MapValue(orig_phi->GetOperand(i), value_map); + auto* orig_pred = static_cast(orig_phi->GetOperand(i + 1)); + auto pred_it = bb_map.find(orig_pred); + BasicBlock* pred = (pred_it != bb_map.end()) ? pred_it->second : orig_pred; + cloned_phi->AddOperand(val); + cloned_phi->AddOperand(pred); + } + continue; + } + + if (inst->GetOpcode() == Opcode::Ret) { + auto* ret_inst = static_cast(inst.get()); + if (ret_inst->HasValue() && has_return) { + Value* ret_val = MapValue(ret_inst->GetValue(), value_map); + auto store = std::make_unique( + Type::GetVoidType(), ret_val, ret_alloca); + store->SetParent(cloned_bb); + cloned_insts.push_back(std::move(store)); + } + auto br = std::make_unique(Type::GetVoidType(), after_bb); + br->SetParent(cloned_bb); + cloned_insts.push_back(std::move(br)); + continue; + } + + if (inst->GetOpcode() == Opcode::Br) { + auto* br = static_cast(inst.get()); + auto it = bb_map.find(br->GetTarget()); + BasicBlock* target = (it != bb_map.end()) ? it->second : br->GetTarget(); + auto new_br = std::make_unique(Type::GetVoidType(), target); + new_br->SetParent(cloned_bb); + cloned_insts.push_back(std::move(new_br)); + continue; + } + + if (inst->GetOpcode() == Opcode::CondBr) { + auto* cbr = static_cast(inst.get()); + Value* cond = MapValue(cbr->GetCond(), value_map); + auto true_it = bb_map.find(cbr->GetTrueTarget()); + BasicBlock* true_target = (true_it != bb_map.end()) ? true_it->second : cbr->GetTrueTarget(); + auto false_it = bb_map.find(cbr->GetFalseTarget()); + BasicBlock* false_target = (false_it != bb_map.end()) ? false_it->second : cbr->GetFalseTarget(); + auto new_cbr = std::make_unique( + Type::GetVoidType(), cond, true_target, false_target); + new_cbr->SetParent(cloned_bb); + cloned_insts.push_back(std::move(new_cbr)); + continue; + } + } + } + + // 7. Insert alloca_insts into caller entry + { + auto* entry_bb = caller->GetEntry(); + auto& entry_insts = const_cast>&>( + entry_bb->GetInstructions()); + size_t alloca_insert_pos = 0; + for (size_t i = 0; i < entry_insts.size(); ++i) { + if (entry_insts[i]->GetOpcode() == Opcode::Alloca) { + alloca_insert_pos = i + 1; + } else { + break; + } + } + for (auto& alloca : alloca_insts) { + alloca->SetParent(entry_bb); + entry_insts.insert(entry_insts.begin() + alloca_insert_pos, std::move(alloca)); + alloca_insert_pos++; + } + } + + // 8-9. Handle return value and remove call + auto call_type = call->GetType(); + + if (has_return) { + auto load_ret = std::make_unique( + call_type, ret_alloca, "__ret.load.inl"); + load_ret->SetParent(after_bb); + Value* ret_val = load_ret.get(); + after_bb_insts.insert(after_bb_insts.begin(), std::move(load_ret)); + + call->ReplaceAllUsesWith(ret_val); + } else { + call->ReplaceAllUsesWith(module->GetContext().GetConstInt(0)); + } + + // Remove the call and add branch to cloned_entry + for (size_t i = 0; i < call_bb_insts.size(); ++i) { + if (call_bb_insts[i].get() == call) { + for (size_t oi = 0; oi < call->GetNumOperands(); ++oi) { + auto* op = call->GetOperand(oi); + if (auto* op_inst = dynamic_cast(op)) { + op_inst->RemoveUse(call, oi); + } + } + call_bb_insts.erase(call_bb_insts.begin() + i); + break; + } + } + + auto br_to_entry = std::make_unique(Type::GetVoidType(), cloned_entry); + br_to_entry->SetParent(call_bb); + call_bb_insts.push_back(std::move(br_to_entry)); + + if (kDebugInline) { + std::cerr << "[Inline] Done inlining " << callee->GetName() << std::endl; + } + + return true; +} + +} // namespace + +void RunInline(Module* module) { + if (!module) return; + + std::unordered_map func_sizes; + std::unordered_set recursive_funcs; + + for (auto& func : module->GetFunctions()) { + if (func->IsExternal()) continue; + func_sizes[func->GetName()] = CountInstructions(func.get()); + if (IsRecursive(func.get())) { + recursive_funcs.insert(func->GetName()); + } + } + + struct InlineSite { + CallInst* call; + Function* caller; + BasicBlock* call_bb; + }; + + std::vector inline_sites; + + for (auto& caller : module->GetFunctions()) { + if (caller->IsExternal()) continue; + + for (auto& bb : caller->GetBlocks()) { + for (auto& inst : bb->GetInstructions()) { + auto* call = dynamic_cast(inst.get()); + if (!call) continue; + + auto* callee = call->GetCallee(); + if (!callee) continue; + if (callee->IsExternal()) continue; + if (recursive_funcs.count(callee->GetName())) continue; + + auto size_it = func_sizes.find(callee->GetName()); + int callee_size = (size_it != func_sizes.end()) ? size_it->second : 9999; + + if (callee_size > kMaxInlineSize) continue; + if (callee == caller.get()) continue; + if (callee->GetBlocks().size() > 1 && callee_size > kMaxMultiBlockInlineSize) continue; + + inline_sites.push_back({call, caller.get(), bb.get()}); + } + } + } + + for (auto& site : inline_sites) { + auto* callee = site.call->GetCallee(); + if (!callee) continue; + + bool still_valid = false; + BasicBlock* actual_bb = nullptr; + for (auto& bb : site.caller->GetBlocks()) { + for (auto& inst : bb->GetInstructions()) { + if (inst.get() == site.call) { + still_valid = true; + actual_bb = bb.get(); + break; + } + } + if (still_valid) break; + } + if (!still_valid) continue; + + InlineCall(site.call, callee, site.caller, actual_bb, module); + } +} + +} // namespace ir diff --git a/src/ir/passes/Mem2Reg.cpp b/src/ir/passes/Mem2Reg.cpp index c004ba63..29f68817 100644 --- a/src/ir/passes/Mem2Reg.cpp +++ b/src/ir/passes/Mem2Reg.cpp @@ -738,7 +738,7 @@ void RunMem2Reg(Module& module) { // PHI 节点在 llc -O0 下会生成 StoreStack 操作,可能导致性能下降 // 阈值设置:基本块数量的 1/4,最小 10,最大 30 int block_count = func->GetBlocks().size(); - int phi_threshold = std::max(50, block_count); + int phi_threshold = std::max(2000, block_count*20); if (total_phi_count > phi_threshold) { if (kDebugMem2Reg) { std::cerr << "[Mem2Reg] Skipping function " << func->GetName()