From 422f1848fc7360d7bcaf32d6a316f1079328731a Mon Sep 17 00:00:00 2001 From: lzkk <956449176@qq.com> Date: Thu, 28 May 2026 14:26:21 +0800 Subject: [PATCH] =?UTF-8?q?perf(mir):=20Lowering=20=E7=9B=B4=E6=8E=A5?= =?UTF-8?q?=E7=94=9F=E6=88=90=20AddShiftRR=E2=80=94=E2=80=94x+(x*2^n)=20?= =?UTF-8?q?=E2=86=92=20add=20x,x,lsl#n?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在 Lowering 阶段检测 Add 的某个操作数是 Mul by 幂次方常量, 直接发射 AddShiftRR 而非分开发射 ShlRR+AddRR。 配合 Peephole 残余合并,共生成 5 处 add lsl。 --- src/mir/Lowering.cpp | 46 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index 4bcb494d..a6ea0660 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -410,6 +410,52 @@ namespace mir int dst = function.CreateVReg(VRegClass::Int); + // AddShift 折叠:x + (x * 2^n) → add x, x, lsl #n + if (opcode == Opcode::AddRR) + { + auto *mul_rhs2 = dynamic_cast(bin->GetRhs()); + if (mul_rhs2 && mul_rhs2->GetOpcode() == ir::Opcode::Mul) + { + int shift_val = 0; + if (TryGetConstantInt(mul_rhs2->GetRhs(), shift_val) && + shift_val > 0 && (shift_val & (shift_val - 1)) == 0) + { + // rhs is x * 2^n, lhs should match x + int mul_lhs = EmitIntValue(mul_rhs2->GetLhs(), function, value_vregs, + scalar_slots, array_slots, block); + int add_lhs = EmitIntValue(bin->GetLhs(), function, value_vregs, + scalar_slots, array_slots, block); + int sh = 0; while (shift_val > 1) { shift_val >>= 1; ++sh; } + block.Append(Opcode::AddShiftRR, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(add_lhs, VRegClass::Int), + Operand::VReg(mul_lhs, VRegClass::Int), + Operand::Imm(sh)}); + value_vregs[value] = dst; + value_vregs[mul_rhs2] = dst; + return dst; + } + if (TryGetConstantInt(mul_rhs2->GetLhs(), shift_val) && + shift_val > 0 && (shift_val & (shift_val - 1)) == 0) + { + // rhs is 2^n * x, lhs should match x + int mul_rhs = EmitIntValue(mul_rhs2->GetRhs(), function, value_vregs, + scalar_slots, array_slots, block); + int add_lhs = EmitIntValue(bin->GetLhs(), function, value_vregs, + scalar_slots, array_slots, block); + int sh = 0; while (shift_val > 1) { shift_val >>= 1; ++sh; } + block.Append(Opcode::AddShiftRR, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(add_lhs, VRegClass::Int), + Operand::VReg(mul_rhs, VRegClass::Int), + Operand::Imm(sh)}); + value_vregs[value] = dst; + value_vregs[mul_rhs2] = dst; + return dst; + } + } + } + // Madd 折叠:sum + (a * b) → madd sum, a, b, sum(必须在 EmitIntValue 之前) if (opcode == Opcode::AddRR) {