You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

142 lines
3.2 KiB

#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <algorithm>
#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct ExprKey {
Opcode opcode = Opcode::Add;
std::vector<std::uintptr_t> operands;
bool operator==(const ExprKey& rhs) const {
return opcode == rhs.opcode && operands == rhs.operands;
}
};
struct ExprKeyHash {
std::size_t operator()(const ExprKey& key) const {
std::size_t h = static_cast<std::size_t>(key.opcode);
for (auto operand : key.operands) {
h ^= std::hash<std::uintptr_t>{}(operand) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
bool IsSupportedCSEInstruction(Instruction* inst) {
if (!inst || inst->IsVoid()) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::Zext:
return true;
default:
return false;
}
}
ExprKey BuildExprKey(Instruction* inst) {
ExprKey key;
key.opcode = inst->GetOpcode();
key.operands.reserve(inst->GetNumOperands());
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
key.operands.push_back(
reinterpret_cast<std::uintptr_t>(inst->GetOperand(i)));
}
if (inst->GetNumOperands() == 2 && passutils::IsCommutativeOpcode(inst->GetOpcode()) &&
key.operands[1] < key.operands[0]) {
std::swap(key.operands[0], key.operands[1]);
}
return key;
}
bool RunCSEOnFunction(Function& function) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
std::unordered_map<ExprKey, Value*, ExprKeyHash> available_exprs;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!IsSupportedCSEInstruction(inst)) {
continue;
}
const auto key = BuildExprKey(inst);
auto it = available_exprs.find(key);
if (it == available_exprs.end()) {
available_exprs.emplace(key, inst);
continue;
}
inst->ReplaceAllUsesWith(it->second);
to_remove.push_back(inst);
changed = true;
}
for (auto* inst : to_remove) {
block_ptr->EraseInstruction(inst);
}
}
return changed;
}
} // namespace
bool RunCSE(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunCSEOnFunction(*function);
}
}
return changed;
}
} // namespace ir