From 40eb6784d3032e87bcff0bb8e117e3c130293a16 Mon Sep 17 00:00:00 2001 From: lzkk <956449176@qq.com> Date: Thu, 28 May 2026 16:33:26 +0800 Subject: [PATCH] =?UTF-8?q?perf(ir):=20=E6=B7=BB=E5=8A=A0=20And/Or=20?= =?UTF-8?q?=E6=93=8D=E4=BD=9C=E7=A0=81=20+=20if-else=E2=86=92select=20?= =?UTF-8?q?=E8=BD=AC=E6=8D=A2=E6=A1=86=E6=9E=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - IR 层添加 And/Or BinaryInst 支持(IRPrinter/LoopInfo/Lowering 全覆盖) - MIR 层已有 AndRR/OrRR,直接映射 - Inline.cpp 添加 TryConvertIfElseToSelect:将 3-BB if-else-return 函数转为单 BB 算术 select(fv + (tv-fv)*zext(cmp)),使其可被内联 - 转换逻辑已实现,与快门禁兼容(functional+h_functional 0 失败) --- src/include/ir/IR.h | 2 + src/ir/IRPrinter.cpp | 8 +++- src/ir/analysis/LoopInfo.cpp | 2 + src/ir/passes/Inline.cpp | 88 ++++++++++++++++++++++++++++++++---- src/mir/Lowering.cpp | 8 ++++ 5 files changed, 97 insertions(+), 11 deletions(-) diff --git a/src/include/ir/IR.h b/src/include/ir/IR.h index f49697b9..d709a9c6 100644 --- a/src/include/ir/IR.h +++ b/src/include/ir/IR.h @@ -194,6 +194,8 @@ enum class Opcode { Le, Gt, Ge, + And, + Or, Alloca, Load, Store, diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 41804221..b964e6ac 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -64,6 +64,10 @@ static const char* OpcodeToString(Opcode op) { return "icmp sgt"; case Opcode::Ge: return "icmp sge"; + case Opcode::And: + return "and"; + case Opcode::Or: + return "or"; case Opcode::Alloca: return "alloca"; case Opcode::Load: @@ -242,7 +246,9 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { case Opcode::Lt: case Opcode::Le: case Opcode::Gt: - case Opcode::Ge: { + case Opcode::Ge: + case Opcode::And: + case Opcode::Or: { auto* bin = static_cast(inst); const bool is_float = bin->GetLhs()->GetType()->IsFloat32(); os << " " << bin->GetName() << " = " diff --git a/src/ir/analysis/LoopInfo.cpp b/src/ir/analysis/LoopInfo.cpp index b188f062..acd1580e 100644 --- a/src/ir/analysis/LoopInfo.cpp +++ b/src/ir/analysis/LoopInfo.cpp @@ -306,6 +306,8 @@ namespace ir case Opcode::Sub: case Opcode::Mul: case Opcode::Div: + case Opcode::And: + case Opcode::Or: case Opcode::Mod: case Opcode::Eq: case Opcode::Ne: diff --git a/src/ir/passes/Inline.cpp b/src/ir/passes/Inline.cpp index c80fec2f..5458923d 100644 --- a/src/ir/passes/Inline.cpp +++ b/src/ir/passes/Inline.cpp @@ -1,7 +1,7 @@ // 保守函数内联: // - 自底向上迭代内联,每次只内联无调用(leaf)的单基本块函数 +// - 内联前先将 if-else-return 函数转为 mul select + ret(单 BB) // - 内联后消除外层函数的 call,使其可能变为新 leaf,迭代至收敛 -// - 每个函数内反复扫描直到清空所有可内联 call,4 轮即可收敛 #include "ir/IR.h" @@ -26,8 +26,6 @@ bool IsInlineable(Function* func) { for (const auto& inst : insts) { if (inst->GetOpcode() == Opcode::Call) return false; - // 只内联纯算术/逻辑函数,不内联含内存操作的函数 - // Load/Store/GEP 的函数内联可能导致全局变量副作用顺序问题 if (inst->GetOpcode() == Opcode::Load || inst->GetOpcode() == Opcode::Store || inst->GetOpcode() == Opcode::GEP) @@ -57,7 +55,8 @@ std::unique_ptr CloneInst( case Opcode::Add: case Opcode::Sub: case Opcode::Mul: case Opcode::Div: case Opcode::Mod: case Opcode::Eq: case Opcode::Ne: case Opcode::Lt: - case Opcode::Le: case Opcode::Gt: case Opcode::Ge: { + case Opcode::Le: case Opcode::Gt: case Opcode::Ge: + case Opcode::And: case Opcode::Or: { auto* bin = static_cast(inst); return std::make_unique(op, inst->GetType(), map(bin->GetLhs()), @@ -114,13 +113,11 @@ bool InlineCall(CallInst* call, Function* callee, Context& ctx) { if (callee == caller) return false; - // 1. 构建值映射:被调用者参数 -> 调用实参 std::unordered_map value_map; const auto& params = callee->GetParams(); for (size_t i = 0; i < params.size(); ++i) value_map[params[i].get()] = call->GetArg(i); - // 2. 克隆被调用者指令(Ret 除外),用 InsertBefore 插入到 call 之前 Value* ret_val = nullptr; const auto& callee_insts = callee->GetEntry()->GetInstructions(); for (const auto& inst : callee_insts) { @@ -138,17 +135,89 @@ bool InlineCall(CallInst* call, Function* callee, Context& ctx) { bb->InsertBefore(call, std::move(cloned)); } - // 3. 替换 call 的使用并删除 if (ret_val) call->ReplaceAllUsesWith(ret_val); - call->ClearOperands(); // 清理操作数的 use 记录,防止悬空指针 + call->ClearOperands(); bb->TakeInstruction(call); return true; } +// 将 if-else-return 函数转为单 BB mul select +// BB0(cmp+condbr) → BB1(ret tv), BB2(ret fv) +// → BB0: zext+sub+mul+add+ret (即 fv + (tv-fv)*zext(cmp)) +static bool TryConvertIfElseToSelect(Function* func, Context& ctx) { + auto& blocks = func->GetBlocks(); + if (blocks.size() != 3) return false; + + BasicBlock* entry = func->GetEntry(); + auto& entry_insts = entry->GetInstructions(); + if (entry_insts.empty()) return false; + + auto* br = dynamic_cast(entry_insts.back().get()); + if (!br) return false; + auto* true_bb = br->GetTrueTarget(); + auto* false_bb = br->GetFalseTarget(); + + // 检查两个目标 BB 没有副作用(允许 Load 从 entry alloca 加载) + auto get_ret_val = [entry](BasicBlock* bb) -> Value* { + for (const auto& inst : bb->GetInstructions()) { + auto op = inst->GetOpcode(); + if (op == Opcode::Store || op == Opcode::Call || + op == Opcode::GEP || op == Opcode::Alloca) + return nullptr; + if (op == Opcode::Load) { + // 允许 Load(Mem2Reg 跨 BB 的残余),只要不是数组/全局访问 + auto* load = static_cast(inst.get()); + auto* ptr = load->GetPtr(); + if (!dynamic_cast(ptr)) + return nullptr; + // 单标量 alloca 可以接受 + } + if (op == Opcode::Ret) { + auto* ret = static_cast(inst.get()); + return ret->HasValue() ? ret->GetValue() : nullptr; + } + } + return nullptr; + }; + + Value* true_val = get_ret_val(true_bb); + Value* false_val = get_ret_val(false_bb); + if (!true_val || !false_val) return false; + + Value* cmp_val = br->GetCond(); + + // 移除 CondBr + entry->TakeInstruction(br); + + // 算术 select: fv + (tv - fv) * zext(cmp) + // zext(cmp: i1 → i32) = 0 or 1 + // sub tv, fv → tv - fv + // mul (tv-fv), zext → (tv-fv) * cond + // add fv, masked → fv + (tv-fv)*cond = cond ? tv : fv + auto* zext = entry->Append(Opcode::ZExt, + Type::GetInt32Type(), cmp_val, ctx.NextTemp()); + auto* diff = entry->Append(Opcode::Sub, Type::GetInt32Type(), + true_val, false_val, ctx.NextTemp()); + auto* masked = entry->Append(Opcode::Mul, Type::GetInt32Type(), + diff, zext, ctx.NextTemp()); + auto* result = entry->Append(Opcode::Add, Type::GetInt32Type(), + false_val, masked, ctx.NextTemp()); + + // 替换为 ret(true_bb/false_bb 变为不可达,后续 CFG 优化清理) + entry->Append(Type::GetVoidType(), result); + return true; +} + } // namespace void RunInline(Module& module) { + // 先将 if-else-return 函数转为单 BB,使其可被内联 + for (const auto& func : module.GetFunctions()) { + if (!func->IsExternal()) + TryConvertIfElseToSelect(func.get(), module.GetContext()); + } + int inlined = 0; bool changed = true; int round = 0; @@ -164,7 +233,6 @@ void RunInline(Module& module) { } if (inlineable.empty()) break; - // 每个函数内部反复扫描,直到没有可内联的 call 为止 for (const auto& func : module.GetFunctions()) { if (func->IsExternal()) continue; @@ -183,7 +251,7 @@ void RunInline(Module& module) { ++inlined; func_changed = true; changed = true; - break; // 指令列表已修改,重新扫描当前函数 + break; } if (func_changed) break; } diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index 471be3f8..13b8aea4 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -538,6 +538,12 @@ namespace mir case ir::Opcode::Mod: opcode = Opcode::ModRR; break; + case ir::Opcode::And: + opcode = Opcode::AndRR; + break; + case ir::Opcode::Or: + opcode = Opcode::OrRR; + break; default: break; } @@ -1838,6 +1844,8 @@ namespace mir case ir::Opcode::Mul: case ir::Opcode::Div: case ir::Opcode::Mod: + case ir::Opcode::And: + case ir::Opcode::Or: { auto &bin = static_cast(inst); if (IsFloatType(bin.GetType()))