forked from NUDT-compiler/nudt-compiler-cpp
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
255 lines
7.4 KiB
255 lines
7.4 KiB
#include "mir/MIR.h"
|
|
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace mir {
|
|
namespace {
|
|
|
|
struct RematDef {
|
|
enum class Kind { Invalid, ImmCopy, Lea };
|
|
|
|
Kind kind = Kind::Invalid;
|
|
ValueType type = ValueType::Void;
|
|
MachineOperand source;
|
|
AddressExpr address;
|
|
};
|
|
|
|
bool IsCheapRematerializableDef(const MachineInstr& inst, RematDef& def) {
|
|
const auto defs = inst.GetDefs();
|
|
if (defs.size() != 1) {
|
|
return false;
|
|
}
|
|
|
|
if (inst.GetOpcode() == MachineInstr::Opcode::Copy) {
|
|
const auto& operands = inst.GetOperands();
|
|
if (operands.size() < 2 || operands[1].GetKind() != OperandKind::Imm) {
|
|
return false;
|
|
}
|
|
def.kind = RematDef::Kind::ImmCopy;
|
|
def.type = inst.GetValueType();
|
|
def.source = operands[1];
|
|
return true;
|
|
}
|
|
|
|
if (inst.GetOpcode() != MachineInstr::Opcode::Lea || !inst.HasAddress()) {
|
|
return false;
|
|
}
|
|
|
|
const auto& address = inst.GetAddress();
|
|
if (address.base_kind == AddrBaseKind::VReg || !address.scaled_vregs.empty()) {
|
|
return false;
|
|
}
|
|
def.kind = RematDef::Kind::Lea;
|
|
def.type = ValueType::Ptr;
|
|
def.address = address;
|
|
return true;
|
|
}
|
|
|
|
MachineInstr BuildRematInstr(int dst_vreg, const RematDef& def) {
|
|
switch (def.kind) {
|
|
case RematDef::Kind::ImmCopy: {
|
|
MachineInstr inst(MachineInstr::Opcode::Copy,
|
|
{MachineOperand::VReg(dst_vreg), def.source});
|
|
inst.SetValueType(def.type);
|
|
return inst;
|
|
}
|
|
case RematDef::Kind::Lea: {
|
|
MachineInstr inst(MachineInstr::Opcode::Lea, {MachineOperand::VReg(dst_vreg)});
|
|
inst.SetAddress(def.address);
|
|
inst.SetValueType(ValueType::Ptr);
|
|
return inst;
|
|
}
|
|
case RematDef::Kind::Invalid:
|
|
break;
|
|
}
|
|
return MachineInstr(MachineInstr::Opcode::Unreachable, {});
|
|
}
|
|
|
|
bool RewriteMappedOperand(MachineOperand& operand,
|
|
const std::unordered_map<int, int>& rename_map) {
|
|
if (operand.GetKind() != OperandKind::VReg) {
|
|
return false;
|
|
}
|
|
auto it = rename_map.find(operand.GetVReg());
|
|
if (it == rename_map.end() || it->second == operand.GetVReg()) {
|
|
return false;
|
|
}
|
|
operand = MachineOperand::VReg(it->second);
|
|
return true;
|
|
}
|
|
|
|
bool RewriteMappedAddress(AddressExpr& address,
|
|
const std::unordered_map<int, int>& rename_map) {
|
|
bool changed = false;
|
|
if (address.base_kind == AddrBaseKind::VReg && address.base_index >= 0) {
|
|
auto it = rename_map.find(address.base_index);
|
|
if (it != rename_map.end() && it->second != address.base_index) {
|
|
address.base_index = it->second;
|
|
changed = true;
|
|
}
|
|
}
|
|
for (auto& term : address.scaled_vregs) {
|
|
auto it = rename_map.find(term.first);
|
|
if (it != rename_map.end() && it->second != term.first) {
|
|
term.first = it->second;
|
|
changed = true;
|
|
}
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
bool RewriteUses(MachineInstr& inst, const std::unordered_map<int, int>& rename_map) {
|
|
bool changed = false;
|
|
auto& operands = inst.GetOperands();
|
|
switch (inst.GetOpcode()) {
|
|
case MachineInstr::Opcode::Copy:
|
|
case MachineInstr::Opcode::ZExt:
|
|
case MachineInstr::Opcode::ItoF:
|
|
case MachineInstr::Opcode::FtoI:
|
|
case MachineInstr::Opcode::FSqrt:
|
|
case MachineInstr::Opcode::FNeg:
|
|
if (operands.size() >= 2) {
|
|
changed |= RewriteMappedOperand(operands[1], rename_map);
|
|
}
|
|
break;
|
|
case MachineInstr::Opcode::Store:
|
|
if (!operands.empty()) {
|
|
changed |= RewriteMappedOperand(operands[0], rename_map);
|
|
}
|
|
break;
|
|
case MachineInstr::Opcode::Add:
|
|
case MachineInstr::Opcode::Sub:
|
|
case MachineInstr::Opcode::Mul:
|
|
case MachineInstr::Opcode::Div:
|
|
case MachineInstr::Opcode::Rem:
|
|
case MachineInstr::Opcode::And:
|
|
case MachineInstr::Opcode::Or:
|
|
case MachineInstr::Opcode::Xor:
|
|
case MachineInstr::Opcode::Shl:
|
|
case MachineInstr::Opcode::AShr:
|
|
case MachineInstr::Opcode::LShr:
|
|
case MachineInstr::Opcode::FAdd:
|
|
case MachineInstr::Opcode::FSub:
|
|
case MachineInstr::Opcode::FMul:
|
|
case MachineInstr::Opcode::FDiv:
|
|
case MachineInstr::Opcode::ICmp:
|
|
case MachineInstr::Opcode::FCmp:
|
|
if (operands.size() >= 2) {
|
|
changed |= RewriteMappedOperand(operands[1], rename_map);
|
|
}
|
|
if (operands.size() >= 3) {
|
|
changed |= RewriteMappedOperand(operands[2], rename_map);
|
|
}
|
|
break;
|
|
case MachineInstr::Opcode::CondBr:
|
|
if (!operands.empty()) {
|
|
changed |= RewriteMappedOperand(operands[0], rename_map);
|
|
}
|
|
break;
|
|
case MachineInstr::Opcode::Call: {
|
|
const size_t arg_begin = inst.GetCallReturnType() == ValueType::Void ? 0 : 1;
|
|
for (size_t i = arg_begin; i < operands.size(); ++i) {
|
|
changed |= RewriteMappedOperand(operands[i], rename_map);
|
|
}
|
|
break;
|
|
}
|
|
case MachineInstr::Opcode::Ret:
|
|
if (!operands.empty()) {
|
|
changed |= RewriteMappedOperand(operands[0], rename_map);
|
|
}
|
|
break;
|
|
case MachineInstr::Opcode::Memset:
|
|
if (!operands.empty()) {
|
|
changed |= RewriteMappedOperand(operands[0], rename_map);
|
|
}
|
|
if (operands.size() >= 2) {
|
|
changed |= RewriteMappedOperand(operands[1], rename_map);
|
|
}
|
|
break;
|
|
case MachineInstr::Opcode::Arg:
|
|
case MachineInstr::Opcode::Load:
|
|
case MachineInstr::Opcode::Lea:
|
|
case MachineInstr::Opcode::Br:
|
|
case MachineInstr::Opcode::Unreachable:
|
|
break;
|
|
}
|
|
if (inst.HasAddress()) {
|
|
changed |= RewriteMappedAddress(inst.GetAddress(), rename_map);
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
bool RunSpillReductionOnFunction(MachineFunction& function) {
|
|
bool changed = false;
|
|
|
|
for (auto& block_ptr : function.GetBlocks()) {
|
|
auto& instructions = block_ptr->GetInstructions();
|
|
std::unordered_map<int, RematDef> available_defs;
|
|
std::unordered_map<int, RematDef> after_call_defs;
|
|
std::unordered_map<int, int> rename_map;
|
|
bool after_call = false;
|
|
|
|
for (size_t i = 0; i < instructions.size(); ++i) {
|
|
if (after_call) {
|
|
const auto uses = instructions[i].GetUses();
|
|
for (int use : uses) {
|
|
if (rename_map.count(use) != 0) {
|
|
continue;
|
|
}
|
|
auto it = after_call_defs.find(use);
|
|
if (it == after_call_defs.end()) {
|
|
continue;
|
|
}
|
|
const int new_vreg = function.NewVReg(function.GetVRegInfo(use).type);
|
|
instructions.insert(instructions.begin() + static_cast<long long>(i),
|
|
BuildRematInstr(new_vreg, it->second));
|
|
++i;
|
|
rename_map[use] = new_vreg;
|
|
available_defs[new_vreg] = it->second;
|
|
changed = true;
|
|
}
|
|
RewriteUses(instructions[i], rename_map);
|
|
}
|
|
|
|
const auto defs = instructions[i].GetDefs();
|
|
for (int def : defs) {
|
|
available_defs.erase(def);
|
|
after_call_defs.erase(def);
|
|
rename_map.erase(def);
|
|
}
|
|
|
|
RematDef def;
|
|
if (IsCheapRematerializableDef(instructions[i], def)) {
|
|
for (int vreg : defs) {
|
|
available_defs[vreg] = def;
|
|
}
|
|
}
|
|
|
|
if (instructions[i].GetOpcode() == MachineInstr::Opcode::Call ||
|
|
instructions[i].GetOpcode() == MachineInstr::Opcode::Memset) {
|
|
after_call_defs = available_defs;
|
|
rename_map.clear();
|
|
after_call = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
return changed;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool RunSpillReduction(MachineModule& module) {
|
|
bool changed = false;
|
|
for (auto& function : module.GetFunctions()) {
|
|
if (function) {
|
|
changed |= RunSpillReductionOnFunction(*function);
|
|
}
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
} // namespace mir
|