You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/ir/passes/ConstFold.cpp

187 lines
5.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// IR 常量折叠:
// - 折叠可判定的常量表达式
// - 简化常量控制流分支(按实现范围裁剪)
#include "ir/IR.h"
#include <algorithm>
#include <climits>
#include <cmath>
#include <iostream>
#include <memory>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
// 尝试对二元指令进行常量折叠
// 返回折叠后的常量,如果无法折叠则返回 nullptr
ConstantValue* TryFoldBinary(Opcode op, Value* lhs, Value* rhs, Context& ctx) {
auto* lhs_const = dynamic_cast<ConstantInt*>(lhs);
auto* rhs_const = dynamic_cast<ConstantInt*>(rhs);
if (lhs_const && rhs_const) {
// 整数常量折叠
int lv = lhs_const->GetValue();
int rv = rhs_const->GetValue();
int result = 0;
switch (op) {
case Opcode::Add: result = static_cast<int>(static_cast<unsigned int>(lv) + static_cast<unsigned int>(rv)); break;
case Opcode::Sub: result = static_cast<int>(static_cast<unsigned int>(lv) - static_cast<unsigned int>(rv)); break;
case Opcode::Mul: result = static_cast<int>(static_cast<unsigned int>(lv) * static_cast<unsigned int>(rv)); break;
case Opcode::Div:
if (rv == 0) return nullptr;
if (lv == INT_MIN && rv == -1) return nullptr;
result = lv / rv;
break;
case Opcode::Mod:
if (rv == 0) return nullptr;
if (lv == INT_MIN && rv == -1) return nullptr;
result = lv % rv;
break;
case Opcode::Eq: return ctx.GetConstBool(lv == rv ? 1 : 0);
case Opcode::Ne: return ctx.GetConstBool(lv != rv ? 1 : 0);
case Opcode::Lt: return ctx.GetConstBool(lv < rv ? 1 : 0);
case Opcode::Le: return ctx.GetConstBool(lv <= rv ? 1 : 0);
case Opcode::Gt: return ctx.GetConstBool(lv > rv ? 1 : 0);
case Opcode::Ge: return ctx.GetConstBool(lv >= rv ? 1 : 0);
default: return nullptr;
}
return ctx.GetConstInt(result);
}
// 浮点常量折叠
auto* lhs_float = dynamic_cast<ConstantFloat*>(lhs);
auto* rhs_float = dynamic_cast<ConstantFloat*>(rhs);
if (lhs_float && rhs_float) {
double lv = lhs_float->GetValue();
double rv = rhs_float->GetValue();
switch (op) {
case Opcode::Add: return ctx.GetConstFloat(lv + rv);
case Opcode::Sub: return ctx.GetConstFloat(lv - rv);
case Opcode::Mul: return ctx.GetConstFloat(lv * rv);
case Opcode::Div:
if (rv == 0.0) return nullptr;
return ctx.GetConstFloat(lv / rv);
case Opcode::Eq: return ctx.GetConstBool(lv == rv ? 1 : 0);
case Opcode::Ne: return ctx.GetConstBool(lv != rv ? 1 : 0);
case Opcode::Lt: return ctx.GetConstBool(lv < rv ? 1 : 0);
case Opcode::Le: return ctx.GetConstBool(lv <= rv ? 1 : 0);
case Opcode::Gt: return ctx.GetConstBool(lv > rv ? 1 : 0);
case Opcode::Ge: return ctx.GetConstBool(lv >= rv ? 1 : 0);
default: return nullptr;
}
}
return nullptr;
}
// 尝试对类型转换指令进行常量折叠
ConstantValue* TryFoldCast(Opcode op, Value* operand, Context& ctx) {
// SIToFP: int -> float
if (op == Opcode::SIToFP) {
if (auto* cint = dynamic_cast<ConstantInt*>(operand)) {
return ctx.GetConstFloat(static_cast<double>(cint->GetValue()));
}
}
// FPToSI: float -> int
if (op == Opcode::FPToSI) {
if (auto* cfloat = dynamic_cast<ConstantFloat*>(operand)) {
double val = cfloat->GetValue();
if (val < static_cast<double>(INT_MIN) || val >= static_cast<double>(INT_MAX) || std::isnan(val))
return nullptr;
return ctx.GetConstInt(static_cast<int>(val));
}
}
// ZExt: i1 -> i32
// 不要折叠 zext因为折叠后类型从 i1 变成 i32会破坏 IR 的类型正确性
// 原操作数是 i1 类型,但折叠后的常量是 i32 类型
if (op == Opcode::ZExt) {
return nullptr;
}
return nullptr;
}
// 检查一个值是否是已知常量
bool IsConstantValue(Value* v) {
return dynamic_cast<ConstantValue*>(v) != nullptr;
}
} // namespace
void RunConstFold(Module& module) {
auto& ctx = module.GetContext();
for (auto& func_ptr : module.GetFunctions()) {
auto* func = func_ptr.get();
if (func->IsExternal()) continue;
// 收集所有需要替换的指令及其常量结果
std::unordered_map<Instruction*, ConstantValue*> to_replace;
for (auto& bb : func->GetBlocks()) {
for (auto& inst_ptr : bb->GetInstructions()) {
auto* inst = inst_ptr.get();
// 跳过 PHI 节点和终止指令
if (dynamic_cast<PhiInst*>(inst)) continue;
if (inst->IsTerminator()) continue;
// 尝试折叠二元指令
if (auto* bin = dynamic_cast<BinaryInst*>(inst)) {
auto* lhs = bin->GetLhs();
auto* rhs = bin->GetRhs();
if (IsConstantValue(lhs) && IsConstantValue(rhs)) {
if (auto* result = TryFoldBinary(bin->GetOpcode(), lhs, rhs, ctx)) {
to_replace[inst] = result;
}
}
}
// 尝试折叠类型转换指令
if (auto* cast = dynamic_cast<CastInst*>(inst)) {
auto* operand = cast->GetOperandValue();
if (IsConstantValue(operand)) {
if (auto* result = TryFoldCast(cast->GetOpcode(), operand, ctx)) {
to_replace[inst] = result;
}
}
}
}
}
// 执行替换
for (auto& [inst, const_val] : to_replace) {
if (inst && const_val) {
inst->ReplaceAllUsesWith(const_val);
}
}
// 删除已被替换的指令(没有剩余 use 的)
for (auto& bb : func->GetBlocks()) {
auto& insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(bb->GetInstructions());
for (auto it = insts.begin(); it != insts.end();) {
auto* inst = it->get();
if (to_replace.count(inst) && inst->GetUses().empty()) {
it = insts.erase(it);
} else {
++it;
}
}
}
}
}
} // namespace ir