|
|
|
|
@ -0,0 +1,195 @@
|
|
|
|
|
// 保守函数内联:
|
|
|
|
|
// - 自底向上迭代内联,每次只内联无调用(leaf)的单基本块函数
|
|
|
|
|
// - 内联后消除外层函数的 call,使其可能变为新 leaf,迭代至收敛
|
|
|
|
|
// - 每个函数内反复扫描直到清空所有可内联 call,4 轮即可收敛
|
|
|
|
|
|
|
|
|
|
#include "ir/IR.h"
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
namespace ir {
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
|
|
// 判断函数是否可被内联:单基本块、无函数调用(leaf)、无数组 alloca、以 Ret 结尾
|
|
|
|
|
bool IsInlineable(Function* func) {
|
|
|
|
|
if (func->IsExternal()) return false;
|
|
|
|
|
const auto& blocks = func->GetBlocks();
|
|
|
|
|
if (blocks.size() != 1) return false;
|
|
|
|
|
|
|
|
|
|
auto* bb = blocks[0].get();
|
|
|
|
|
const auto& insts = bb->GetInstructions();
|
|
|
|
|
if (insts.empty()) return false;
|
|
|
|
|
|
|
|
|
|
for (const auto& inst : insts) {
|
|
|
|
|
if (inst->GetOpcode() == Opcode::Call) return false;
|
|
|
|
|
// 只内联纯算术/逻辑函数,不内联含内存操作的函数
|
|
|
|
|
// Load/Store/GEP 的函数内联可能导致全局变量副作用顺序问题
|
|
|
|
|
if (inst->GetOpcode() == Opcode::Load ||
|
|
|
|
|
inst->GetOpcode() == Opcode::Store ||
|
|
|
|
|
inst->GetOpcode() == Opcode::GEP)
|
|
|
|
|
return false;
|
|
|
|
|
if (inst->GetOpcode() == Opcode::Alloca) {
|
|
|
|
|
auto* alloca = static_cast<AllocaInst*>(inst.get());
|
|
|
|
|
if (alloca->IsArrayAlloca()) return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return insts.back()->GetOpcode() == Opcode::Ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 克隆一条指令,通过 value_map 映射操作数
|
|
|
|
|
std::unique_ptr<Instruction> CloneInst(
|
|
|
|
|
Instruction* inst,
|
|
|
|
|
const std::unordered_map<Value*, Value*>& value_map,
|
|
|
|
|
Context& ctx) {
|
|
|
|
|
auto map = [&](Value* v) -> Value* {
|
|
|
|
|
auto it = value_map.find(v);
|
|
|
|
|
return (it != value_map.end()) ? it->second : v;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Opcode op = inst->GetOpcode();
|
|
|
|
|
|
|
|
|
|
switch (op) {
|
|
|
|
|
case Opcode::Add: case Opcode::Sub: case Opcode::Mul:
|
|
|
|
|
case Opcode::Div: case Opcode::Mod:
|
|
|
|
|
case Opcode::Eq: case Opcode::Ne: case Opcode::Lt:
|
|
|
|
|
case Opcode::Le: case Opcode::Gt: case Opcode::Ge: {
|
|
|
|
|
auto* bin = static_cast<BinaryInst*>(inst);
|
|
|
|
|
return std::make_unique<BinaryInst>(op, inst->GetType(),
|
|
|
|
|
map(bin->GetLhs()),
|
|
|
|
|
map(bin->GetRhs()),
|
|
|
|
|
ctx.NextTemp());
|
|
|
|
|
}
|
|
|
|
|
case Opcode::Load: {
|
|
|
|
|
auto* load = static_cast<LoadInst*>(inst);
|
|
|
|
|
return std::make_unique<LoadInst>(inst->GetType(), map(load->GetPtr()),
|
|
|
|
|
ctx.NextTemp());
|
|
|
|
|
}
|
|
|
|
|
case Opcode::Store: {
|
|
|
|
|
auto* store = static_cast<StoreInst*>(inst);
|
|
|
|
|
return std::make_unique<StoreInst>(inst->GetType(),
|
|
|
|
|
map(store->GetValue()),
|
|
|
|
|
map(store->GetPtr()));
|
|
|
|
|
}
|
|
|
|
|
case Opcode::GEP: {
|
|
|
|
|
auto* gep = static_cast<GetElementPtrInst*>(inst);
|
|
|
|
|
return std::make_unique<GetElementPtrInst>(
|
|
|
|
|
inst->GetType(), map(gep->GetBasePtr()), map(gep->GetIndex()),
|
|
|
|
|
ctx.NextTemp());
|
|
|
|
|
}
|
|
|
|
|
case Opcode::Call: {
|
|
|
|
|
auto* call = static_cast<CallInst*>(inst);
|
|
|
|
|
std::vector<Value*> new_args;
|
|
|
|
|
for (size_t i = 0; i < call->GetNumArgs(); ++i)
|
|
|
|
|
new_args.push_back(map(call->GetArg(i)));
|
|
|
|
|
return std::make_unique<CallInst>(inst->GetType(), call->GetCallee(),
|
|
|
|
|
new_args, ctx.NextTemp());
|
|
|
|
|
}
|
|
|
|
|
case Opcode::Alloca: {
|
|
|
|
|
auto* alloca = static_cast<AllocaInst*>(inst);
|
|
|
|
|
Value* count = alloca->GetCount();
|
|
|
|
|
return std::make_unique<AllocaInst>(alloca->GetElementType(),
|
|
|
|
|
ctx.NextTemp(),
|
|
|
|
|
count ? map(count) : nullptr);
|
|
|
|
|
}
|
|
|
|
|
case Opcode::SIToFP: case Opcode::FPToSI: case Opcode::ZExt: {
|
|
|
|
|
auto* cast = static_cast<CastInst*>(inst);
|
|
|
|
|
return std::make_unique<CastInst>(op, inst->GetType(),
|
|
|
|
|
map(cast->GetOperandValue()),
|
|
|
|
|
ctx.NextTemp());
|
|
|
|
|
}
|
|
|
|
|
default:
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 内联一次 call 调用。不分裂基本块,保持单 BB 结构以便级联内联。
|
|
|
|
|
bool InlineCall(CallInst* call, Function* callee, Context& ctx) {
|
|
|
|
|
BasicBlock* bb = call->GetParent();
|
|
|
|
|
Function* caller = bb->GetParent();
|
|
|
|
|
|
|
|
|
|
if (callee == caller) return false;
|
|
|
|
|
|
|
|
|
|
// 1. 构建值映射:被调用者参数 -> 调用实参
|
|
|
|
|
std::unordered_map<Value*, Value*> value_map;
|
|
|
|
|
const auto& params = callee->GetParams();
|
|
|
|
|
for (size_t i = 0; i < params.size(); ++i)
|
|
|
|
|
value_map[params[i].get()] = call->GetArg(i);
|
|
|
|
|
|
|
|
|
|
// 2. 克隆被调用者指令(Ret 除外),用 InsertBefore 插入到 call 之前
|
|
|
|
|
Value* ret_val = nullptr;
|
|
|
|
|
const auto& callee_insts = callee->GetEntry()->GetInstructions();
|
|
|
|
|
for (const auto& inst : callee_insts) {
|
|
|
|
|
if (inst->GetOpcode() == Opcode::Ret) {
|
|
|
|
|
auto* ret = static_cast<ReturnInst*>(inst.get());
|
|
|
|
|
if (ret->HasValue())
|
|
|
|
|
ret_val = value_map.count(ret->GetValue())
|
|
|
|
|
? value_map[ret->GetValue()]
|
|
|
|
|
: ret->GetValue();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
auto cloned = CloneInst(inst.get(), value_map, ctx);
|
|
|
|
|
if (!cloned) return false;
|
|
|
|
|
value_map[inst.get()] = cloned.get();
|
|
|
|
|
bb->InsertBefore(call, std::move(cloned));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 3. 替换 call 的使用并删除
|
|
|
|
|
if (ret_val) call->ReplaceAllUsesWith(ret_val);
|
|
|
|
|
call->ClearOperands(); // 清理操作数的 use 记录,防止悬空指针
|
|
|
|
|
bb->TakeInstruction(call);
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void RunInline(Module& module) {
|
|
|
|
|
int inlined = 0;
|
|
|
|
|
bool changed = true;
|
|
|
|
|
int round = 0;
|
|
|
|
|
const int kMaxRounds = 16;
|
|
|
|
|
|
|
|
|
|
while (changed && round < kMaxRounds) {
|
|
|
|
|
changed = false;
|
|
|
|
|
++round;
|
|
|
|
|
|
|
|
|
|
std::unordered_set<Function*> inlineable;
|
|
|
|
|
for (const auto& func : module.GetFunctions()) {
|
|
|
|
|
if (IsInlineable(func.get())) inlineable.insert(func.get());
|
|
|
|
|
}
|
|
|
|
|
if (inlineable.empty()) break;
|
|
|
|
|
|
|
|
|
|
// 每个函数内部反复扫描,直到没有可内联的 call 为止
|
|
|
|
|
for (const auto& func : module.GetFunctions()) {
|
|
|
|
|
if (func->IsExternal()) continue;
|
|
|
|
|
|
|
|
|
|
bool func_changed = true;
|
|
|
|
|
while (func_changed) {
|
|
|
|
|
func_changed = false;
|
|
|
|
|
for (const auto& bb : func->GetBlocks()) {
|
|
|
|
|
const auto& insts = bb->GetInstructions();
|
|
|
|
|
for (size_t i = 0; i < insts.size(); ++i) {
|
|
|
|
|
auto* inst = insts[i].get();
|
|
|
|
|
if (inst->GetOpcode() != Opcode::Call) continue;
|
|
|
|
|
auto* call = static_cast<CallInst*>(inst);
|
|
|
|
|
if (!inlineable.count(call->GetCallee())) continue;
|
|
|
|
|
if (!InlineCall(call, call->GetCallee(), module.GetContext()))
|
|
|
|
|
continue;
|
|
|
|
|
++inlined;
|
|
|
|
|
func_changed = true;
|
|
|
|
|
changed = true;
|
|
|
|
|
break; // 指令列表已修改,重新扫描当前函数
|
|
|
|
|
}
|
|
|
|
|
if (func_changed) break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace ir
|