diff --git a/src/include/mir/MIR.h b/src/include/mir/MIR.h index 64ff5c4d..75252cb4 100644 --- a/src/include/mir/MIR.h +++ b/src/include/mir/MIR.h @@ -164,6 +164,7 @@ namespace mir FCmpRR, CSet, Csel, + Csneg, Smull, Msub, NegRR, diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index 41479843..dd4ba798 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -77,6 +77,8 @@ namespace mir return "fcmp"; case Opcode::CSet: return "cset"; + case Opcode::Csneg: + return "csneg"; case Opcode::Scvtf: return "scvtf"; case Opcode::FCvtzs: @@ -657,6 +659,19 @@ namespace mir } return; + case Opcode::Csneg: + if (operands.size() >= 4) + { + os << " csneg "; + PrintOperand(operands[0], os); + os << ", "; + PrintOperand(operands[1], os); + os << ", "; + PrintOperand(operands[2], os); + os << ", " << CondCodeToAsm(static_cast(operands[3].GetImm())) << "\n"; + } + return; + case Opcode::Smull: if (operands.size() >= 3) { @@ -771,6 +786,47 @@ namespace mir for (const auto &global : module.GetGlobals()) { const std::string asm_name = NormalizeAsmSymbol(global.name); + + bool is_zero_init = false; + if (global.kind == MachineGlobal::Kind::I32Scalar && global.init_value == 0) + { + is_zero_init = true; + } + if (global.kind == MachineGlobal::Kind::I32Array) + { + bool all_zero = true; + for (auto v : global.init_values) + { + if (v != 0) + { + all_zero = false; + break; + } + } + if (all_zero) + { + is_zero_init = true; + } + } + + if (is_zero_init) + { + os << " .bss\n"; + os << " .globl " << asm_name << "\n"; + os << " .p2align 2\n"; + os << asm_name << ":\n"; + if (global.kind == MachineGlobal::Kind::I32Scalar) + { + os << " .space 4\n"; + } + else + { + os << " .space " << (global.array_size * 4) << "\n"; + } + os << " .data\n"; + continue; + } + os << " .globl " << asm_name << "\n"; os << " .p2align 2\n"; os << asm_name << ":\n"; diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index cbbfaaf5..c574751b 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -437,6 +437,25 @@ namespace mir tmp >>= 1; ++shift; } + if (val == 2) + { + int sign_bit = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::ShrRR, + {Operand::VReg(sign_bit, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(31)}); + int biased = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::AddRR, + {Operand::VReg(biased, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::VReg(sign_bit, VRegClass::Int)}); + block.Append(Opcode::AsrRR, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(biased, VRegClass::Int), + Operand::Imm(shift)}); + value_vregs[value] = dst; + return dst; + } int bias = (1 << shift) - 1; int biased = function.CreateVReg(VRegClass::Int); if (bias <= 4095) @@ -483,6 +502,29 @@ namespace mir tmp >>= 1; ++shift; } + if (abs_val == 2) + { + int sign_bit = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::ShrRR, + {Operand::VReg(sign_bit, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(31)}); + int biased = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::AddRR, + {Operand::VReg(biased, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::VReg(sign_bit, VRegClass::Int)}); + int pos_q = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::AsrRR, + {Operand::VReg(pos_q, VRegClass::Int), + Operand::VReg(biased, VRegClass::Int), + Operand::Imm(shift)}); + block.Append(Opcode::NegRR, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(pos_q, VRegClass::Int)}); + value_vregs[value] = dst; + return dst; + } int bias = (1 << shift) - 1; int biased = function.CreateVReg(VRegClass::Int); if (bias <= 4095) @@ -534,6 +576,24 @@ namespace mir int val = rhs_const->GetValue(); if (val > 0 && (val & (val - 1)) == 0) { + if (val == 2) + { + int tmp = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::AndRR, + {Operand::VReg(tmp, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(1)}); + block.Append(Opcode::CmpImm, + {Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(0)}); + block.Append(Opcode::Csneg, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(tmp, VRegClass::Int), + Operand::VReg(tmp, VRegClass::Int), + Operand::Imm(static_cast(CondCode::GE))}); + value_vregs[value] = dst; + return dst; + } int bias = val - 1; int biased = function.CreateVReg(VRegClass::Int); if (bias <= 4095) @@ -590,6 +650,24 @@ namespace mir if (val < 0 && (-val & (-val - 1)) == 0 && val != -1) { int abs_val = -val; + if (abs_val == 2) + { + int tmp = function.CreateVReg(VRegClass::Int); + block.Append(Opcode::AndRR, + {Operand::VReg(tmp, VRegClass::Int), + Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(1)}); + block.Append(Opcode::CmpImm, + {Operand::VReg(lhs, VRegClass::Int), + Operand::Imm(0)}); + block.Append(Opcode::Csneg, + {Operand::VReg(dst, VRegClass::Int), + Operand::VReg(tmp, VRegClass::Int), + Operand::VReg(tmp, VRegClass::Int), + Operand::Imm(static_cast(CondCode::GE))}); + value_vregs[value] = dst; + return dst; + } int bias = abs_val - 1; int biased = function.CreateVReg(VRegClass::Int); if (bias <= 4095) diff --git a/src/mir/RegAlloc.cpp b/src/mir/RegAlloc.cpp index 5fc46ce9..8a99870d 100644 --- a/src/mir/RegAlloc.cpp +++ b/src/mir/RegAlloc.cpp @@ -212,6 +212,18 @@ namespace mir } break; + case Opcode::Csneg: + if (ops.size() >= 3) + { + if (ops[0].GetKind() == Operand::Kind::VReg) + result.defs.push_back(ops[0].GetVRegId()); + if (ops[1].GetKind() == Operand::Kind::VReg) + result.uses.push_back(ops[1].GetVRegId()); + if (ops[2].GetKind() == Operand::Kind::VReg) + result.uses.push_back(ops[2].GetVRegId()); + } + break; + case Opcode::Smull: if (ops.size() >= 3) { diff --git a/src/mir/passes/Peephole.cpp b/src/mir/passes/Peephole.cpp index e59ea22f..1e365fe0 100644 --- a/src/mir/passes/Peephole.cpp +++ b/src/mir/passes/Peephole.cpp @@ -9,27 +9,29 @@ namespace mir namespace { - static bool IsSamePhysReg(PhysReg a, PhysReg b) + static bool IsWReg(PhysReg r) { - int an = static_cast(a); - int bn = static_cast(b); + return static_cast(r) >= static_cast(PhysReg::W0) && + static_cast(r) <= static_cast(PhysReg::W30); + } + + static bool IsXReg(PhysReg r) + { + return static_cast(r) >= static_cast(PhysReg::X0) && + static_cast(r) <= static_cast(PhysReg::X30); + } - if (an == bn) + static bool IsSamePhysRegOrWXPair(PhysReg a, PhysReg b) + { + if (static_cast(a) == static_cast(b)) return true; int aw = static_cast(PhysReg::W0); int ax = static_cast(PhysReg::X0); - int as = static_cast(PhysReg::S0); - if (an >= aw && an <= static_cast(PhysReg::W30) && - bn >= ax && bn <= static_cast(PhysReg::X30)) - { - return (an - aw) == (bn - ax); - } - if (an >= ax && an <= static_cast(PhysReg::X30) && - bn >= aw && bn <= static_cast(PhysReg::W30)) + if (IsWReg(a) && IsXReg(b)) { - return (an - ax) == (bn - aw); + return (static_cast(a) - aw) == (static_cast(b) - ax); } return false; @@ -44,7 +46,17 @@ namespace mir return false; if (ops[0].GetKind() != Operand::Kind::Reg || ops[1].GetKind() != Operand::Kind::Reg) return false; - return ops[0].GetReg() == ops[1].GetReg(); + PhysReg dst = ops[0].GetReg(); + PhysReg src = ops[1].GetReg(); + if (static_cast(dst) == static_cast(src)) + return true; + if (IsWReg(dst) && IsXReg(src)) + { + int aw = static_cast(PhysReg::W0); + int ax = static_cast(PhysReg::X0); + return (static_cast(dst) - aw) == (static_cast(src) - ax); + } + return false; } static bool IsIdentityAddSub(const MachineInstr &inst) @@ -81,13 +93,13 @@ namespace mir while (changed) { changed = false; - for (auto it = insts.begin(); it != insts.end(); ++it) + for (auto it = insts.begin(); it != insts.end();) { if (IsRedundantMovReg(*it)) { it = insts.erase(it); changed = true; - break; + continue; } if (IsIdentityAddSub(*it)) @@ -95,13 +107,15 @@ namespace mir const auto &ops = it->GetOperands(); if (ops[0].GetKind() == Operand::Kind::Reg && ops[1].GetKind() == Operand::Kind::Reg && - ops[0].GetReg() == ops[1].GetReg()) + IsSamePhysRegOrWXPair(ops[0].GetReg(), ops[1].GetReg())) { it = insts.erase(it); changed = true; - break; + continue; } } + + ++it; } if (!changed) @@ -117,7 +131,7 @@ namespace mir const auto &l_ops = next->GetOperands(); if (s_ops[0].GetKind() == Operand::Kind::Reg && l_ops[0].GetKind() == Operand::Kind::Reg && - s_ops[0].GetReg() == l_ops[0].GetReg()) + IsSamePhysRegOrWXPair(s_ops[0].GetReg(), l_ops[0].GetReg())) { next = insts.erase(next); changed = true;