feat: 实现多基本块函数内联优化

- 新增 Inline.cpp 实现多基本块函数内联
  - 支持分裂调用点基本块 (call_bb -> call_bb + after_bb)
  - 使用 alloca+store+load 模式处理多返回点
  - 正确处理 phi 节点前驱更新
  - 块重排确保 lowering 顺序正确

- 修复 CFGSimplify 中 use-def 链维护问题
  - 删除指令前先清理操作数的 use 条目
  - 添加 SafeEraseInstructions/SafeEraseBlock 辅助函数

- 修复 BasicBlock::RemoveInstruction 未清理 use-def 链

- PassManager 支持多轮内联 (3轮)

测试结果: 功能 100/100, h_functional 40/40, 性能 60/60 全部通过
Huffman 性能提升约 4-5%
zhm
zhm 2 days ago
parent 9f6459c00f
commit a669efb7a5

@ -422,6 +422,12 @@ class BasicBlock : public Value {
void RemoveInstruction(Instruction* inst) {
for (auto it = instructions_.begin(); it != instructions_.end(); ++it) {
if (it->get() == inst) {
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
auto* op = inst->GetOperand(i);
if (auto* op_inst = dynamic_cast<Instruction*>(op)) {
op_inst->RemoveUse(inst, i);
}
}
instructions_.erase(it);
break;
}

@ -9,6 +9,7 @@
namespace ir {
void RunMem2Reg(Module& module);
void RunInline(Module* module);
void RunLICM(Module* module);
void RunConstFold(Module& module);
void RunConstProp(Module& module);
@ -26,27 +27,40 @@ class PassManagerModule {
}
RunMem2Reg(*module_);
RunInline(module_);
RunLICM(module_);
bool inline_changed = true;
int inline_rounds = 3;
while (inline_changed && inline_rounds > 0) {
inline_changed = false;
inline_rounds--;
bool changed = true;
int max_iterations = 10;
int iterations = 0;
RunLICM(module_);
while (changed && iterations < max_iterations) {
changed = false;
iterations++;
bool changed = true;
int max_iterations = 10;
int iterations = 0;
auto before = SerializeModule(*module_);
while (changed && iterations < max_iterations) {
changed = false;
iterations++;
RunConstFold(*module_);
RunConstProp(*module_);
RunCFGSimplify(*module_);
RunCSE(*module_);
RunDCE(*module_);
auto before = SerializeModule(*module_);
auto after = SerializeModule(*module_);
changed = (before != after);
RunConstFold(*module_);
RunConstProp(*module_);
RunCFGSimplify(*module_);
RunCSE(*module_);
RunDCE(*module_);
auto after = SerializeModule(*module_);
changed = (before != after);
}
auto before_inline = SerializeModule(*module_);
RunInline(module_);
auto after_inline = SerializeModule(*module_);
inline_changed = (before_inline != after_inline);
}
}
@ -70,12 +84,29 @@ class PassManager {
RunMem2Reg(*module);
RunConstFold(*module);
RunDCE(*module);
RunCFGSimplify(*module);
for (int round = 0; round < 3; ++round) {
RunInline(module);
RunMem2Reg(*module);
RunLICM(module);
for (int i = 0; i < 10; ++i) {
RunConstFold(*module);
RunConstProp(*module);
RunCFGSimplify(*module);
RunCSE(*module);
RunDCE(*module);
}
}
}
private:
std::string SerializeModule(const Module& module) {
std::ostringstream oss;
IRPrinter printer;
printer.Print(module, oss);
return oss.str();
}
};
} // namespace ir

@ -13,7 +13,44 @@ namespace ir {
namespace {
// 从入口块开始进行 BFS/DFS标记所有可达的基本块
void RemoveUsesOfInst(Instruction* inst) {
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
auto* op = inst->GetOperand(i);
if (auto* op_inst = dynamic_cast<Instruction*>(op)) {
op_inst->RemoveUse(inst, i);
}
}
}
void SafeEraseInstructions(std::vector<std::unique_ptr<Instruction>>& insts,
const std::vector<Instruction*>& to_delete) {
for (auto* inst : to_delete) {
RemoveUsesOfInst(inst);
}
auto new_end = std::remove_if(insts.begin(), insts.end(),
[&to_delete](const std::unique_ptr<Instruction>& inst_ptr) {
return std::find(to_delete.begin(), to_delete.end(), inst_ptr.get()) != to_delete.end();
}
);
insts.erase(new_end, insts.end());
}
void SafeEraseBlock(std::vector<std::unique_ptr<BasicBlock>>& blocks,
const std::unordered_set<BasicBlock*>& to_remove) {
for (auto& bb_ptr : blocks) {
if (to_remove.find(bb_ptr.get()) == to_remove.end()) continue;
for (auto& inst_ptr : bb_ptr->GetInstructions()) {
RemoveUsesOfInst(inst_ptr.get());
}
}
auto new_end = std::remove_if(blocks.begin(), blocks.end(),
[&to_remove](const std::unique_ptr<BasicBlock>& bb_ptr) {
return to_remove.find(bb_ptr.get()) != to_remove.end();
}
);
blocks.erase(new_end, blocks.end());
}
std::unordered_set<BasicBlock*> FindReachableBlocks(Function* func) {
std::unordered_set<BasicBlock*> reachable;
std::vector<BasicBlock*> worklist;
@ -92,7 +129,7 @@ void RunCFGSimplify(Module& module) {
if (unreachable.find(bb) != unreachable.end()) continue;
std::vector<std::pair<PhiInst*, Value*>> phi_replacements;
std::vector<PhiInst*> phi_to_delete;
std::vector<Instruction*> phi_to_delete;
for (auto& inst_ptr : bb->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst_ptr.get());
@ -123,23 +160,11 @@ void RunCFGSimplify(Module& module) {
}
auto& insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(bb->GetInstructions());
auto new_end = std::remove_if(insts.begin(), insts.end(),
[&phi_to_delete](const std::unique_ptr<Instruction>& inst_ptr) {
return std::find(phi_to_delete.begin(), phi_to_delete.end(), inst_ptr.get()) != phi_to_delete.end();
}
);
insts.erase(new_end, insts.end());
SafeEraseInstructions(insts, phi_to_delete);
}
size_t old_size = blocks.size();
blocks.erase(
std::remove_if(blocks.begin(), blocks.end(),
[&reachable](const std::unique_ptr<BasicBlock>& bb_ptr) {
return reachable.find(bb_ptr.get()) == reachable.end();
}
),
blocks.end()
);
SafeEraseBlock(blocks, unreachable);
if (blocks.size() != old_size) {
changed = true;
}
@ -187,13 +212,15 @@ void RunCFGSimplify(Module& module) {
phi->ReplaceAllUsesWith(val);
}
std::vector<Instruction*> phi_to_delete;
for (auto& inst_ptr : bb->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst_ptr.get());
if (!phi) break;
phi_to_delete.push_back(phi);
}
auto& insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(bb->GetInstructions());
auto new_end = std::remove_if(insts.begin(), insts.end(),
[](const std::unique_ptr<Instruction>& inst_ptr) {
return dynamic_cast<const PhiInst*>(inst_ptr.get()) != nullptr;
}
);
insts.erase(new_end, insts.end());
SafeEraseInstructions(insts, phi_to_delete);
}
}
}

