forked from NUDT-compiler/nudt-compiler-cpp
- 新增 Inline.cpp 实现多基本块函数内联 - 支持分裂调用点基本块 (call_bb -> call_bb + after_bb) - 使用 alloca+store+load 模式处理多返回点 - 正确处理 phi 节点前驱更新 - 块重排确保 lowering 顺序正确 - 修复 CFGSimplify 中 use-def 链维护问题 - 删除指令前先清理操作数的 use 条目 - 添加 SafeEraseInstructions/SafeEraseBlock 辅助函数 - 修复 BasicBlock::RemoveInstruction 未清理 use-def 链 - PassManager 支持多轮内联 (3轮) 测试结果: 功能 100/100, h_functional 40/40, 性能 60/60 全部通过 Huffman 性能提升约 4-5%zhm
parent
9f6459c00f
commit
a669efb7a5
@ -0,0 +1,621 @@
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr bool kDebugInline = false;
|
||||
constexpr int kMaxInlineSize = 200;
|
||||
constexpr int kMaxMultiBlockInlineSize = 50;
|
||||
|
||||
bool IsRecursive(Function* func) {
|
||||
if (!func || func->IsExternal()) return true;
|
||||
for (auto& bb : func->GetBlocks()) {
|
||||
for (auto& inst : bb->GetInstructions()) {
|
||||
if (auto* call = dynamic_cast<CallInst*>(inst.get())) {
|
||||
if (call->GetCallee() == func) return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int CountInstructions(Function* func) {
|
||||
int count = 0;
|
||||
for (auto& bb : func->GetBlocks()) {
|
||||
count += bb->GetInstructions().size();
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
Value* MapValue(Value* v, const std::unordered_map<Value*, Value*>& value_map) {
|
||||
auto it = value_map.find(v);
|
||||
if (it != value_map.end()) return it->second;
|
||||
return v;
|
||||
}
|
||||
|
||||
void CloneInstruction(Instruction* inst,
|
||||
const std::unordered_map<Value*, Value*>& value_map,
|
||||
std::vector<std::unique_ptr<Instruction>>& out) {
|
||||
std::unique_ptr<Instruction> cloned;
|
||||
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Mod:
|
||||
case Opcode::Eq:
|
||||
case Opcode::Ne:
|
||||
case Opcode::Lt:
|
||||
case Opcode::Le:
|
||||
case Opcode::Gt:
|
||||
case Opcode::Ge: {
|
||||
auto* bin = static_cast<BinaryInst*>(inst);
|
||||
Value* lhs = MapValue(bin->GetLhs(), value_map);
|
||||
Value* rhs = MapValue(bin->GetRhs(), value_map);
|
||||
cloned = std::make_unique<BinaryInst>(
|
||||
inst->GetOpcode(), inst->GetType(), lhs, rhs,
|
||||
inst->GetName() + ".inl");
|
||||
break;
|
||||
}
|
||||
case Opcode::SIToFP:
|
||||
case Opcode::FPToSI:
|
||||
case Opcode::ZExt: {
|
||||
auto* cast = static_cast<CastInst*>(inst);
|
||||
Value* operand = MapValue(cast->GetOperandValue(), value_map);
|
||||
cloned = std::make_unique<CastInst>(
|
||||
inst->GetOpcode(), inst->GetType(), operand,
|
||||
inst->GetName() + ".inl");
|
||||
break;
|
||||
}
|
||||
case Opcode::Load: {
|
||||
auto* load = static_cast<LoadInst*>(inst);
|
||||
Value* ptr = MapValue(load->GetPtr(), value_map);
|
||||
cloned = std::make_unique<LoadInst>(
|
||||
load->GetType(), ptr, inst->GetName() + ".inl");
|
||||
break;
|
||||
}
|
||||
case Opcode::Store: {
|
||||
auto* store = static_cast<StoreInst*>(inst);
|
||||
Value* val = MapValue(store->GetValue(), value_map);
|
||||
Value* ptr = MapValue(store->GetPtr(), value_map);
|
||||
cloned = std::make_unique<StoreInst>(Type::GetVoidType(), val, ptr);
|
||||
break;
|
||||
}
|
||||
case Opcode::GEP: {
|
||||
auto* gep = static_cast<GetElementPtrInst*>(inst);
|
||||
Value* base = MapValue(gep->GetBasePtr(), value_map);
|
||||
Value* index = MapValue(gep->GetIndex(), value_map);
|
||||
cloned = std::make_unique<GetElementPtrInst>(
|
||||
gep->GetType(), base, index, inst->GetName() + ".inl");
|
||||
break;
|
||||
}
|
||||
case Opcode::Call: {
|
||||
auto* orig_call = static_cast<CallInst*>(inst);
|
||||
Function* callee_func = orig_call->GetCallee();
|
||||
std::vector<Value*> args;
|
||||
for (size_t i = 0; i < orig_call->GetNumArgs(); ++i) {
|
||||
args.push_back(MapValue(orig_call->GetArg(i), value_map));
|
||||
}
|
||||
cloned = std::make_unique<CallInst>(
|
||||
orig_call->GetType(), callee_func, args,
|
||||
inst->GetName() + ".inl");
|
||||
break;
|
||||
}
|
||||
case Opcode::Alloca: {
|
||||
auto* alloca_inst = static_cast<AllocaInst*>(inst);
|
||||
if (alloca_inst->IsArrayAlloca()) {
|
||||
Value* count = MapValue(alloca_inst->GetCount(), value_map);
|
||||
cloned = std::make_unique<AllocaInst>(
|
||||
alloca_inst->GetElementType(),
|
||||
alloca_inst->GetName() + ".inl", count);
|
||||
} else {
|
||||
cloned = std::make_unique<AllocaInst>(
|
||||
alloca_inst->GetElementType(),
|
||||
alloca_inst->GetName() + ".inl");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
if (cloned) {
|
||||
out.push_back(std::move(cloned));
|
||||
}
|
||||
}
|
||||
|
||||
bool InlineCall(CallInst* call, Function* callee, Function* caller,
|
||||
BasicBlock* call_bb, Module* module) {
|
||||
if (kDebugInline) {
|
||||
std::cerr << "[Inline] Inlining " << callee->GetName()
|
||||
<< " (" << callee->GetBlocks().size() << " blocks)"
|
||||
<< " into " << caller->GetName() << std::endl;
|
||||
}
|
||||
|
||||
bool is_single_block = (callee->GetBlocks().size() == 1);
|
||||
|
||||
std::unordered_map<Value*, Value*> value_map;
|
||||
|
||||
for (auto& gv : module->GetGlobals()) {
|
||||
value_map[gv.get()] = gv.get();
|
||||
}
|
||||
for (auto& other_func : module->GetFunctions()) {
|
||||
value_map[other_func.get()] = other_func.get();
|
||||
}
|
||||
for (auto& arg : caller->GetParams()) {
|
||||
value_map[arg.get()] = arg.get();
|
||||
}
|
||||
{
|
||||
auto& blocks = caller->GetBlocks();
|
||||
for (size_t bi = 0; bi < blocks.size(); ++bi) {
|
||||
auto& insts = blocks[bi]->GetInstructions();
|
||||
for (size_t ii = 0; ii < insts.size(); ++ii) {
|
||||
value_map[insts[ii].get()] = insts[ii].get();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < callee->GetParams().size(); ++i) {
|
||||
auto* formal_arg = callee->GetParams()[i].get();
|
||||
auto* actual_arg = call->GetArg(i);
|
||||
value_map[formal_arg] = actual_arg;
|
||||
}
|
||||
|
||||
auto& call_bb_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
|
||||
call_bb->GetInstructions());
|
||||
|
||||
size_t call_idx = 0;
|
||||
for (size_t i = 0; i < call_bb_insts.size(); ++i) {
|
||||
if (call_bb_insts[i].get() == call) {
|
||||
call_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_single_block) {
|
||||
auto* callee_entry = callee->GetEntry();
|
||||
Value* return_value = nullptr;
|
||||
|
||||
std::vector<std::unique_ptr<Instruction>> cloned_insts;
|
||||
std::vector<std::unique_ptr<Instruction>> alloca_insts;
|
||||
|
||||
for (auto& inst : callee_entry->GetInstructions()) {
|
||||
if (inst->GetOpcode() == Opcode::Alloca) {
|
||||
std::vector<std::unique_ptr<Instruction>> tmp;
|
||||
CloneInstruction(inst.get(), value_map, tmp);
|
||||
if (!tmp.empty()) {
|
||||
value_map[inst.get()] = tmp.back().get();
|
||||
alloca_insts.push_back(std::move(tmp.back()));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inst->GetOpcode() == Opcode::Ret) {
|
||||
auto* ret_inst = static_cast<ReturnInst*>(inst.get());
|
||||
if (ret_inst->HasValue()) {
|
||||
return_value = MapValue(ret_inst->GetValue(), value_map);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<Instruction>> tmp;
|
||||
CloneInstruction(inst.get(), value_map, tmp);
|
||||
if (!tmp.empty()) {
|
||||
value_map[inst.get()] = tmp.back().get();
|
||||
cloned_insts.push_back(std::move(tmp.back()));
|
||||
}
|
||||
}
|
||||
|
||||
if (return_value) {
|
||||
call->ReplaceAllUsesWith(return_value);
|
||||
} else if (!call->GetType()->IsVoid()) {
|
||||
call->ReplaceAllUsesWith(module->GetContext().GetConstInt(0));
|
||||
}
|
||||
|
||||
auto* entry_bb = caller->GetEntry();
|
||||
auto& entry_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
|
||||
entry_bb->GetInstructions());
|
||||
size_t alloca_insert_pos = 0;
|
||||
for (size_t i = 0; i < entry_insts.size(); ++i) {
|
||||
if (entry_insts[i]->GetOpcode() == Opcode::Alloca) {
|
||||
alloca_insert_pos = i + 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (auto& alloca : alloca_insts) {
|
||||
alloca->SetParent(entry_bb);
|
||||
entry_insts.insert(entry_insts.begin() + alloca_insert_pos, std::move(alloca));
|
||||
alloca_insert_pos++;
|
||||
}
|
||||
|
||||
size_t insert_pos = call_idx;
|
||||
for (auto& cloned : cloned_insts) {
|
||||
cloned->SetParent(call_bb);
|
||||
call_bb_insts.insert(call_bb_insts.begin() + insert_pos, std::move(cloned));
|
||||
insert_pos++;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < call_bb_insts.size(); ++i) {
|
||||
if (call_bb_insts[i].get() == call) {
|
||||
for (size_t oi = 0; oi < call->GetNumOperands(); ++oi) {
|
||||
auto* op = call->GetOperand(oi);
|
||||
if (auto* op_inst = dynamic_cast<Instruction*>(op)) {
|
||||
op_inst->RemoveUse(call, oi);
|
||||
}
|
||||
}
|
||||
call_bb_insts.erase(call_bb_insts.begin() + i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// === Multi-block inlining ===
|
||||
|
||||
// 1. Create after_bb: move instructions after call from call_bb to after_bb
|
||||
BasicBlock* after_bb = caller->CreateBlock(call_bb->GetName() + ".after");
|
||||
|
||||
std::vector<std::unique_ptr<Instruction>> after_insts;
|
||||
for (size_t i = call_idx + 1; i < call_bb_insts.size(); ++i) {
|
||||
after_insts.push_back(std::move(call_bb_insts[i]));
|
||||
}
|
||||
call_bb_insts.resize(call_idx + 1);
|
||||
|
||||
for (auto& inst : after_insts) {
|
||||
inst->SetParent(after_bb);
|
||||
after_bb->GetMutablePredecessors();
|
||||
}
|
||||
auto& after_bb_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
|
||||
after_bb->GetInstructions());
|
||||
for (auto& inst : after_insts) {
|
||||
after_bb_insts.push_back(std::move(inst));
|
||||
}
|
||||
|
||||
// 1b. Fix phi nodes: any phi that had call_bb as predecessor should now use after_bb
|
||||
for (auto& bb : caller->GetBlocks()) {
|
||||
for (auto& inst : bb->GetInstructions()) {
|
||||
if (inst->GetOpcode() != Opcode::Phi) break;
|
||||
auto* phi = static_cast<PhiInst*>(inst.get());
|
||||
size_t num_ops = phi->GetNumOperands();
|
||||
for (size_t i = 0; i + 1 < num_ops; i += 2) {
|
||||
auto* bb_ptr = dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1));
|
||||
if (bb_ptr == call_bb) {
|
||||
phi->SetOperand(i + 1, after_bb);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Create cloned blocks for callee
|
||||
std::unordered_map<BasicBlock*, BasicBlock*> bb_map;
|
||||
std::vector<BasicBlock*> cloned_bbs;
|
||||
for (auto& bb : callee->GetBlocks()) {
|
||||
BasicBlock* cloned_bb = caller->CreateBlock(bb->GetName() + ".inl");
|
||||
bb_map[bb.get()] = cloned_bb;
|
||||
cloned_bbs.push_back(cloned_bb);
|
||||
}
|
||||
BasicBlock* cloned_entry = bb_map[callee->GetEntry()];
|
||||
|
||||
// 2b. Reorder blocks: move cloned blocks and after_bb right after call_bb
|
||||
// IMPORTANT: after_bb must come AFTER all cloned blocks, because
|
||||
// after_bb may use values defined in the cloned blocks (e.g., call results
|
||||
// from nested inlines). The lowering processes blocks in order, so values
|
||||
// must be defined before they are used.
|
||||
{
|
||||
auto& blocks = const_cast<std::vector<std::unique_ptr<BasicBlock>>&>(caller->GetBlocks());
|
||||
std::vector<size_t> move_indices;
|
||||
for (auto* cb : cloned_bbs) {
|
||||
for (size_t i = 0; i < blocks.size(); ++i) {
|
||||
if (blocks[i].get() == cb) { move_indices.push_back(i); break; }
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < blocks.size(); ++i) {
|
||||
if (blocks[i].get() == after_bb) { move_indices.push_back(i); break; }
|
||||
}
|
||||
|
||||
size_t call_bb_idx = 0;
|
||||
for (size_t i = 0; i < blocks.size(); ++i) {
|
||||
if (blocks[i].get() == call_bb) { call_bb_idx = i; break; }
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<BasicBlock>> extracted;
|
||||
for (auto idx : move_indices) {
|
||||
extracted.push_back(std::move(blocks[idx]));
|
||||
}
|
||||
|
||||
size_t insert_pos = call_bb_idx + 1;
|
||||
for (auto& b : extracted) {
|
||||
blocks.insert(blocks.begin() + insert_pos, std::move(b));
|
||||
insert_pos++;
|
||||
}
|
||||
|
||||
blocks.erase(std::remove_if(blocks.begin(), blocks.end(),
|
||||
[](const std::unique_ptr<BasicBlock>& b) { return b == nullptr; }),
|
||||
blocks.end());
|
||||
}
|
||||
|
||||
// 4. Create alloca for return value (if non-void)
|
||||
AllocaInst* ret_alloca = nullptr;
|
||||
bool has_return = !call->GetType()->IsVoid();
|
||||
if (has_return) {
|
||||
auto* entry_bb = caller->GetEntry();
|
||||
auto& entry_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
|
||||
entry_bb->GetInstructions());
|
||||
auto alloca = std::make_unique<AllocaInst>(call->GetType(), "__ret.inl");
|
||||
alloca->SetParent(entry_bb);
|
||||
ret_alloca = static_cast<AllocaInst*>(alloca.get());
|
||||
size_t alloca_insert_pos = 0;
|
||||
for (size_t i = 0; i < entry_insts.size(); ++i) {
|
||||
if (entry_insts[i]->GetOpcode() == Opcode::Alloca) {
|
||||
alloca_insert_pos = i + 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
entry_insts.insert(entry_insts.begin() + alloca_insert_pos, std::move(alloca));
|
||||
}
|
||||
|
||||
// 5. Clone all instructions from callee blocks into cloned blocks
|
||||
// Pass 1: Create cloned instructions with original operands, build value_map
|
||||
std::vector<std::unique_ptr<Instruction>> alloca_insts;
|
||||
std::vector<std::pair<Instruction*, Instruction*>> remap_list;
|
||||
|
||||
for (auto& bb : callee->GetBlocks()) {
|
||||
BasicBlock* cloned_bb = bb_map[bb.get()];
|
||||
auto& cloned_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
|
||||
cloned_bb->GetInstructions());
|
||||
|
||||
for (auto& inst : bb->GetInstructions()) {
|
||||
if (inst->GetOpcode() == Opcode::Alloca) {
|
||||
std::vector<std::unique_ptr<Instruction>> tmp;
|
||||
CloneInstruction(inst.get(), value_map, tmp);
|
||||
if (!tmp.empty()) {
|
||||
value_map[inst.get()] = tmp.back().get();
|
||||
alloca_insts.push_back(std::move(tmp.back()));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inst->GetOpcode() == Opcode::Phi) {
|
||||
auto* phi = static_cast<PhiInst*>(inst.get());
|
||||
auto new_phi = std::make_unique<PhiInst>(phi->GetType(), phi->GetName() + ".inl");
|
||||
new_phi->SetParent(cloned_bb);
|
||||
value_map[inst.get()] = new_phi.get();
|
||||
cloned_insts.push_back(std::move(new_phi));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inst->IsTerminator()) continue;
|
||||
|
||||
std::vector<std::unique_ptr<Instruction>> tmp;
|
||||
CloneInstruction(inst.get(), value_map, tmp);
|
||||
if (!tmp.empty()) {
|
||||
tmp.back()->SetParent(cloned_bb);
|
||||
value_map[inst.get()] = tmp.back().get();
|
||||
remap_list.push_back({inst.get(), tmp.back().get()});
|
||||
cloned_insts.push_back(std::move(tmp.back()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 1b: Remap operands of cloned instructions now that value_map is complete
|
||||
for (auto& [orig, cloned] : remap_list) {
|
||||
for (size_t i = 0; i < orig->GetNumOperands(); ++i) {
|
||||
Value* orig_op = orig->GetOperand(i);
|
||||
Value* mapped = MapValue(orig_op, value_map);
|
||||
if (mapped != orig_op) {
|
||||
cloned->SetOperand(i, mapped);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: fill phi operands and handle terminators
|
||||
for (auto& bb : callee->GetBlocks()) {
|
||||
BasicBlock* cloned_bb = bb_map[bb.get()];
|
||||
auto& cloned_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
|
||||
cloned_bb->GetInstructions());
|
||||
|
||||
for (auto& inst : bb->GetInstructions()) {
|
||||
if (inst->GetOpcode() == Opcode::Phi) {
|
||||
auto* orig_phi = static_cast<PhiInst*>(inst.get());
|
||||
auto* cloned_phi = static_cast<PhiInst*>(value_map[orig_phi]);
|
||||
if (!cloned_phi) continue;
|
||||
|
||||
for (size_t i = 0; i < orig_phi->GetNumOperands(); i += 2) {
|
||||
Value* val = MapValue(orig_phi->GetOperand(i), value_map);
|
||||
auto* orig_pred = static_cast<BasicBlock*>(orig_phi->GetOperand(i + 1));
|
||||
auto pred_it = bb_map.find(orig_pred);
|
||||
BasicBlock* pred = (pred_it != bb_map.end()) ? pred_it->second : orig_pred;
|
||||
cloned_phi->AddOperand(val);
|
||||
cloned_phi->AddOperand(pred);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inst->GetOpcode() == Opcode::Ret) {
|
||||
auto* ret_inst = static_cast<ReturnInst*>(inst.get());
|
||||
if (ret_inst->HasValue() && has_return) {
|
||||
Value* ret_val = MapValue(ret_inst->GetValue(), value_map);
|
||||
auto store = std::make_unique<StoreInst>(
|
||||
Type::GetVoidType(), ret_val, ret_alloca);
|
||||
store->SetParent(cloned_bb);
|
||||
cloned_insts.push_back(std::move(store));
|
||||
}
|
||||
auto br = std::make_unique<BranchInst>(Type::GetVoidType(), after_bb);
|
||||
br->SetParent(cloned_bb);
|
||||
cloned_insts.push_back(std::move(br));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inst->GetOpcode() == Opcode::Br) {
|
||||
auto* br = static_cast<BranchInst*>(inst.get());
|
||||
auto it = bb_map.find(br->GetTarget());
|
||||
BasicBlock* target = (it != bb_map.end()) ? it->second : br->GetTarget();
|
||||
auto new_br = std::make_unique<BranchInst>(Type::GetVoidType(), target);
|
||||
new_br->SetParent(cloned_bb);
|
||||
cloned_insts.push_back(std::move(new_br));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inst->GetOpcode() == Opcode::CondBr) {
|
||||
auto* cbr = static_cast<CondBranchInst*>(inst.get());
|
||||
Value* cond = MapValue(cbr->GetCond(), value_map);
|
||||
auto true_it = bb_map.find(cbr->GetTrueTarget());
|
||||
BasicBlock* true_target = (true_it != bb_map.end()) ? true_it->second : cbr->GetTrueTarget();
|
||||
auto false_it = bb_map.find(cbr->GetFalseTarget());
|
||||
BasicBlock* false_target = (false_it != bb_map.end()) ? false_it->second : cbr->GetFalseTarget();
|
||||
auto new_cbr = std::make_unique<CondBranchInst>(
|
||||
Type::GetVoidType(), cond, true_target, false_target);
|
||||
new_cbr->SetParent(cloned_bb);
|
||||
cloned_insts.push_back(std::move(new_cbr));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 7. Insert alloca_insts into caller entry
|
||||
{
|
||||
auto* entry_bb = caller->GetEntry();
|
||||
auto& entry_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
|
||||
entry_bb->GetInstructions());
|
||||
size_t alloca_insert_pos = 0;
|
||||
for (size_t i = 0; i < entry_insts.size(); ++i) {
|
||||
if (entry_insts[i]->GetOpcode() == Opcode::Alloca) {
|
||||
alloca_insert_pos = i + 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (auto& alloca : alloca_insts) {
|
||||
alloca->SetParent(entry_bb);
|
||||
entry_insts.insert(entry_insts.begin() + alloca_insert_pos, std::move(alloca));
|
||||
alloca_insert_pos++;
|
||||
}
|
||||
}
|
||||
|
||||
// 8-9. Handle return value and remove call
|
||||
auto call_type = call->GetType();
|
||||
|
||||
if (has_return) {
|
||||
auto load_ret = std::make_unique<LoadInst>(
|
||||
call_type, ret_alloca, "__ret.load.inl");
|
||||
load_ret->SetParent(after_bb);
|
||||
Value* ret_val = load_ret.get();
|
||||
after_bb_insts.insert(after_bb_insts.begin(), std::move(load_ret));
|
||||
|
||||
call->ReplaceAllUsesWith(ret_val);
|
||||
} else {
|
||||
call->ReplaceAllUsesWith(module->GetContext().GetConstInt(0));
|
||||
}
|
||||
|
||||
// Remove the call and add branch to cloned_entry
|
||||
for (size_t i = 0; i < call_bb_insts.size(); ++i) {
|
||||
if (call_bb_insts[i].get() == call) {
|
||||
for (size_t oi = 0; oi < call->GetNumOperands(); ++oi) {
|
||||
auto* op = call->GetOperand(oi);
|
||||
if (auto* op_inst = dynamic_cast<Instruction*>(op)) {
|
||||
op_inst->RemoveUse(call, oi);
|
||||
}
|
||||
}
|
||||
call_bb_insts.erase(call_bb_insts.begin() + i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto br_to_entry = std::make_unique<BranchInst>(Type::GetVoidType(), cloned_entry);
|
||||
br_to_entry->SetParent(call_bb);
|
||||
call_bb_insts.push_back(std::move(br_to_entry));
|
||||
|
||||
if (kDebugInline) {
|
||||
std::cerr << "[Inline] Done inlining " << callee->GetName() << std::endl;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RunInline(Module* module) {
|
||||
if (!module) return;
|
||||
|
||||
std::unordered_map<std::string, int> func_sizes;
|
||||
std::unordered_set<std::string> recursive_funcs;
|
||||
|
||||
for (auto& func : module->GetFunctions()) {
|
||||
if (func->IsExternal()) continue;
|
||||
func_sizes[func->GetName()] = CountInstructions(func.get());
|
||||
if (IsRecursive(func.get())) {
|
||||
recursive_funcs.insert(func->GetName());
|
||||
}
|
||||
}
|
||||
|
||||
struct InlineSite {
|
||||
CallInst* call;
|
||||
Function* caller;
|
||||
BasicBlock* call_bb;
|
||||
};
|
||||
|
||||
std::vector<InlineSite> inline_sites;
|
||||
|
||||
for (auto& caller : module->GetFunctions()) {
|
||||
if (caller->IsExternal()) continue;
|
||||
|
||||
for (auto& bb : caller->GetBlocks()) {
|
||||
for (auto& inst : bb->GetInstructions()) {
|
||||
auto* call = dynamic_cast<CallInst*>(inst.get());
|
||||
if (!call) continue;
|
||||
|
||||
auto* callee = call->GetCallee();
|
||||
if (!callee) continue;
|
||||
if (callee->IsExternal()) continue;
|
||||
if (recursive_funcs.count(callee->GetName())) continue;
|
||||
|
||||
auto size_it = func_sizes.find(callee->GetName());
|
||||
int callee_size = (size_it != func_sizes.end()) ? size_it->second : 9999;
|
||||
|
||||
if (callee_size > kMaxInlineSize) continue;
|
||||
if (callee == caller.get()) continue;
|
||||
if (callee->GetBlocks().size() > 1 && callee_size > kMaxMultiBlockInlineSize) continue;
|
||||
|
||||
inline_sites.push_back({call, caller.get(), bb.get()});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& site : inline_sites) {
|
||||
auto* callee = site.call->GetCallee();
|
||||
if (!callee) continue;
|
||||
|
||||
bool still_valid = false;
|
||||
BasicBlock* actual_bb = nullptr;
|
||||
for (auto& bb : site.caller->GetBlocks()) {
|
||||
for (auto& inst : bb->GetInstructions()) {
|
||||
if (inst.get() == site.call) {
|
||||
still_valid = true;
|
||||
actual_bb = bb.get();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (still_valid) break;
|
||||
}
|
||||
if (!still_valid) continue;
|
||||
|
||||
InlineCall(site.call, callee, site.caller, actual_bb, module);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
Loading…
Reference in new issue