diff --git a/src/include/ir/IR.h b/src/include/ir/IR.h index 36f33120..f49697b9 100644 --- a/src/include/ir/IR.h +++ b/src/include/ir/IR.h @@ -431,6 +431,7 @@ class BasicBlock : public Value { } std::unique_ptr TakeInstruction(Instruction* inst); void InsertInstructionBeforeTerminator(std::unique_ptr inst); + void InsertBefore(Instruction* before, std::unique_ptr inst); private: Function* parent_ = nullptr; diff --git a/src/include/ir/passes/PassManager.h b/src/include/ir/passes/PassManager.h index d67d23cf..22670ae9 100644 --- a/src/include/ir/passes/PassManager.h +++ b/src/include/ir/passes/PassManager.h @@ -10,6 +10,7 @@ namespace ir { void RunMem2Reg(Module& module); void RunLICM(Module* module); +void RunInline(Module& module); void RunConstFold(Module& module); void RunConstProp(Module& module); void RunDCE(Module& module); @@ -25,6 +26,8 @@ class PassManager { RunMem2Reg(*module); + RunInline(*module); + RunLICM(module); bool changed = true; diff --git a/src/ir/BasicBlock.cpp b/src/ir/BasicBlock.cpp index 6fd9042a..49a20152 100644 --- a/src/ir/BasicBlock.cpp +++ b/src/ir/BasicBlock.cpp @@ -66,4 +66,15 @@ void BasicBlock::InsertInstructionBeforeTerminator(std::unique_ptr instructions_.insert(instructions_.begin() + pos, std::move(inst)); } +void BasicBlock::InsertBefore(Instruction* before, + std::unique_ptr inst) { + for (auto it = instructions_.begin(); it != instructions_.end(); ++it) { + if (it->get() == before) { + inst->SetParent(this); + instructions_.insert(it, std::move(inst)); + return; + } + } +} + } // namespace ir diff --git a/src/ir/passes/CMakeLists.txt b/src/ir/passes/CMakeLists.txt index bf51491d..31318ab9 100644 --- a/src/ir/passes/CMakeLists.txt +++ b/src/ir/passes/CMakeLists.txt @@ -5,6 +5,7 @@ add_library(ir_passes STATIC CSE.cpp DCE.cpp CFGSimplify.cpp + Inline.cpp IRVerifier.cpp ) diff --git a/src/ir/passes/Inline.cpp b/src/ir/passes/Inline.cpp new file mode 100644 index 00000000..c80fec2f --- /dev/null +++ b/src/ir/passes/Inline.cpp @@ -0,0 +1,195 @@ +// 保守函数内联: +// - 自底向上迭代内联,每次只内联无调用(leaf)的单基本块函数 +// - 内联后消除外层函数的 call,使其可能变为新 leaf,迭代至收敛 +// - 每个函数内反复扫描直到清空所有可内联 call,4 轮即可收敛 + +#include "ir/IR.h" + +#include +#include +#include +#include + +namespace ir { + +namespace { + +// 判断函数是否可被内联:单基本块、无函数调用(leaf)、无数组 alloca、以 Ret 结尾 +bool IsInlineable(Function* func) { + if (func->IsExternal()) return false; + const auto& blocks = func->GetBlocks(); + if (blocks.size() != 1) return false; + + auto* bb = blocks[0].get(); + const auto& insts = bb->GetInstructions(); + if (insts.empty()) return false; + + for (const auto& inst : insts) { + if (inst->GetOpcode() == Opcode::Call) return false; + // 只内联纯算术/逻辑函数,不内联含内存操作的函数 + // Load/Store/GEP 的函数内联可能导致全局变量副作用顺序问题 + if (inst->GetOpcode() == Opcode::Load || + inst->GetOpcode() == Opcode::Store || + inst->GetOpcode() == Opcode::GEP) + return false; + if (inst->GetOpcode() == Opcode::Alloca) { + auto* alloca = static_cast(inst.get()); + if (alloca->IsArrayAlloca()) return false; + } + } + + return insts.back()->GetOpcode() == Opcode::Ret; +} + +// 克隆一条指令,通过 value_map 映射操作数 +std::unique_ptr CloneInst( + Instruction* inst, + const std::unordered_map& value_map, + Context& ctx) { + auto map = [&](Value* v) -> Value* { + auto it = value_map.find(v); + return (it != value_map.end()) ? it->second : v; + }; + + Opcode op = inst->GetOpcode(); + + switch (op) { + case Opcode::Add: case Opcode::Sub: case Opcode::Mul: + case Opcode::Div: case Opcode::Mod: + case Opcode::Eq: case Opcode::Ne: case Opcode::Lt: + case Opcode::Le: case Opcode::Gt: case Opcode::Ge: { + auto* bin = static_cast(inst); + return std::make_unique(op, inst->GetType(), + map(bin->GetLhs()), + map(bin->GetRhs()), + ctx.NextTemp()); + } + case Opcode::Load: { + auto* load = static_cast(inst); + return std::make_unique(inst->GetType(), map(load->GetPtr()), + ctx.NextTemp()); + } + case Opcode::Store: { + auto* store = static_cast(inst); + return std::make_unique(inst->GetType(), + map(store->GetValue()), + map(store->GetPtr())); + } + case Opcode::GEP: { + auto* gep = static_cast(inst); + return std::make_unique( + inst->GetType(), map(gep->GetBasePtr()), map(gep->GetIndex()), + ctx.NextTemp()); + } + case Opcode::Call: { + auto* call = static_cast(inst); + std::vector new_args; + for (size_t i = 0; i < call->GetNumArgs(); ++i) + new_args.push_back(map(call->GetArg(i))); + return std::make_unique(inst->GetType(), call->GetCallee(), + new_args, ctx.NextTemp()); + } + case Opcode::Alloca: { + auto* alloca = static_cast(inst); + Value* count = alloca->GetCount(); + return std::make_unique(alloca->GetElementType(), + ctx.NextTemp(), + count ? map(count) : nullptr); + } + case Opcode::SIToFP: case Opcode::FPToSI: case Opcode::ZExt: { + auto* cast = static_cast(inst); + return std::make_unique(op, inst->GetType(), + map(cast->GetOperandValue()), + ctx.NextTemp()); + } + default: + return nullptr; + } +} + +// 内联一次 call 调用。不分裂基本块,保持单 BB 结构以便级联内联。 +bool InlineCall(CallInst* call, Function* callee, Context& ctx) { + BasicBlock* bb = call->GetParent(); + Function* caller = bb->GetParent(); + + if (callee == caller) return false; + + // 1. 构建值映射:被调用者参数 -> 调用实参 + std::unordered_map value_map; + const auto& params = callee->GetParams(); + for (size_t i = 0; i < params.size(); ++i) + value_map[params[i].get()] = call->GetArg(i); + + // 2. 克隆被调用者指令(Ret 除外),用 InsertBefore 插入到 call 之前 + Value* ret_val = nullptr; + const auto& callee_insts = callee->GetEntry()->GetInstructions(); + for (const auto& inst : callee_insts) { + if (inst->GetOpcode() == Opcode::Ret) { + auto* ret = static_cast(inst.get()); + if (ret->HasValue()) + ret_val = value_map.count(ret->GetValue()) + ? value_map[ret->GetValue()] + : ret->GetValue(); + break; + } + auto cloned = CloneInst(inst.get(), value_map, ctx); + if (!cloned) return false; + value_map[inst.get()] = cloned.get(); + bb->InsertBefore(call, std::move(cloned)); + } + + // 3. 替换 call 的使用并删除 + if (ret_val) call->ReplaceAllUsesWith(ret_val); + call->ClearOperands(); // 清理操作数的 use 记录,防止悬空指针 + bb->TakeInstruction(call); + + return true; +} + +} // namespace + +void RunInline(Module& module) { + int inlined = 0; + bool changed = true; + int round = 0; + const int kMaxRounds = 16; + + while (changed && round < kMaxRounds) { + changed = false; + ++round; + + std::unordered_set inlineable; + for (const auto& func : module.GetFunctions()) { + if (IsInlineable(func.get())) inlineable.insert(func.get()); + } + if (inlineable.empty()) break; + + // 每个函数内部反复扫描,直到没有可内联的 call 为止 + for (const auto& func : module.GetFunctions()) { + if (func->IsExternal()) continue; + + bool func_changed = true; + while (func_changed) { + func_changed = false; + for (const auto& bb : func->GetBlocks()) { + const auto& insts = bb->GetInstructions(); + for (size_t i = 0; i < insts.size(); ++i) { + auto* inst = insts[i].get(); + if (inst->GetOpcode() != Opcode::Call) continue; + auto* call = static_cast(inst); + if (!inlineable.count(call->GetCallee())) continue; + if (!InlineCall(call, call->GetCallee(), module.GetContext())) + continue; + ++inlined; + func_changed = true; + changed = true; + break; // 指令列表已修改,重新扫描当前函数 + } + if (func_changed) break; + } + } + } + } +} + +} // namespace ir