|
|
// IR 常量折叠:
|
|
|
// - 折叠可判定的常量表达式
|
|
|
// - 简化常量控制流分支(按实现范围裁剪)
|
|
|
|
|
|
#include "ir/IR.h"
|
|
|
|
|
|
#include <algorithm>
|
|
|
#include <climits>
|
|
|
#include <cmath>
|
|
|
#include <iostream>
|
|
|
#include <memory>
|
|
|
#include <unordered_map>
|
|
|
#include <vector>
|
|
|
|
|
|
namespace ir {
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
// 尝试对二元指令进行常量折叠
|
|
|
// 返回折叠后的常量,如果无法折叠则返回 nullptr
|
|
|
ConstantValue* TryFoldBinary(Opcode op, Value* lhs, Value* rhs, Context& ctx) {
|
|
|
auto* lhs_const = dynamic_cast<ConstantInt*>(lhs);
|
|
|
auto* rhs_const = dynamic_cast<ConstantInt*>(rhs);
|
|
|
|
|
|
if (lhs_const && rhs_const) {
|
|
|
// 整数常量折叠
|
|
|
int lv = lhs_const->GetValue();
|
|
|
int rv = rhs_const->GetValue();
|
|
|
int result = 0;
|
|
|
|
|
|
switch (op) {
|
|
|
case Opcode::Add: result = static_cast<int>(static_cast<unsigned int>(lv) + static_cast<unsigned int>(rv)); break;
|
|
|
case Opcode::Sub: result = static_cast<int>(static_cast<unsigned int>(lv) - static_cast<unsigned int>(rv)); break;
|
|
|
case Opcode::Mul: result = static_cast<int>(static_cast<unsigned int>(lv) * static_cast<unsigned int>(rv)); break;
|
|
|
case Opcode::Div:
|
|
|
if (rv == 0) return nullptr;
|
|
|
if (lv == INT_MIN && rv == -1) return nullptr;
|
|
|
result = lv / rv;
|
|
|
break;
|
|
|
case Opcode::Mod:
|
|
|
if (rv == 0) return nullptr;
|
|
|
if (lv == INT_MIN && rv == -1) return nullptr;
|
|
|
result = lv % rv;
|
|
|
break;
|
|
|
case Opcode::Eq: return ctx.GetConstBool(lv == rv ? 1 : 0);
|
|
|
case Opcode::Ne: return ctx.GetConstBool(lv != rv ? 1 : 0);
|
|
|
case Opcode::Lt: return ctx.GetConstBool(lv < rv ? 1 : 0);
|
|
|
case Opcode::Le: return ctx.GetConstBool(lv <= rv ? 1 : 0);
|
|
|
case Opcode::Gt: return ctx.GetConstBool(lv > rv ? 1 : 0);
|
|
|
case Opcode::Ge: return ctx.GetConstBool(lv >= rv ? 1 : 0);
|
|
|
default: return nullptr;
|
|
|
}
|
|
|
|
|
|
return ctx.GetConstInt(result);
|
|
|
}
|
|
|
|
|
|
// 浮点常量折叠
|
|
|
auto* lhs_float = dynamic_cast<ConstantFloat*>(lhs);
|
|
|
auto* rhs_float = dynamic_cast<ConstantFloat*>(rhs);
|
|
|
|
|
|
if (lhs_float && rhs_float) {
|
|
|
double lv = lhs_float->GetValue();
|
|
|
double rv = rhs_float->GetValue();
|
|
|
|
|
|
switch (op) {
|
|
|
case Opcode::Add: return ctx.GetConstFloat(lv + rv);
|
|
|
case Opcode::Sub: return ctx.GetConstFloat(lv - rv);
|
|
|
case Opcode::Mul: return ctx.GetConstFloat(lv * rv);
|
|
|
case Opcode::Div:
|
|
|
if (rv == 0.0) return nullptr;
|
|
|
return ctx.GetConstFloat(lv / rv);
|
|
|
case Opcode::Eq: return ctx.GetConstBool(lv == rv ? 1 : 0);
|
|
|
case Opcode::Ne: return ctx.GetConstBool(lv != rv ? 1 : 0);
|
|
|
case Opcode::Lt: return ctx.GetConstBool(lv < rv ? 1 : 0);
|
|
|
case Opcode::Le: return ctx.GetConstBool(lv <= rv ? 1 : 0);
|
|
|
case Opcode::Gt: return ctx.GetConstBool(lv > rv ? 1 : 0);
|
|
|
case Opcode::Ge: return ctx.GetConstBool(lv >= rv ? 1 : 0);
|
|
|
default: return nullptr;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return nullptr;
|
|
|
}
|
|
|
|
|
|
// 尝试对类型转换指令进行常量折叠
|
|
|
ConstantValue* TryFoldCast(Opcode op, Value* operand, Context& ctx) {
|
|
|
// SIToFP: int -> float
|
|
|
if (op == Opcode::SIToFP) {
|
|
|
if (auto* cint = dynamic_cast<ConstantInt*>(operand)) {
|
|
|
return ctx.GetConstFloat(static_cast<double>(cint->GetValue()));
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// FPToSI: float -> int
|
|
|
if (op == Opcode::FPToSI) {
|
|
|
if (auto* cfloat = dynamic_cast<ConstantFloat*>(operand)) {
|
|
|
double val = cfloat->GetValue();
|
|
|
if (val < static_cast<double>(INT_MIN) || val >= static_cast<double>(INT_MAX) || std::isnan(val))
|
|
|
return nullptr;
|
|
|
return ctx.GetConstInt(static_cast<int>(val));
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// ZExt: i1 -> i32
|
|
|
// 不要折叠 zext,因为折叠后类型从 i1 变成 i32,会破坏 IR 的类型正确性
|
|
|
// 原操作数是 i1 类型,但折叠后的常量是 i32 类型
|
|
|
if (op == Opcode::ZExt) {
|
|
|
return nullptr;
|
|
|
}
|
|
|
|
|
|
return nullptr;
|
|
|
}
|
|
|
|
|
|
// 检查一个值是否是已知常量
|
|
|
bool IsConstantValue(Value* v) {
|
|
|
return dynamic_cast<ConstantValue*>(v) != nullptr;
|
|
|
}
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
void RunConstFold(Module& module) {
|
|
|
auto& ctx = module.GetContext();
|
|
|
|
|
|
for (auto& func_ptr : module.GetFunctions()) {
|
|
|
auto* func = func_ptr.get();
|
|
|
if (func->IsExternal()) continue;
|
|
|
|
|
|
// 收集所有需要替换的指令及其常量结果
|
|
|
std::unordered_map<Instruction*, ConstantValue*> to_replace;
|
|
|
|
|
|
for (auto& bb : func->GetBlocks()) {
|
|
|
for (auto& inst_ptr : bb->GetInstructions()) {
|
|
|
auto* inst = inst_ptr.get();
|
|
|
|
|
|
// 跳过 PHI 节点和终止指令
|
|
|
if (dynamic_cast<PhiInst*>(inst)) continue;
|
|
|
if (inst->IsTerminator()) continue;
|
|
|
|
|
|
// 尝试折叠二元指令
|
|
|
if (auto* bin = dynamic_cast<BinaryInst*>(inst)) {
|
|
|
auto* lhs = bin->GetLhs();
|
|
|
auto* rhs = bin->GetRhs();
|
|
|
|
|
|
if (IsConstantValue(lhs) && IsConstantValue(rhs)) {
|
|
|
if (auto* result = TryFoldBinary(bin->GetOpcode(), lhs, rhs, ctx)) {
|
|
|
to_replace[inst] = result;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 尝试折叠类型转换指令
|
|
|
if (auto* cast = dynamic_cast<CastInst*>(inst)) {
|
|
|
auto* operand = cast->GetOperandValue();
|
|
|
|
|
|
if (IsConstantValue(operand)) {
|
|
|
if (auto* result = TryFoldCast(cast->GetOpcode(), operand, ctx)) {
|
|
|
to_replace[inst] = result;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 执行替换
|
|
|
for (auto& [inst, const_val] : to_replace) {
|
|
|
if (inst && const_val) {
|
|
|
inst->ReplaceAllUsesWith(const_val);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 删除已被替换的指令(没有剩余 use 的)
|
|
|
for (auto& bb : func->GetBlocks()) {
|
|
|
auto& insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(bb->GetInstructions());
|
|
|
for (auto it = insts.begin(); it != insts.end();) {
|
|
|
auto* inst = it->get();
|
|
|
if (to_replace.count(inst) && inst->GetUses().empty()) {
|
|
|
it = insts.erase(it);
|
|
|
} else {
|
|
|
++it;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
} // namespace ir
|