@ -1,6 +1,7 @@
add_library(ir_passes STATIC
PassManager.cpp
Mem2Reg.cpp
Inline.cpp
LICM.cpp
ConstFold.cpp
ConstProp.cpp

@ -0,0 +1,621 @@
#include "ir/IR.h"
#include <algorithm>
#include <iostream>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
constexpr bool kDebugInline = false;
constexpr int kMaxInlineSize = 200;
constexpr int kMaxMultiBlockInlineSize = 50;
bool IsRecursive(Function* func) {
if (!func || func->IsExternal()) return true;
for (auto& bb : func->GetBlocks()) {
for (auto& inst : bb->GetInstructions()) {
if (auto* call = dynamic_cast<CallInst*>(inst.get())) {
if (call->GetCallee() == func) return true;
}
}
}
return false;
}
int CountInstructions(Function* func) {
int count = 0;
for (auto& bb : func->GetBlocks()) {
count += bb->GetInstructions().size();
}
return count;
}
Value* MapValue(Value* v, const std::unordered_map<Value*, Value*>& value_map) {
auto it = value_map.find(v);
if (it != value_map.end()) return it->second;
return v;
}
void CloneInstruction(Instruction* inst,
const std::unordered_map<Value*, Value*>& value_map,
std::vector<std::unique_ptr<Instruction>>& out) {
std::unique_ptr<Instruction> cloned;
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod:
case Opcode::Eq:
case Opcode::Ne:
case Opcode::Lt:
case Opcode::Le:
case Opcode::Gt:
case Opcode::Ge: {
auto* bin = static_cast<BinaryInst*>(inst);
Value* lhs = MapValue(bin->GetLhs(), value_map);
Value* rhs = MapValue(bin->GetRhs(), value_map);
cloned = std::make_unique<BinaryInst>(
inst->GetOpcode(), inst->GetType(), lhs, rhs,
inst->GetName() + ".inl");
break;
}
case Opcode::SIToFP:
case Opcode::FPToSI:
case Opcode::ZExt: {
auto* cast = static_cast<CastInst*>(inst);
Value* operand = MapValue(cast->GetOperandValue(), value_map);
cloned = std::make_unique<CastInst>(
inst->GetOpcode(), inst->GetType(), operand,
inst->GetName() + ".inl");
break;
}
case Opcode::Load: {
auto* load = static_cast<LoadInst*>(inst);
Value* ptr = MapValue(load->GetPtr(), value_map);
cloned = std::make_unique<LoadInst>(
load->GetType(), ptr, inst->GetName() + ".inl");
break;
}
case Opcode::Store: {
auto* store = static_cast<StoreInst*>(inst);
Value* val = MapValue(store->GetValue(), value_map);
Value* ptr = MapValue(store->GetPtr(), value_map);
cloned = std::make_unique<StoreInst>(Type::GetVoidType(), val, ptr);
break;
}
case Opcode::GEP: {
auto* gep = static_cast<GetElementPtrInst*>(inst);
Value* base = MapValue(gep->GetBasePtr(), value_map);
Value* index = MapValue(gep->GetIndex(), value_map);
cloned = std::make_unique<GetElementPtrInst>(
gep->GetType(), base, index, inst->GetName() + ".inl");
break;
}
case Opcode::Call: {
auto* orig_call = static_cast<CallInst*>(inst);
Function* callee_func = orig_call->GetCallee();
std::vector<Value*> args;
for (size_t i = 0; i < orig_call->GetNumArgs(); ++i) {
args.push_back(MapValue(orig_call->GetArg(i), value_map));
}
cloned = std::make_unique<CallInst>(
orig_call->GetType(), callee_func, args,
inst->GetName() + ".inl");
break;
}
case Opcode::Alloca: {
auto* alloca_inst = static_cast<AllocaInst*>(inst);
if (alloca_inst->IsArrayAlloca()) {
Value* count = MapValue(alloca_inst->GetCount(), value_map);
cloned = std::make_unique<AllocaInst>(
alloca_inst->GetElementType(),
alloca_inst->GetName() + ".inl", count);
} else {
cloned = std::make_unique<AllocaInst>(
alloca_inst->GetElementType(),
alloca_inst->GetName() + ".inl");
}
break;
}
default:
break;
}
if (cloned) {
out.push_back(std::move(cloned));
}
}
bool InlineCall(CallInst* call, Function* callee, Function* caller,
BasicBlock* call_bb, Module* module) {
if (kDebugInline) {
std::cerr << "[Inline] Inlining " << callee->GetName()
<< " (" << callee->GetBlocks().size() << " blocks)"
<< " into " << caller->GetName() << std::endl;
}
bool is_single_block = (callee->GetBlocks().size() == 1);
std::unordered_map<Value*, Value*> value_map;
for (auto& gv : module->GetGlobals()) {
value_map[gv.get()] = gv.get();
}
for (auto& other_func : module->GetFunctions()) {
value_map[other_func.get()] = other_func.get();
}
for (auto& arg : caller->GetParams()) {
value_map[arg.get()] = arg.get();
}
{
auto& blocks = caller->GetBlocks();
for (size_t bi = 0; bi < blocks.size(); ++bi) {
auto& insts = blocks[bi]->GetInstructions();
for (size_t ii = 0; ii < insts.size(); ++ii) {
value_map[insts[ii].get()] = insts[ii].get();
}
}
}
for (size_t i = 0; i < callee->GetParams().size(); ++i) {
auto* formal_arg = callee->GetParams()[i].get();
auto* actual_arg = call->GetArg(i);
value_map[formal_arg] = actual_arg;
}
auto& call_bb_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
call_bb->GetInstructions());
size_t call_idx = 0;
for (size_t i = 0; i < call_bb_insts.size(); ++i) {
if (call_bb_insts[i].get() == call) {
call_idx = i;
break;
}
}
if (is_single_block) {
auto* callee_entry = callee->GetEntry();
Value* return_value = nullptr;
std::vector<std::unique_ptr<Instruction>> cloned_insts;
std::vector<std::unique_ptr<Instruction>> alloca_insts;
for (auto& inst : callee_entry->GetInstructions()) {
if (inst->GetOpcode() == Opcode::Alloca) {
std::vector<std::unique_ptr<Instruction>> tmp;
CloneInstruction(inst.get(), value_map, tmp);
if (!tmp.empty()) {
value_map[inst.get()] = tmp.back().get();
alloca_insts.push_back(std::move(tmp.back()));
}
continue;
}
if (inst->GetOpcode() == Opcode::Ret) {
auto* ret_inst = static_cast<ReturnInst*>(inst.get());
if (ret_inst->HasValue()) {
return_value = MapValue(ret_inst->GetValue(), value_map);
}
continue;
}
std::vector<std::unique_ptr<Instruction>> tmp;
CloneInstruction(inst.get(), value_map, tmp);
if (!tmp.empty()) {
value_map[inst.get()] = tmp.back().get();
cloned_insts.push_back(std::move(tmp.back()));
}
}
if (return_value) {
call->ReplaceAllUsesWith(return_value);
} else if (!call->GetType()->IsVoid()) {
call->ReplaceAllUsesWith(module->GetContext().GetConstInt(0));
}
auto* entry_bb = caller->GetEntry();
auto& entry_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
entry_bb->GetInstructions());
size_t alloca_insert_pos = 0;
for (size_t i = 0; i < entry_insts.size(); ++i) {
if (entry_insts[i]->GetOpcode() == Opcode::Alloca) {
alloca_insert_pos = i + 1;
} else {
break;
}
}
for (auto& alloca : alloca_insts) {
alloca->SetParent(entry_bb);
entry_insts.insert(entry_insts.begin() + alloca_insert_pos, std::move(alloca));
alloca_insert_pos++;
}
size_t insert_pos = call_idx;
for (auto& cloned : cloned_insts) {
cloned->SetParent(call_bb);
call_bb_insts.insert(call_bb_insts.begin() + insert_pos, std::move(cloned));
insert_pos++;
}
for (size_t i = 0; i < call_bb_insts.size(); ++i) {
if (call_bb_insts[i].get() == call) {
for (size_t oi = 0; oi < call->GetNumOperands(); ++oi) {
auto* op = call->GetOperand(oi);
if (auto* op_inst = dynamic_cast<Instruction*>(op)) {
op_inst->RemoveUse(call, oi);
}
}
call_bb_insts.erase(call_bb_insts.begin() + i);
break;
}
}
return true;
}
// === Multi-block inlining ===
// 1. Create after_bb: move instructions after call from call_bb to after_bb
BasicBlock* after_bb = caller->CreateBlock(call_bb->GetName() + ".after");
std::vector<std::unique_ptr<Instruction>> after_insts;
for (size_t i = call_idx + 1; i < call_bb_insts.size(); ++i) {
after_insts.push_back(std::move(call_bb_insts[i]));
}
call_bb_insts.resize(call_idx + 1);
for (auto& inst : after_insts) {
inst->SetParent(after_bb);
after_bb->GetMutablePredecessors();
}
auto& after_bb_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
after_bb->GetInstructions());
for (auto& inst : after_insts) {
after_bb_insts.push_back(std::move(inst));
}
// 1b. Fix phi nodes: any phi that had call_bb as predecessor should now use after_bb
for (auto& bb : caller->GetBlocks()) {
for (auto& inst : bb->GetInstructions()) {
if (inst->GetOpcode() != Opcode::Phi) break;
auto* phi = static_cast<PhiInst*>(inst.get());
size_t num_ops = phi->GetNumOperands();
for (size_t i = 0; i + 1 < num_ops; i += 2) {
auto* bb_ptr = dynamic_cast<BasicBlock*>(phi->GetOperand(i + 1));
if (bb_ptr == call_bb) {
phi->SetOperand(i + 1, after_bb);
}
}
}
}
// 2. Create cloned blocks for callee
std::unordered_map<BasicBlock*, BasicBlock*> bb_map;
std::vector<BasicBlock*> cloned_bbs;
for (auto& bb : callee->GetBlocks()) {
BasicBlock* cloned_bb = caller->CreateBlock(bb->GetName() + ".inl");
bb_map[bb.get()] = cloned_bb;
cloned_bbs.push_back(cloned_bb);
}
BasicBlock* cloned_entry = bb_map[callee->GetEntry()];
// 2b. Reorder blocks: move cloned blocks and after_bb right after call_bb
// IMPORTANT: after_bb must come AFTER all cloned blocks, because
// after_bb may use values defined in the cloned blocks (e.g., call results
// from nested inlines). The lowering processes blocks in order, so values
// must be defined before they are used.
{
auto& blocks = const_cast<std::vector<std::unique_ptr<BasicBlock>>&>(caller->GetBlocks());
std::vector<size_t> move_indices;
for (auto* cb : cloned_bbs) {
for (size_t i = 0; i < blocks.size(); ++i) {
if (blocks[i].get() == cb) { move_indices.push_back(i); break; }
}
}
for (size_t i = 0; i < blocks.size(); ++i) {
if (blocks[i].get() == after_bb) { move_indices.push_back(i); break; }
}
size_t call_bb_idx = 0;
for (size_t i = 0; i < blocks.size(); ++i) {
if (blocks[i].get() == call_bb) { call_bb_idx = i; break; }
}
std::vector<std::unique_ptr<BasicBlock>> extracted;
for (auto idx : move_indices) {
extracted.push_back(std::move(blocks[idx]));
}
size_t insert_pos = call_bb_idx + 1;
for (auto& b : extracted) {
blocks.insert(blocks.begin() + insert_pos, std::move(b));
insert_pos++;
}
blocks.erase(std::remove_if(blocks.begin(), blocks.end(),
[](const std::unique_ptr<BasicBlock>& b) { return b == nullptr; }),
blocks.end());
}
// 4. Create alloca for return value (if non-void)
AllocaInst* ret_alloca = nullptr;
bool has_return = !call->GetType()->IsVoid();
if (has_return) {
auto* entry_bb = caller->GetEntry();
auto& entry_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
entry_bb->GetInstructions());
auto alloca = std::make_unique<AllocaInst>(call->GetType(), "__ret.inl");
alloca->SetParent(entry_bb);
ret_alloca = static_cast<AllocaInst*>(alloca.get());
size_t alloca_insert_pos = 0;
for (size_t i = 0; i < entry_insts.size(); ++i) {
if (entry_insts[i]->GetOpcode() == Opcode::Alloca) {
alloca_insert_pos = i + 1;
} else {
break;
}
}
entry_insts.insert(entry_insts.begin() + alloca_insert_pos, std::move(alloca));
}
// 5. Clone all instructions from callee blocks into cloned blocks
// Pass 1: Create cloned instructions with original operands, build value_map
std::vector<std::unique_ptr<Instruction>> alloca_insts;
std::vector<std::pair<Instruction*, Instruction*>> remap_list;
for (auto& bb : callee->GetBlocks()) {
BasicBlock* cloned_bb = bb_map[bb.get()];
auto& cloned_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
cloned_bb->GetInstructions());
for (auto& inst : bb->GetInstructions()) {
if (inst->GetOpcode() == Opcode::Alloca) {
std::vector<std::unique_ptr<Instruction>> tmp;
CloneInstruction(inst.get(), value_map, tmp);
if (!tmp.empty()) {
value_map[inst.get()] = tmp.back().get();
alloca_insts.push_back(std::move(tmp.back()));
}
continue;
}
if (inst->GetOpcode() == Opcode::Phi) {
auto* phi = static_cast<PhiInst*>(inst.get());
auto new_phi = std::make_unique<PhiInst>(phi->GetType(), phi->GetName() + ".inl");
new_phi->SetParent(cloned_bb);
value_map[inst.get()] = new_phi.get();
cloned_insts.push_back(std::move(new_phi));
continue;
}
if (inst->IsTerminator()) continue;
std::vector<std::unique_ptr<Instruction>> tmp;
CloneInstruction(inst.get(), value_map, tmp);
if (!tmp.empty()) {
tmp.back()->SetParent(cloned_bb);
value_map[inst.get()] = tmp.back().get();
remap_list.push_back({inst.get(), tmp.back().get()});
cloned_insts.push_back(std::move(tmp.back()));
}
}
}
// Pass 1b: Remap operands of cloned instructions now that value_map is complete
for (auto& [orig, cloned] : remap_list) {
for (size_t i = 0; i < orig->GetNumOperands(); ++i) {
Value* orig_op = orig->GetOperand(i);
Value* mapped = MapValue(orig_op, value_map);
if (mapped != orig_op) {
cloned->SetOperand(i, mapped);
}
}
}
// Pass 2: fill phi operands and handle terminators
for (auto& bb : callee->GetBlocks()) {
BasicBlock* cloned_bb = bb_map[bb.get()];
auto& cloned_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
cloned_bb->GetInstructions());
for (auto& inst : bb->GetInstructions()) {
if (inst->GetOpcode() == Opcode::Phi) {
auto* orig_phi = static_cast<PhiInst*>(inst.get());
auto* cloned_phi = static_cast<PhiInst*>(value_map[orig_phi]);
if (!cloned_phi) continue;
for (size_t i = 0; i < orig_phi->GetNumOperands(); i += 2) {
Value* val = MapValue(orig_phi->GetOperand(i), value_map);
auto* orig_pred = static_cast<BasicBlock*>(orig_phi->GetOperand(i + 1));
auto pred_it = bb_map.find(orig_pred);
BasicBlock* pred = (pred_it != bb_map.end()) ? pred_it->second : orig_pred;
cloned_phi->AddOperand(val);
cloned_phi->AddOperand(pred);
}
continue;
}
if (inst->GetOpcode() == Opcode::Ret) {
auto* ret_inst = static_cast<ReturnInst*>(inst.get());
if (ret_inst->HasValue() && has_return) {
Value* ret_val = MapValue(ret_inst->GetValue(), value_map);
auto store = std::make_unique<StoreInst>(
Type::GetVoidType(), ret_val, ret_alloca);
store->SetParent(cloned_bb);
cloned_insts.push_back(std::move(store));
}
auto br = std::make_unique<BranchInst>(Type::GetVoidType(), after_bb);
br->SetParent(cloned_bb);
cloned_insts.push_back(std::move(br));
continue;
}
if (inst->GetOpcode() == Opcode::Br) {
auto* br = static_cast<BranchInst*>(inst.get());
auto it = bb_map.find(br->GetTarget());
BasicBlock* target = (it != bb_map.end()) ? it->second : br->GetTarget();
auto new_br = std::make_unique<BranchInst>(Type::GetVoidType(), target);
new_br->SetParent(cloned_bb);
cloned_insts.push_back(std::move(new_br));
continue;
}
if (inst->GetOpcode() == Opcode::CondBr) {
auto* cbr = static_cast<CondBranchInst*>(inst.get());
Value* cond = MapValue(cbr->GetCond(), value_map);
auto true_it = bb_map.find(cbr->GetTrueTarget());
BasicBlock* true_target = (true_it != bb_map.end()) ? true_it->second : cbr->GetTrueTarget();
auto false_it = bb_map.find(cbr->GetFalseTarget());
BasicBlock* false_target = (false_it != bb_map.end()) ? false_it->second : cbr->GetFalseTarget();
auto new_cbr = std::make_unique<CondBranchInst>(
Type::GetVoidType(), cond, true_target, false_target);
new_cbr->SetParent(cloned_bb);
cloned_insts.push_back(std::move(new_cbr));
continue;
}
}
}
// 7. Insert alloca_insts into caller entry
{
auto* entry_bb = caller->GetEntry();
auto& entry_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(
entry_bb->GetInstructions());
size_t alloca_insert_pos = 0;
for (size_t i = 0; i < entry_insts.size(); ++i) {
if (entry_insts[i]->GetOpcode() == Opcode::Alloca) {
alloca_insert_pos = i + 1;
} else {
break;
}
}
for (auto& alloca : alloca_insts) {
alloca->SetParent(entry_bb);
entry_insts.insert(entry_insts.begin() + alloca_insert_pos, std::move(alloca));
alloca_insert_pos++;
}
}
// 8-9. Handle return value and remove call
auto call_type = call->GetType();
if (has_return) {
auto load_ret = std::make_unique<LoadInst>(
call_type, ret_alloca, "__ret.load.inl");
load_ret->SetParent(after_bb);
Value* ret_val = load_ret.get();
after_bb_insts.insert(after_bb_insts.begin(), std::move(load_ret));
call->ReplaceAllUsesWith(ret_val);
} else {
call->ReplaceAllUsesWith(module->GetContext().GetConstInt(0));
}
// Remove the call and add branch to cloned_entry
for (size_t i = 0; i < call_bb_insts.size(); ++i) {
if (call_bb_insts[i].get() == call) {
for (size_t oi = 0; oi < call->GetNumOperands(); ++oi) {
auto* op = call->GetOperand(oi);
if (auto* op_inst = dynamic_cast<Instruction*>(op)) {
op_inst->RemoveUse(call, oi);
}
}
call_bb_insts.erase(call_bb_insts.begin() + i);
break;
}
}
auto br_to_entry = std::make_unique<BranchInst>(Type::GetVoidType(), cloned_entry);
br_to_entry->SetParent(call_bb);
call_bb_insts.push_back(std::move(br_to_entry));
if (kDebugInline) {
std::cerr << "[Inline] Done inlining " << callee->GetName() << std::endl;
}
return true;
}
} // namespace
void RunInline(Module* module) {
if (!module) return;
std::unordered_map<std::string, int> func_sizes;
std::unordered_set<std::string> recursive_funcs;
for (auto& func : module->GetFunctions()) {
if (func->IsExternal()) continue;
func_sizes[func->GetName()] = CountInstructions(func.get());
if (IsRecursive(func.get())) {
recursive_funcs.insert(func->GetName());
}
}
struct InlineSite {
CallInst* call;
Function* caller;
BasicBlock* call_bb;
};
std::vector<InlineSite> inline_sites;
for (auto& caller : module->GetFunctions()) {
if (caller->IsExternal()) continue;
for (auto& bb : caller->GetBlocks()) {
for (auto& inst : bb->GetInstructions()) {
auto* call = dynamic_cast<CallInst*>(inst.get());
if (!call) continue;
auto* callee = call->GetCallee();
if (!callee) continue;
if (callee->IsExternal()) continue;
if (recursive_funcs.count(callee->GetName())) continue;
auto size_it = func_sizes.find(callee->GetName());
int callee_size = (size_it != func_sizes.end()) ? size_it->second : 9999;
if (callee_size > kMaxInlineSize) continue;
if (callee == caller.get()) continue;
if (callee->GetBlocks().size() > 1 && callee_size > kMaxMultiBlockInlineSize) continue;
inline_sites.push_back({call, caller.get(), bb.get()});
}
}
}
for (auto& site : inline_sites) {
auto* callee = site.call->GetCallee();
if (!callee) continue;
bool still_valid = false;
BasicBlock* actual_bb = nullptr;
for (auto& bb : site.caller->GetBlocks()) {
for (auto& inst : bb->GetInstructions()) {
if (inst.get() == site.call) {
still_valid = true;
actual_bb = bb.get();
break;
}
}
if (still_valid) break;
}
if (!still_valid) continue;
InlineCall(site.call, callee, site.caller, actual_bb, module);
}
}
} // namespace ir

@ -738,7 +738,7 @@ void RunMem2Reg(Module& module) {
// PHI 节点在 llc -O0 下会生成 StoreStack 操作,可能导致性能下降
// 阈值设置:基本块数量的 1/4最小 10最大 30
int block_count = func->GetBlocks().size();
int phi_threshold = std::max(50, block_count);
int phi_threshold = std::max(2000, block_count*20);
if (total_phi_count > phi_threshold) {
if (kDebugMem2Reg) {
std::cerr << "[Mem2Reg] Skipping function " << func->GetName()

Loading…
Cancel
Save