// IR 常量折叠: // - 折叠可判定的常量表达式 // - 简化常量控制流分支(按实现范围裁剪) #include "ir/IR.h" #include #include #include #include #include #include #include namespace ir { namespace { // 尝试对二元指令进行常量折叠 // 返回折叠后的常量,如果无法折叠则返回 nullptr ConstantValue* TryFoldBinary(Opcode op, Value* lhs, Value* rhs, Context& ctx) { auto* lhs_const = dynamic_cast(lhs); auto* rhs_const = dynamic_cast(rhs); if (lhs_const && rhs_const) { // 整数常量折叠 int lv = lhs_const->GetValue(); int rv = rhs_const->GetValue(); int result = 0; switch (op) { case Opcode::Add: result = static_cast(static_cast(lv) + static_cast(rv)); break; case Opcode::Sub: result = static_cast(static_cast(lv) - static_cast(rv)); break; case Opcode::Mul: result = static_cast(static_cast(lv) * static_cast(rv)); break; case Opcode::Div: if (rv == 0) return nullptr; if (lv == INT_MIN && rv == -1) return nullptr; result = lv / rv; break; case Opcode::Mod: if (rv == 0) return nullptr; if (lv == INT_MIN && rv == -1) return nullptr; result = lv % rv; break; case Opcode::Eq: return ctx.GetConstBool(lv == rv ? 1 : 0); case Opcode::Ne: return ctx.GetConstBool(lv != rv ? 1 : 0); case Opcode::Lt: return ctx.GetConstBool(lv < rv ? 1 : 0); case Opcode::Le: return ctx.GetConstBool(lv <= rv ? 1 : 0); case Opcode::Gt: return ctx.GetConstBool(lv > rv ? 1 : 0); case Opcode::Ge: return ctx.GetConstBool(lv >= rv ? 1 : 0); default: return nullptr; } return ctx.GetConstInt(result); } // 浮点常量折叠 auto* lhs_float = dynamic_cast(lhs); auto* rhs_float = dynamic_cast(rhs); if (lhs_float && rhs_float) { double lv = lhs_float->GetValue(); double rv = rhs_float->GetValue(); switch (op) { case Opcode::Add: return ctx.GetConstFloat(lv + rv); case Opcode::Sub: return ctx.GetConstFloat(lv - rv); case Opcode::Mul: return ctx.GetConstFloat(lv * rv); case Opcode::Div: if (rv == 0.0) return nullptr; return ctx.GetConstFloat(lv / rv); case Opcode::Eq: return ctx.GetConstBool(lv == rv ? 1 : 0); case Opcode::Ne: return ctx.GetConstBool(lv != rv ? 1 : 0); case Opcode::Lt: return ctx.GetConstBool(lv < rv ? 1 : 0); case Opcode::Le: return ctx.GetConstBool(lv <= rv ? 1 : 0); case Opcode::Gt: return ctx.GetConstBool(lv > rv ? 1 : 0); case Opcode::Ge: return ctx.GetConstBool(lv >= rv ? 1 : 0); default: return nullptr; } } return nullptr; } // 尝试对类型转换指令进行常量折叠 ConstantValue* TryFoldCast(Opcode op, Value* operand, Context& ctx) { // SIToFP: int -> float if (op == Opcode::SIToFP) { if (auto* cint = dynamic_cast(operand)) { return ctx.GetConstFloat(static_cast(cint->GetValue())); } } // FPToSI: float -> int if (op == Opcode::FPToSI) { if (auto* cfloat = dynamic_cast(operand)) { double val = cfloat->GetValue(); if (val < static_cast(INT_MIN) || val >= static_cast(INT_MAX) || std::isnan(val)) return nullptr; return ctx.GetConstInt(static_cast(val)); } } // ZExt: i1 -> i32 // 不要折叠 zext,因为折叠后类型从 i1 变成 i32,会破坏 IR 的类型正确性 // 原操作数是 i1 类型,但折叠后的常量是 i32 类型 if (op == Opcode::ZExt) { return nullptr; } return nullptr; } // 检查一个值是否是已知常量 bool IsConstantValue(Value* v) { return dynamic_cast(v) != nullptr; } } // namespace void RunConstFold(Module& module) { auto& ctx = module.GetContext(); for (auto& func_ptr : module.GetFunctions()) { auto* func = func_ptr.get(); if (func->IsExternal()) continue; // 收集所有需要替换的指令及其常量结果 std::unordered_map to_replace; for (auto& bb : func->GetBlocks()) { for (auto& inst_ptr : bb->GetInstructions()) { auto* inst = inst_ptr.get(); // 跳过 PHI 节点和终止指令 if (dynamic_cast(inst)) continue; if (inst->IsTerminator()) continue; // 尝试折叠二元指令 if (auto* bin = dynamic_cast(inst)) { auto* lhs = bin->GetLhs(); auto* rhs = bin->GetRhs(); if (IsConstantValue(lhs) && IsConstantValue(rhs)) { if (auto* result = TryFoldBinary(bin->GetOpcode(), lhs, rhs, ctx)) { to_replace[inst] = result; } } } // 尝试折叠类型转换指令 if (auto* cast = dynamic_cast(inst)) { auto* operand = cast->GetOperandValue(); if (IsConstantValue(operand)) { if (auto* result = TryFoldCast(cast->GetOpcode(), operand, ctx)) { to_replace[inst] = result; } } } } } // 执行替换 for (auto& [inst, const_val] : to_replace) { if (inst && const_val) { inst->ReplaceAllUsesWith(const_val); } } // 删除已被替换的指令(没有剩余 use 的) for (auto& bb : func->GetBlocks()) { auto& insts = const_cast>&>(bb->GetInstructions()); for (auto it = insts.begin(); it != insts.end();) { auto* inst = it->get(); if (to_replace.count(inst) && inst->GetUses().empty()) { it = insts.erase(it); } else { ++it; } } } } } } // namespace ir