diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index 4bcb494d..a6ea0660 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -410,6 +410,52 @@ namespace mir int dst = function.CreateVReg(VRegClass::Int); + // AddShift 折叠:x + (x * 2^n) → add x, x, lsl #n + if (opcode == Opcode::AddRR) + { + auto *mul_rhs2 = dynamic_cast(bin->GetRhs()); + if (mul_rhs2 && mul_rhs2->GetOpcode() == ir::Opcode::Mul) + { + int shift_val = 0; + if (TryGetConstantInt(mul_rhs2->GetRhs(), shift_val) && + shift_val > 0 && (shift_val & (shift_val - 1)) == 0) + { + // rhs is x * 2^n, lhs should match x + int mul_lhs = EmitIntValue(mul_rhs2->GetLhs(), function, value_vregs, + scalar_slots, array_slots, block); + int add_lhs = EmitIntValue(bin->GetLhs(), function, value_vregs, + scalar_slots, array_slots, block); + int sh = 0; while (shift_val > 1) { shift_val >>= 1; ++sh; } + block.Append(Opcode::AddShiftRR, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(add_lhs, VRegClass::Int), + Operand::VReg(mul_lhs, VRegClass::Int), + Operand::Imm(sh)}); + value_vregs[value] = dst; + value_vregs[mul_rhs2] = dst; + return dst; + } + if (TryGetConstantInt(mul_rhs2->GetLhs(), shift_val) && + shift_val > 0 && (shift_val & (shift_val - 1)) == 0) + { + // rhs is 2^n * x, lhs should match x + int mul_rhs = EmitIntValue(mul_rhs2->GetRhs(), function, value_vregs, + scalar_slots, array_slots, block); + int add_lhs = EmitIntValue(bin->GetLhs(), function, value_vregs, + scalar_slots, array_slots, block); + int sh = 0; while (shift_val > 1) { shift_val >>= 1; ++sh; } + block.Append(Opcode::AddShiftRR, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(add_lhs, VRegClass::Int), + Operand::VReg(mul_rhs, VRegClass::Int), + Operand::Imm(sh)}); + value_vregs[value] = dst; + value_vregs[mul_rhs2] = dst; + return dst; + } + } + } + // Madd 折叠:sum + (a * b) → madd sum, a, b, sum(必须在 EmitIntValue 之前) if (opcode == Opcode::AddRR) {