You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

130 lines
4.0 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// 公共子表达式消除CSE
// - 识别并复用重复计算的等价表达式
// - 局部值编号:在单个基本块内消除重复计算
#include "ir/IR.h"
#include <cstdint>
#include <sstream>
#include <string>
#include <unordered_map>
namespace ir {
namespace {
// 为操作数生成唯一标识:常量使用值,否则使用指针
std::string ValKey(Value* v) {
if (auto* ci = dynamic_cast<ConstantInt*>(v))
return "ci" + std::to_string(ci->GetValue());
if (auto* cf = dynamic_cast<ConstantFloat*>(v)) {
// 使用 IEEE 754 位表示
union { float f; uint32_t i; } u;
u.f = cf->GetValue();
return "cf" + std::to_string(u.i);
}
// 非常量使用指针地址作为唯一标识
std::ostringstream oss;
oss << "p" << reinterpret_cast<uintptr_t>(v);
return oss.str();
}
// 为可消除的指令生成 hash key
std::string MakeKey(Instruction* inst) {
switch (inst->GetOpcode()) {
case Opcode::Add: case Opcode::Sub: case Opcode::Mul:
case Opcode::Div: case Opcode::Mod:
case Opcode::FAdd: case Opcode::FSub:
case Opcode::FMul: case Opcode::FDiv: {
auto* bin = static_cast<BinaryInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(bin->GetLhs()) + "|" + ValKey(bin->GetRhs());
}
case Opcode::ICmp: {
auto* cmp = static_cast<ICmpInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
std::to_string(static_cast<int>(cmp->GetPredicate())) + "|" +
ValKey(cmp->GetLhs()) + "|" + ValKey(cmp->GetRhs());
}
case Opcode::FCmp: {
auto* cmp = static_cast<FCmpInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
std::to_string(static_cast<int>(cmp->GetPredicate())) + "|" +
ValKey(cmp->GetLhs()) + "|" + ValKey(cmp->GetRhs());
}
case Opcode::Gep: {
auto* gep = static_cast<GepInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(gep->GetBasePtr()) + "|" + ValKey(gep->GetIndex());
}
case Opcode::Load: {
auto* ld = static_cast<LoadInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(ld->GetPtr());
}
case Opcode::ZExt: {
auto* ze = static_cast<ZExtInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(ze->GetSrc());
}
case Opcode::SIToFP: {
auto* si = static_cast<SIToFPInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(si->GetSrc());
}
case Opcode::FPToSI: {
auto* fs = static_cast<FPToSIInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(fs->GetSrc());
}
default: return "";
}
}
} // namespace
bool RunCSE(Function& func) {
bool changed = false;
for (auto& bb : func.GetBlocks()) {
std::unordered_map<std::string, Value*> available;
std::vector<Instruction*> to_remove;
for (auto& inst : bb->GetInstructions()) {
auto* ip = inst.get();
// Store/Call 指令可能修改内存,使之前缓存的 Load/Gep 结果失效。
// 保守地清空整个 available 表,避免跨副作用的错误 CSE。
if (ip->GetOpcode() == Opcode::Store ||
ip->GetOpcode() == Opcode::Call) {
available.clear();
continue;
}
std::string key = MakeKey(ip);
if (key.empty()) continue;
auto it = available.find(key);
if (it != available.end()) {
// 找到已有的等价指令,替换使用
ip->ReplaceAllUsesWith(it->second);
to_remove.push_back(ip);
changed = true;
} else {
available[key] = ip;
}
}
for (auto* ip : to_remove) {
for (size_t i = 0; i < ip->GetNumOperands(); ++i)
ip->SetOperand(i, nullptr);
ip->RemoveFromParent();
}
bb->SweepDeadInstructions();
}
return changed;
}
} // namespace ir