From 39e4dada13f43cd739a4bd0895d3c86c02e9a588 Mon Sep 17 00:00:00 2001 From: lzkk <956449176@qq.com> Date: Thu, 28 May 2026 02:55:51 +0800 Subject: [PATCH] =?UTF-8?q?perf(mir):=20=E6=B7=BB=E5=8A=A0=20Madd=EF=BC=88?= =?UTF-8?q?=E4=B9=98=E5=8A=A0=EF=BC=89=E6=8C=87=E4=BB=A4=E2=80=94=E2=80=94?= =?UTF-8?q?sum+a*b=20=E2=86=92=20madd=20sum,a,b,sum?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 限制:仅折合非常量乘数(常量有 shift 优化更优)。 SSA 单使用 Mul 被跳过,由 Add 处统一发射 Madd。 --- src/include/mir/MIR.h | 1 + src/mir/AsmPrinter.cpp | 5 ++++- src/mir/Lowering.cpp | 42 +++++++++++++++++++++++++++++++++++++++++- src/mir/RegAlloc.cpp | 1 + 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/src/include/mir/MIR.h b/src/include/mir/MIR.h index 64c949cf..47afad02 100644 --- a/src/include/mir/MIR.h +++ b/src/include/mir/MIR.h @@ -181,6 +181,7 @@ namespace mir Ret, LoadAddr, MovReg, + Madd, }; enum class CondCode diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index 27568c8b..6d228441 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -59,6 +59,8 @@ namespace mir return "fdiv"; case Opcode::ModRR: return "msub"; + case Opcode::Madd: + return "madd"; case Opcode::AndRR: return "and"; case Opcode::OrRR: @@ -713,9 +715,10 @@ namespace mir return; case Opcode::Msub: + case Opcode::Madd: if (operands.size() >= 4) { - os << " msub "; + os << " " << (instr.GetOpcode() == Opcode::Madd ? "madd" : "msub") << " "; PrintOperand(operands[0], os); os << ", "; PrintOperand(operands[1], os); diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index 03dfb230..39e4924e 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -408,11 +408,51 @@ namespace mir break; } + int dst = function.CreateVReg(VRegClass::Int); + + // Madd 折叠:sum + (a * b) → madd sum, a, b, sum(必须在 EmitIntValue 之前) + if (opcode == Opcode::AddRR) + { + auto *mul_rhs = dynamic_cast(bin->GetRhs()); + auto *mul_lhs = dynamic_cast(bin->GetLhs()); + const ir::BinaryInst *mul_op = nullptr; + bool mul_is_rhs = true; + + if (mul_rhs && mul_rhs->GetOpcode() == ir::Opcode::Mul) + mul_op = mul_rhs; + else if (mul_lhs && mul_lhs->GetOpcode() == ir::Opcode::Mul) + { mul_op = mul_lhs; mul_is_rhs = false; } + + if (mul_op) + { + // 仅当 Mul 两端都不是常量时才折叠(常量有 shift/强度削减更优) + int dummy; + if (!TryGetConstantInt(mul_op->GetRhs(), dummy) && + !TryGetConstantInt(mul_op->GetLhs(), dummy)) + { + int a = EmitIntValue(mul_op->GetLhs(), function, value_vregs, + scalar_slots, array_slots, block); + int b = EmitIntValue(mul_op->GetRhs(), function, value_vregs, + scalar_slots, array_slots, block); + int acc = EmitIntValue(mul_is_rhs ? bin->GetLhs() : bin->GetRhs(), + function, value_vregs, + scalar_slots, array_slots, block); + block.Append(Opcode::Madd, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(a, VRegClass::Int), + Operand::VReg(b, VRegClass::Int), + Operand::VReg(acc, VRegClass::Int)}); + value_vregs[value] = dst; + value_vregs[mul_op] = dst; + return dst; + } + } + } + int lhs = EmitIntValue(bin->GetLhs(), function, value_vregs, scalar_slots, array_slots, block); int rhs = EmitIntValue(bin->GetRhs(), function, value_vregs, scalar_slots, array_slots, block); - int dst = function.CreateVReg(VRegClass::Int); if (opcode == Opcode::MulRR) { diff --git a/src/mir/RegAlloc.cpp b/src/mir/RegAlloc.cpp index b91e76ee..17d5374f 100644 --- a/src/mir/RegAlloc.cpp +++ b/src/mir/RegAlloc.cpp @@ -243,6 +243,7 @@ namespace mir break; case Opcode::Msub: + case Opcode::Madd: if (ops.size() >= 4) { if (ops[0].GetKind() == Operand::Kind::VReg)