You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/mir/passes/SpillReduction.cpp

255 lines
7.4 KiB

#include "mir/MIR.h"
#include <unordered_map>
#include <utility>
#include <vector>
namespace mir {
namespace {
struct RematDef {
enum class Kind { Invalid, ImmCopy, Lea };
Kind kind = Kind::Invalid;
ValueType type = ValueType::Void;
MachineOperand source;
AddressExpr address;
};
bool IsCheapRematerializableDef(const MachineInstr& inst, RematDef& def) {
const auto defs = inst.GetDefs();
if (defs.size() != 1) {
return false;
}
if (inst.GetOpcode() == MachineInstr::Opcode::Copy) {
const auto& operands = inst.GetOperands();
if (operands.size() < 2 || operands[1].GetKind() != OperandKind::Imm) {
return false;
}
def.kind = RematDef::Kind::ImmCopy;
def.type = inst.GetValueType();
def.source = operands[1];
return true;
}
if (inst.GetOpcode() != MachineInstr::Opcode::Lea || !inst.HasAddress()) {
return false;
}
const auto& address = inst.GetAddress();
if (address.base_kind == AddrBaseKind::VReg || !address.scaled_vregs.empty()) {
return false;
}
def.kind = RematDef::Kind::Lea;
def.type = ValueType::Ptr;
def.address = address;
return true;
}
MachineInstr BuildRematInstr(int dst_vreg, const RematDef& def) {
switch (def.kind) {
case RematDef::Kind::ImmCopy: {
MachineInstr inst(MachineInstr::Opcode::Copy,
{MachineOperand::VReg(dst_vreg), def.source});
inst.SetValueType(def.type);
return inst;
}
case RematDef::Kind::Lea: {
MachineInstr inst(MachineInstr::Opcode::Lea, {MachineOperand::VReg(dst_vreg)});
inst.SetAddress(def.address);
inst.SetValueType(ValueType::Ptr);
return inst;
}
case RematDef::Kind::Invalid:
break;
}
return MachineInstr(MachineInstr::Opcode::Unreachable, {});
}
bool RewriteMappedOperand(MachineOperand& operand,
const std::unordered_map<int, int>& rename_map) {
if (operand.GetKind() != OperandKind::VReg) {
return false;
}
auto it = rename_map.find(operand.GetVReg());
if (it == rename_map.end() || it->second == operand.GetVReg()) {
return false;
}
operand = MachineOperand::VReg(it->second);
return true;
}
bool RewriteMappedAddress(AddressExpr& address,
const std::unordered_map<int, int>& rename_map) {
bool changed = false;
if (address.base_kind == AddrBaseKind::VReg && address.base_index >= 0) {
auto it = rename_map.find(address.base_index);
if (it != rename_map.end() && it->second != address.base_index) {
address.base_index = it->second;
changed = true;
}
}
for (auto& term : address.scaled_vregs) {
auto it = rename_map.find(term.first);
if (it != rename_map.end() && it->second != term.first) {
term.first = it->second;
changed = true;
}
}
return changed;
}
bool RewriteUses(MachineInstr& inst, const std::unordered_map<int, int>& rename_map) {
bool changed = false;
auto& operands = inst.GetOperands();
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::Copy:
case MachineInstr::Opcode::ZExt:
case MachineInstr::Opcode::ItoF:
case MachineInstr::Opcode::FtoI:
case MachineInstr::Opcode::FSqrt:
case MachineInstr::Opcode::FNeg:
if (operands.size() >= 2) {
changed |= RewriteMappedOperand(operands[1], rename_map);
}
break;
case MachineInstr::Opcode::Store:
if (!operands.empty()) {
changed |= RewriteMappedOperand(operands[0], rename_map);
}
break;
case MachineInstr::Opcode::Add:
case MachineInstr::Opcode::Sub:
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::Rem:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr:
case MachineInstr::Opcode::FAdd:
case MachineInstr::Opcode::FSub:
case MachineInstr::Opcode::FMul:
case MachineInstr::Opcode::FDiv:
case MachineInstr::Opcode::ICmp:
case MachineInstr::Opcode::FCmp:
if (operands.size() >= 2) {
changed |= RewriteMappedOperand(operands[1], rename_map);
}
if (operands.size() >= 3) {
changed |= RewriteMappedOperand(operands[2], rename_map);
}
break;
case MachineInstr::Opcode::CondBr:
if (!operands.empty()) {
changed |= RewriteMappedOperand(operands[0], rename_map);
}
break;
case MachineInstr::Opcode::Call: {
const size_t arg_begin = inst.GetCallReturnType() == ValueType::Void ? 0 : 1;
for (size_t i = arg_begin; i < operands.size(); ++i) {
changed |= RewriteMappedOperand(operands[i], rename_map);
}
break;
}
case MachineInstr::Opcode::Ret:
if (!operands.empty()) {
changed |= RewriteMappedOperand(operands[0], rename_map);
}
break;
case MachineInstr::Opcode::Memset:
if (!operands.empty()) {
changed |= RewriteMappedOperand(operands[0], rename_map);
}
if (operands.size() >= 2) {
changed |= RewriteMappedOperand(operands[1], rename_map);
}
break;
case MachineInstr::Opcode::Arg:
case MachineInstr::Opcode::Load:
case MachineInstr::Opcode::Lea:
case MachineInstr::Opcode::Br:
case MachineInstr::Opcode::Unreachable:
break;
}
if (inst.HasAddress()) {
changed |= RewriteMappedAddress(inst.GetAddress(), rename_map);
}
return changed;
}
bool RunSpillReductionOnFunction(MachineFunction& function) {
bool changed = false;
for (auto& block_ptr : function.GetBlocks()) {
auto& instructions = block_ptr->GetInstructions();
std::unordered_map<int, RematDef> available_defs;
std::unordered_map<int, RematDef> after_call_defs;
std::unordered_map<int, int> rename_map;
bool after_call = false;
for (size_t i = 0; i < instructions.size(); ++i) {
if (after_call) {
const auto uses = instructions[i].GetUses();
for (int use : uses) {
if (rename_map.count(use) != 0) {
continue;
}
auto it = after_call_defs.find(use);
if (it == after_call_defs.end()) {
continue;
}
const int new_vreg = function.NewVReg(function.GetVRegInfo(use).type);
instructions.insert(instructions.begin() + static_cast<long long>(i),
BuildRematInstr(new_vreg, it->second));
++i;
rename_map[use] = new_vreg;
available_defs[new_vreg] = it->second;
changed = true;
}
RewriteUses(instructions[i], rename_map);
}
const auto defs = instructions[i].GetDefs();
for (int def : defs) {
available_defs.erase(def);
after_call_defs.erase(def);
rename_map.erase(def);
}
RematDef def;
if (IsCheapRematerializableDef(instructions[i], def)) {
for (int vreg : defs) {
available_defs[vreg] = def;
}
}
if (instructions[i].GetOpcode() == MachineInstr::Opcode::Call ||
instructions[i].GetOpcode() == MachineInstr::Opcode::Memset) {
after_call_defs = available_defs;
rename_map.clear();
after_call = true;
}
}
}
return changed;
}
} // namespace
bool RunSpillReduction(MachineModule& module) {
bool changed = false;
for (auto& function : module.GetFunctions()) {
if (function) {
changed |= RunSpillReductionOnFunction(*function);
}
}
return changed;
}
} // namespace mir