perf(mir): 添加 Madd(乘加)指令——sum+a*b → madd sum,a,b,sum

限制:仅折合非常量乘数(常量有 shift 优化更优)。
SSA 单使用 Mul 被跳过,由 Add 处统一发射 Madd。
lzk
lzkk 3 days ago
parent 035f83b209
commit 39e4dada13

@ -181,6 +181,7 @@ namespace mir
Ret,
LoadAddr,
MovReg,
Madd,
};
enum class CondCode

@ -59,6 +59,8 @@ namespace mir
return "fdiv";
case Opcode::ModRR:
return "msub";
case Opcode::Madd:
return "madd";
case Opcode::AndRR:
return "and";
case Opcode::OrRR:
@ -713,9 +715,10 @@ namespace mir
return;
case Opcode::Msub:
case Opcode::Madd:
if (operands.size() >= 4)
{
os << " msub ";
os << " " << (instr.GetOpcode() == Opcode::Madd ? "madd" : "msub") << " ";
PrintOperand(operands[0], os);
os << ", ";
PrintOperand(operands[1], os);

@ -408,11 +408,51 @@ namespace mir
break;
}
int dst = function.CreateVReg(VRegClass::Int);
// Madd 折叠sum + (a * b) → madd sum, a, b, sum必须在 EmitIntValue 之前)
if (opcode == Opcode::AddRR)
{
auto *mul_rhs = dynamic_cast<const ir::BinaryInst *>(bin->GetRhs());
auto *mul_lhs = dynamic_cast<const ir::BinaryInst *>(bin->GetLhs());
const ir::BinaryInst *mul_op = nullptr;
bool mul_is_rhs = true;
if (mul_rhs && mul_rhs->GetOpcode() == ir::Opcode::Mul)
mul_op = mul_rhs;
else if (mul_lhs && mul_lhs->GetOpcode() == ir::Opcode::Mul)
{ mul_op = mul_lhs; mul_is_rhs = false; }
if (mul_op)
{
// 仅当 Mul 两端都不是常量时才折叠(常量有 shift/强度削减更优)
int dummy;
if (!TryGetConstantInt(mul_op->GetRhs(), dummy) &&
!TryGetConstantInt(mul_op->GetLhs(), dummy))
{
int a = EmitIntValue(mul_op->GetLhs(), function, value_vregs,
scalar_slots, array_slots, block);
int b = EmitIntValue(mul_op->GetRhs(), function, value_vregs,
scalar_slots, array_slots, block);
int acc = EmitIntValue(mul_is_rhs ? bin->GetLhs() : bin->GetRhs(),
function, value_vregs,
scalar_slots, array_slots, block);
block.Append(Opcode::Madd,
{Operand::VReg(dst, VRegClass::Int),
Operand::VReg(a, VRegClass::Int),
Operand::VReg(b, VRegClass::Int),
Operand::VReg(acc, VRegClass::Int)});
value_vregs[value] = dst;
value_vregs[mul_op] = dst;
return dst;
}
}
}
int lhs = EmitIntValue(bin->GetLhs(), function, value_vregs,
scalar_slots, array_slots, block);
int rhs = EmitIntValue(bin->GetRhs(), function, value_vregs,
scalar_slots, array_slots, block);
int dst = function.CreateVReg(VRegClass::Int);
if (opcode == Opcode::MulRR)
{

@ -243,6 +243,7 @@ namespace mir
break;
case Opcode::Msub:
case Opcode::Madd:
if (ops.size() >= 4)
{
if (ops[0].GetKind() == Operand::Kind::VReg)

Loading…
Cancel
Save