You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/mir/passes/Peephole.cpp

384 lines
13 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "mir/MIR.h"
#include <vector>
#include "utils/Log.h"
namespace mir
{
namespace
{
static CondCode InvertCondCode(CondCode cc)
{
switch (cc)
{
case CondCode::EQ: return CondCode::NE;
case CondCode::NE: return CondCode::EQ;
case CondCode::LT: return CondCode::GE;
case CondCode::LE: return CondCode::GT;
case CondCode::GT: return CondCode::LE;
case CondCode::GE: return CondCode::LT;
}
return cc;
}
static bool IsSamePhysReg(PhysReg a, PhysReg b)
{
int an = static_cast<int>(a);
int bn = static_cast<int>(b);
if (an == bn)
return true;
int aw = static_cast<int>(PhysReg::W0);
int ax = static_cast<int>(PhysReg::X0);
int as = static_cast<int>(PhysReg::S0);
if (an >= aw && an <= static_cast<int>(PhysReg::W30) &&
bn >= ax && bn <= static_cast<int>(PhysReg::X30))
{
return (an - aw) == (bn - ax);
}
if (an >= ax && an <= static_cast<int>(PhysReg::X30) &&
bn >= aw && bn <= static_cast<int>(PhysReg::W30))
{
return (an - ax) == (bn - aw);
}
return false;
}
static bool IsRedundantMovReg(const MachineInstr &inst)
{
if (inst.GetOpcode() != Opcode::MovReg)
return false;
const auto &ops = inst.GetOperands();
if (ops.size() < 2)
return false;
if (ops[0].GetKind() != Operand::Kind::Reg || ops[1].GetKind() != Operand::Kind::Reg)
return false;
return ops[0].GetReg() == ops[1].GetReg();
}
static bool IsIdentityAddSub(const MachineInstr &inst)
{
if (inst.GetOpcode() != Opcode::AddRR && inst.GetOpcode() != Opcode::SubRR)
return false;
const auto &ops = inst.GetOperands();
if (ops.size() < 3)
return false;
if (ops[2].GetKind() != Operand::Kind::Imm)
return false;
return ops[2].GetImm() == 0;
}
static bool IsRedundantStoreLoad(const MachineInstr &store,
const MachineInstr &load)
{
if (store.GetOpcode() != Opcode::StoreStack || load.GetOpcode() != Opcode::LoadStack)
return false;
const auto &s_ops = store.GetOperands();
const auto &l_ops = load.GetOperands();
if (s_ops.size() < 2 || l_ops.size() < 2)
return false;
if (s_ops[1].GetKind() != Operand::Kind::FrameIndex ||
l_ops[1].GetKind() != Operand::Kind::FrameIndex)
return false;
return s_ops[1].GetFrameIndex() == l_ops[1].GetFrameIndex();
}
static bool IsForwardableStoreLoad(const MachineInstr &store,
const MachineInstr &load)
{
if (store.GetOpcode() != Opcode::StoreStack || load.GetOpcode() != Opcode::LoadStack)
return false;
const auto &s_ops = store.GetOperands();
const auto &l_ops = load.GetOperands();
if (s_ops.size() < 2 || l_ops.size() < 2)
return false;
if (s_ops[0].GetKind() != Operand::Kind::Reg ||
l_ops[0].GetKind() != Operand::Kind::Reg)
return false;
if (s_ops[1].GetKind() != Operand::Kind::FrameIndex ||
l_ops[1].GetKind() != Operand::Kind::FrameIndex)
return false;
return s_ops[1].GetFrameIndex() == l_ops[1].GetFrameIndex() &&
s_ops[0].GetReg() != l_ops[0].GetReg();
}
// 全局变量StoreGlobal 后紧跟 LoadGlobal 同一符号 → 用 MovReg 替代 LoadGlobal
static bool IsGlobalFwdStoreLoad(const MachineInstr &a, const MachineInstr &b)
{
if (a.GetOpcode() != Opcode::StoreGlobal || b.GetOpcode() != Opcode::LoadGlobal)
return false;
const auto &a_ops = a.GetOperands();
const auto &b_ops = b.GetOperands();
if (a_ops.size() < 2 || b_ops.size() < 2)
return false;
if (a_ops[1].GetKind() != Operand::Kind::Symbol ||
b_ops[1].GetKind() != Operand::Kind::Symbol)
return false;
return a_ops[1].GetSymbol() == b_ops[1].GetSymbol();
}
// 全局变量LoadGlobal 后紧跟 LoadGlobal 同一符号 → 用 MovReg 替代第二个
static bool IsGlobalRedundantLoad(const MachineInstr &a, const MachineInstr &b)
{
if (a.GetOpcode() != Opcode::LoadGlobal || b.GetOpcode() != Opcode::LoadGlobal)
return false;
const auto &a_ops = a.GetOperands();
const auto &b_ops = b.GetOperands();
if (a_ops.size() < 2 || b_ops.size() < 2)
return false;
if (a_ops[1].GetKind() != Operand::Kind::Symbol ||
b_ops[1].GetKind() != Operand::Kind::Symbol)
return false;
return a_ops[1].GetSymbol() == b_ops[1].GetSymbol();
}
static bool TryMergeZeroStores(MachineInstr &first, MachineInstr &second)
{
if (first.GetOpcode() != Opcode::StoreStack ||
second.GetOpcode() != Opcode::StoreStack)
return false;
const auto &f_ops = first.GetOperands();
const auto &s_ops = second.GetOperands();
if (f_ops.size() < 2 || s_ops.size() < 2)
return false;
if (f_ops[0].GetKind() != Operand::Kind::Reg ||
s_ops[0].GetKind() != Operand::Kind::Reg)
return false;
if (f_ops[0].GetReg() != PhysReg::WZR ||
s_ops[0].GetReg() != PhysReg::WZR)
return false;
if (f_ops[1].GetKind() != Operand::Kind::FrameIndex ||
s_ops[1].GetKind() != Operand::Kind::FrameIndex)
return false;
int fi1 = f_ops[1].GetFrameIndex();
int fi2 = s_ops[1].GetFrameIndex();
if (fi2 != fi1 + 1)
return false;
first = MachineInstr(Opcode::StoreStack,
{Operand::Reg(PhysReg::XZR),
Operand::FrameIndex(fi1)});
return true;
}
static void RunPeepholeOnBlock(MachineBasicBlock &block,
const MachineFunction &function)
{
auto &insts = block.GetInstructions();
bool changed = true;
while (changed)
{
changed = false;
for (auto it = insts.begin(); it != insts.end(); ++it)
{
if (IsRedundantMovReg(*it))
{
it = insts.erase(it);
changed = true;
break;
}
if (IsIdentityAddSub(*it))
{
const auto &ops = it->GetOperands();
if (ops[0].GetKind() == Operand::Kind::Reg &&
ops[1].GetKind() == Operand::Kind::Reg &&
ops[0].GetReg() == ops[1].GetReg())
{
it = insts.erase(it);
changed = true;
break;
}
}
}
if (!changed)
{
for (auto it = insts.begin(); it != insts.end(); ++it)
{
auto next = std::next(it);
if (next != insts.end() && TryMergeZeroStores(*it, *next))
{
next = insts.erase(next);
changed = true;
break;
}
}
}
if (!changed)
{
for (auto it = insts.begin(); it != insts.end(); ++it)
{
if (it->GetOpcode() == Opcode::StoreStack)
{
auto next = std::next(it);
if (next != insts.end() && IsForwardableStoreLoad(*it, *next))
{
const auto &s_ops = it->GetOperands();
const auto &l_ops = next->GetOperands();
*next = MachineInstr(Opcode::MovReg,
{Operand::Reg(l_ops[0].GetReg()),
Operand::Reg(s_ops[0].GetReg())});
changed = true;
break;
}
}
}
}
if (!changed)
{
for (auto it = insts.begin(); it != insts.end(); ++it)
{
if (it->GetOpcode() == Opcode::StoreStack)
{
auto next = std::next(it);
if (next != insts.end() && IsRedundantStoreLoad(*it, *next))
{
const auto &s_ops = it->GetOperands();
const auto &l_ops = next->GetOperands();
if (s_ops[0].GetKind() == Operand::Kind::Reg &&
l_ops[0].GetKind() == Operand::Kind::Reg &&
s_ops[0].GetReg() == l_ops[0].GetReg())
{
next = insts.erase(next);
changed = true;
break;
}
}
}
}
}
}
// 全局变量 StoreGlobal → LoadGlobal 同一符号转发
if (!changed)
{
for (auto it = insts.begin(); it != insts.end(); ++it)
{
if (it->GetOpcode() == Opcode::StoreGlobal)
{
auto next = std::next(it);
if (next != insts.end() && IsGlobalFwdStoreLoad(*it, *next))
{
const auto &s_ops = it->GetOperands();
const auto &l_ops = next->GetOperands();
// 若已是同一寄存器则直接删除 load否则用 MovReg 替代
if (s_ops[0].GetKind() == l_ops[0].GetKind() &&
s_ops[0].GetKind() == Operand::Kind::Reg &&
s_ops[0].GetReg() == l_ops[0].GetReg())
next = insts.erase(next);
else
*next = MachineInstr(Opcode::MovReg, {l_ops[0], s_ops[0]});
changed = true;
break;
}
}
}
}
// 全局变量 LoadGlobal → LoadGlobal 同一符号消除
if (!changed)
{
for (auto it = insts.begin(); it != insts.end(); ++it)
{
if (it->GetOpcode() == Opcode::LoadGlobal)
{
auto next = std::next(it);
if (next != insts.end() && IsGlobalRedundantLoad(*it, *next))
{
const auto &first_ops = it->GetOperands();
const auto &second_ops = next->GetOperands();
if (first_ops[0].GetKind() == second_ops[0].GetKind() &&
first_ops[0].GetKind() == Operand::Kind::Reg &&
first_ops[0].GetReg() == second_ops[0].GetReg())
next = insts.erase(next);
else
*next = MachineInstr(Opcode::MovReg, {second_ops[0], first_ops[0]});
changed = true;
break;
}
}
}
}
// 分支 fallthrough: 末尾 Br 的目标是紧邻下一个块 → 删除 Br
// CondBr + Br 模式CondBr 条件反转使 fallthrough 对齐
if (!insts.empty())
{
const auto &blocks = function.GetBlocks();
int my_idx = -1;
for (size_t bi = 0; bi < blocks.size(); ++bi)
{
if (blocks[bi].get() == &block) { my_idx = static_cast<int>(bi); break; }
}
int next_label = (my_idx >= 0 && my_idx + 1 < static_cast<int>(blocks.size()))
? blocks[my_idx + 1]->GetLabelId()
: -1;
if (next_label >= 0)
{
// CondBr + Br 模式
if (insts.size() >= 2)
{
auto br_it = insts.end() - 1;
auto cond_it = insts.end() - 2;
if (br_it->GetOpcode() == Opcode::Br &&
cond_it->GetOpcode() == Opcode::CondBr &&
br_it->GetOperands().size() >= 1 &&
br_it->GetOperands()[0].GetKind() == Operand::Kind::Label &&
cond_it->GetOperands().size() >= 2 &&
cond_it->GetOperands()[1].GetKind() == Operand::Kind::Label)
{
int cond_target = cond_it->GetOperands()[1].GetLabel();
int br_target = br_it->GetOperands()[0].GetLabel();
if (cond_target == next_label)
{
// CondBr 目标已是 fallthrough → 反转条件,交换目标
CondCode old_cc = static_cast<CondCode>(cond_it->GetOperands()[0].GetImm());
CondCode new_cc = InvertCondCode(old_cc);
const_cast<Operand &>(cond_it->GetOperands()[0]) = Operand::Imm(static_cast<int>(new_cc));
const_cast<Operand &>(cond_it->GetOperands()[1]) =
Operand::Label(br_target);
insts.pop_back(); // 删除 Br
}
else if (br_target == next_label)
{
// Br 目标已是 fallthrough → 直接删除 Br
insts.pop_back();
}
}
}
}
}
}
} // namespace
void RunPeephole(MachineFunction &function)
{
for (auto &block : function.GetBlocks())
{
if (block)
RunPeepholeOnBlock(*block, function);
}
}
void RunPeephole(MachineModule &module)
{
for (auto &function : module.GetFunctions())
{
if (function)
RunPeephole(*function);
}
}
} // namespace mir