You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
470 lines
13 KiB
470 lines
13 KiB
#include "ir/PassManager.h"
|
|
|
|
#include "ir/IR.h"
|
|
#include "PassUtils.h"
|
|
|
|
#include <cmath>
|
|
#include <cstdint>
|
|
#include <limits>
|
|
#include <vector>
|
|
|
|
namespace ir {
|
|
namespace {
|
|
|
|
Value* GetInt32Const(Context& ctx, std::int32_t value) {
|
|
return ctx.GetConstInt(static_cast<int>(value));
|
|
}
|
|
|
|
Value* GetBoolConst(Context& ctx, bool value) { return ctx.GetConstBool(value); }
|
|
|
|
Value* GetFloatConst(float value) {
|
|
return new ConstantFloat(Type::GetFloatType(), value);
|
|
}
|
|
|
|
bool TryGetInt32(Value* value, std::int32_t& out) {
|
|
if (auto* ci = dyncast<ConstantInt>(value)) {
|
|
out = static_cast<std::int32_t>(ci->GetValue());
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool TryGetBool(Value* value, bool& out) {
|
|
if (auto* cb = dyncast<ConstantI1>(value)) {
|
|
out = cb->GetValue();
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool TryGetFloat(Value* value, float& out) {
|
|
if (auto* cf = dyncast<ConstantFloat>(value)) {
|
|
out = cf->GetValue();
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool IsZeroValue(Value* value) {
|
|
std::int32_t i32 = 0;
|
|
bool i1 = false;
|
|
float f32 = 0.0f;
|
|
return (TryGetInt32(value, i32) && i32 == 0) || (TryGetBool(value, i1) && !i1) ||
|
|
(TryGetFloat(value, f32) && passutils::FloatBits(f32) == 0);
|
|
}
|
|
|
|
bool IsOneValue(Value* value) {
|
|
std::int32_t i32 = 0;
|
|
bool i1 = false;
|
|
float f32 = 0.0f;
|
|
return (TryGetInt32(value, i32) && i32 == 1) || (TryGetBool(value, i1) && i1) ||
|
|
(TryGetFloat(value, f32) &&
|
|
passutils::FloatBits(f32) == passutils::FloatBits(1.0f));
|
|
}
|
|
|
|
bool IsAllOnesInt(Value* value) {
|
|
std::int32_t i32 = 0;
|
|
return TryGetInt32(value, i32) && i32 == -1;
|
|
}
|
|
|
|
std::int32_t WrapInt32(std::uint32_t value) {
|
|
return static_cast<std::int32_t>(value);
|
|
}
|
|
|
|
Value* FoldBinary(Context& ctx, BinaryInst* inst) {
|
|
const auto opcode = inst->GetOpcode();
|
|
auto* lhs = inst->GetLhs();
|
|
auto* rhs = inst->GetRhs();
|
|
|
|
std::int32_t lhs_i32 = 0;
|
|
std::int32_t rhs_i32 = 0;
|
|
bool lhs_i1 = false;
|
|
bool rhs_i1 = false;
|
|
float lhs_f32 = 0.0f;
|
|
float rhs_f32 = 0.0f;
|
|
|
|
const bool has_lhs_i32 = TryGetInt32(lhs, lhs_i32);
|
|
const bool has_rhs_i32 = TryGetInt32(rhs, rhs_i32);
|
|
const bool has_lhs_i1 = TryGetBool(lhs, lhs_i1);
|
|
const bool has_rhs_i1 = TryGetBool(rhs, rhs_i1);
|
|
const bool has_lhs_f32 = TryGetFloat(lhs, lhs_f32);
|
|
const bool has_rhs_f32 = TryGetFloat(rhs, rhs_f32);
|
|
|
|
if (has_lhs_i32 && has_rhs_i32) {
|
|
switch (opcode) {
|
|
case Opcode::Add:
|
|
return GetInt32Const(
|
|
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) +
|
|
static_cast<std::uint32_t>(rhs_i32)));
|
|
case Opcode::Sub:
|
|
return GetInt32Const(
|
|
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) -
|
|
static_cast<std::uint32_t>(rhs_i32)));
|
|
case Opcode::Mul:
|
|
return GetInt32Const(
|
|
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) *
|
|
static_cast<std::uint32_t>(rhs_i32)));
|
|
case Opcode::Div:
|
|
if (rhs_i32 == 0 ||
|
|
(lhs_i32 == std::numeric_limits<std::int32_t>::min() && rhs_i32 == -1)) {
|
|
return nullptr;
|
|
}
|
|
return GetInt32Const(ctx, lhs_i32 / rhs_i32);
|
|
case Opcode::Rem:
|
|
if (rhs_i32 == 0 ||
|
|
(lhs_i32 == std::numeric_limits<std::int32_t>::min() && rhs_i32 == -1)) {
|
|
return nullptr;
|
|
}
|
|
return GetInt32Const(ctx, lhs_i32 % rhs_i32);
|
|
case Opcode::And:
|
|
return GetInt32Const(ctx, lhs_i32 & rhs_i32);
|
|
case Opcode::Or:
|
|
return GetInt32Const(ctx, lhs_i32 | rhs_i32);
|
|
case Opcode::Xor:
|
|
return GetInt32Const(ctx, lhs_i32 ^ rhs_i32);
|
|
case Opcode::Shl:
|
|
if (rhs_i32 < 0 || rhs_i32 >= 32) {
|
|
return nullptr;
|
|
}
|
|
return GetInt32Const(ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32)
|
|
<< rhs_i32));
|
|
case Opcode::AShr:
|
|
if (rhs_i32 < 0 || rhs_i32 >= 32) {
|
|
return nullptr;
|
|
}
|
|
return GetInt32Const(ctx, lhs_i32 >> rhs_i32);
|
|
case Opcode::LShr:
|
|
if (rhs_i32 < 0 || rhs_i32 >= 32) {
|
|
return nullptr;
|
|
}
|
|
return GetInt32Const(
|
|
ctx,
|
|
WrapInt32(static_cast<std::uint32_t>(lhs_i32) >> rhs_i32));
|
|
case Opcode::ICmpEQ:
|
|
return GetBoolConst(ctx, lhs_i32 == rhs_i32);
|
|
case Opcode::ICmpNE:
|
|
return GetBoolConst(ctx, lhs_i32 != rhs_i32);
|
|
case Opcode::ICmpLT:
|
|
return GetBoolConst(ctx, lhs_i32 < rhs_i32);
|
|
case Opcode::ICmpGT:
|
|
return GetBoolConst(ctx, lhs_i32 > rhs_i32);
|
|
case Opcode::ICmpLE:
|
|
return GetBoolConst(ctx, lhs_i32 <= rhs_i32);
|
|
case Opcode::ICmpGE:
|
|
return GetBoolConst(ctx, lhs_i32 >= rhs_i32);
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (has_lhs_i1 && has_rhs_i1) {
|
|
switch (opcode) {
|
|
case Opcode::And:
|
|
return GetBoolConst(ctx, lhs_i1 && rhs_i1);
|
|
case Opcode::Or:
|
|
return GetBoolConst(ctx, lhs_i1 || rhs_i1);
|
|
case Opcode::Xor:
|
|
return GetBoolConst(ctx, lhs_i1 != rhs_i1);
|
|
case Opcode::ICmpEQ:
|
|
return GetBoolConst(ctx, lhs_i1 == rhs_i1);
|
|
case Opcode::ICmpNE:
|
|
return GetBoolConst(ctx, lhs_i1 != rhs_i1);
|
|
case Opcode::ICmpLT:
|
|
return GetBoolConst(ctx, static_cast<int>(lhs_i1) < static_cast<int>(rhs_i1));
|
|
case Opcode::ICmpGT:
|
|
return GetBoolConst(ctx, static_cast<int>(lhs_i1) > static_cast<int>(rhs_i1));
|
|
case Opcode::ICmpLE:
|
|
return GetBoolConst(ctx, static_cast<int>(lhs_i1) <= static_cast<int>(rhs_i1));
|
|
case Opcode::ICmpGE:
|
|
return GetBoolConst(ctx, static_cast<int>(lhs_i1) >= static_cast<int>(rhs_i1));
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (has_lhs_f32 && has_rhs_f32) {
|
|
switch (opcode) {
|
|
case Opcode::FAdd:
|
|
return GetFloatConst(lhs_f32 + rhs_f32);
|
|
case Opcode::FSub:
|
|
return GetFloatConst(lhs_f32 - rhs_f32);
|
|
case Opcode::FMul:
|
|
return GetFloatConst(lhs_f32 * rhs_f32);
|
|
case Opcode::FDiv:
|
|
return GetFloatConst(lhs_f32 / rhs_f32);
|
|
case Opcode::FRem:
|
|
return GetFloatConst(std::fmod(lhs_f32, rhs_f32));
|
|
case Opcode::FCmpEQ:
|
|
return GetBoolConst(
|
|
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 == rhs_f32);
|
|
case Opcode::FCmpNE:
|
|
return GetBoolConst(
|
|
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 != rhs_f32);
|
|
case Opcode::FCmpLT:
|
|
return GetBoolConst(
|
|
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 < rhs_f32);
|
|
case Opcode::FCmpGT:
|
|
return GetBoolConst(
|
|
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 > rhs_f32);
|
|
case Opcode::FCmpLE:
|
|
return GetBoolConst(
|
|
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 <= rhs_f32);
|
|
case Opcode::FCmpGE:
|
|
return GetBoolConst(
|
|
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 >= rhs_f32);
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
switch (opcode) {
|
|
case Opcode::Add:
|
|
if (IsZeroValue(rhs)) {
|
|
return lhs;
|
|
}
|
|
if (IsZeroValue(lhs)) {
|
|
return rhs;
|
|
}
|
|
break;
|
|
case Opcode::Sub:
|
|
if (IsZeroValue(rhs)) {
|
|
return lhs;
|
|
}
|
|
break;
|
|
case Opcode::Mul:
|
|
if (IsOneValue(rhs)) {
|
|
return lhs;
|
|
}
|
|
if (IsOneValue(lhs)) {
|
|
return rhs;
|
|
}
|
|
if (IsZeroValue(lhs) || IsZeroValue(rhs)) {
|
|
return GetInt32Const(ctx, 0);
|
|
}
|
|
break;
|
|
case Opcode::Div:
|
|
if (IsOneValue(rhs)) {
|
|
return lhs;
|
|
}
|
|
if (IsZeroValue(lhs) && !IsZeroValue(rhs)) {
|
|
return GetInt32Const(ctx, 0);
|
|
}
|
|
break;
|
|
case Opcode::Rem:
|
|
if ((has_rhs_i32 && (rhs_i32 == 1 || rhs_i32 == -1)) ||
|
|
(has_rhs_i1 && rhs_i1)) {
|
|
return GetInt32Const(ctx, 0);
|
|
}
|
|
if (IsZeroValue(lhs) && !IsZeroValue(rhs)) {
|
|
return GetInt32Const(ctx, 0);
|
|
}
|
|
break;
|
|
case Opcode::And:
|
|
if (IsZeroValue(lhs) || IsZeroValue(rhs)) {
|
|
return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false)
|
|
: GetInt32Const(ctx, 0);
|
|
}
|
|
if (has_lhs_i1 && lhs_i1) {
|
|
return rhs;
|
|
}
|
|
if (has_rhs_i1 && rhs_i1) {
|
|
return lhs;
|
|
}
|
|
if (IsAllOnesInt(lhs)) {
|
|
return rhs;
|
|
}
|
|
if (IsAllOnesInt(rhs)) {
|
|
return lhs;
|
|
}
|
|
if (passutils::AreEquivalentValues(lhs, rhs)) {
|
|
return lhs;
|
|
}
|
|
break;
|
|
case Opcode::Or:
|
|
if (IsZeroValue(lhs)) {
|
|
return rhs;
|
|
}
|
|
if (IsZeroValue(rhs)) {
|
|
return lhs;
|
|
}
|
|
if (has_lhs_i1 && lhs_i1) {
|
|
return GetBoolConst(ctx, true);
|
|
}
|
|
if (has_rhs_i1 && rhs_i1) {
|
|
return GetBoolConst(ctx, true);
|
|
}
|
|
if (passutils::AreEquivalentValues(lhs, rhs)) {
|
|
return lhs;
|
|
}
|
|
break;
|
|
case Opcode::Xor:
|
|
if (IsZeroValue(lhs)) {
|
|
return rhs;
|
|
}
|
|
if (IsZeroValue(rhs)) {
|
|
return lhs;
|
|
}
|
|
if (passutils::AreEquivalentValues(lhs, rhs)) {
|
|
return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false)
|
|
: GetInt32Const(ctx, 0);
|
|
}
|
|
break;
|
|
case Opcode::Shl:
|
|
case Opcode::AShr:
|
|
case Opcode::LShr:
|
|
if (IsZeroValue(rhs)) {
|
|
return lhs;
|
|
}
|
|
if (IsZeroValue(lhs)) {
|
|
return GetInt32Const(ctx, 0);
|
|
}
|
|
break;
|
|
case Opcode::FAdd:
|
|
if (IsZeroValue(lhs)) {
|
|
return rhs;
|
|
}
|
|
if (IsZeroValue(rhs)) {
|
|
return lhs;
|
|
}
|
|
break;
|
|
case Opcode::FSub:
|
|
if (IsZeroValue(rhs)) {
|
|
return lhs;
|
|
}
|
|
break;
|
|
case Opcode::FMul:
|
|
if (IsOneValue(lhs)) {
|
|
return rhs;
|
|
}
|
|
if (IsOneValue(rhs)) {
|
|
return lhs;
|
|
}
|
|
break;
|
|
case Opcode::FDiv:
|
|
if (IsOneValue(rhs)) {
|
|
return lhs;
|
|
}
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
Value* FoldUnary(Context& ctx, UnaryInst* inst) {
|
|
auto* operand = inst->GetOprd();
|
|
std::int32_t i32 = 0;
|
|
bool i1 = false;
|
|
float f32 = 0.0f;
|
|
switch (inst->GetOpcode()) {
|
|
case Opcode::Neg:
|
|
if (TryGetInt32(operand, i32)) {
|
|
return GetInt32Const(ctx, WrapInt32(0u - static_cast<std::uint32_t>(i32)));
|
|
}
|
|
break;
|
|
case Opcode::Not:
|
|
if (TryGetBool(operand, i1)) {
|
|
return GetBoolConst(ctx, !i1);
|
|
}
|
|
if (TryGetInt32(operand, i32)) {
|
|
return GetInt32Const(ctx, i32 ^ 1);
|
|
}
|
|
break;
|
|
case Opcode::FNeg:
|
|
if (TryGetFloat(operand, f32)) {
|
|
return GetFloatConst(-f32);
|
|
}
|
|
break;
|
|
case Opcode::FtoI:
|
|
if (TryGetFloat(operand, f32)) {
|
|
return GetInt32Const(ctx, static_cast<std::int32_t>(f32));
|
|
}
|
|
break;
|
|
case Opcode::IToF:
|
|
if (TryGetInt32(operand, i32)) {
|
|
return GetFloatConst(static_cast<float>(i32));
|
|
}
|
|
if (TryGetBool(operand, i1)) {
|
|
return GetFloatConst(i1 ? 1.0f : 0.0f);
|
|
}
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
Value* FoldZext(Context& ctx, ZextInst* inst) {
|
|
auto* value = inst->GetValue();
|
|
bool i1 = false;
|
|
std::int32_t i32 = 0;
|
|
if (inst->GetType()->IsInt1()) {
|
|
if (TryGetBool(value, i1)) {
|
|
return GetBoolConst(ctx, i1);
|
|
}
|
|
if (TryGetInt32(value, i32)) {
|
|
return GetBoolConst(ctx, i32 != 0);
|
|
}
|
|
}
|
|
if (inst->GetType()->IsInt32()) {
|
|
if (TryGetBool(value, i1)) {
|
|
return GetInt32Const(ctx, i1 ? 1 : 0);
|
|
}
|
|
if (TryGetInt32(value, i32)) {
|
|
return GetInt32Const(ctx, i32);
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
bool FoldFunction(Function& function, Context& ctx) {
|
|
if (function.IsExternal()) {
|
|
return false;
|
|
}
|
|
|
|
bool changed = false;
|
|
for (const auto& block_ptr : function.GetBlocks()) {
|
|
std::vector<Instruction*> to_remove;
|
|
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
|
|
auto* inst = inst_ptr.get();
|
|
Value* replacement = nullptr;
|
|
if (auto* binary = dyncast<BinaryInst>(inst)) {
|
|
replacement = FoldBinary(ctx, binary);
|
|
} else if (auto* unary = dyncast<UnaryInst>(inst)) {
|
|
replacement = FoldUnary(ctx, unary);
|
|
} else if (auto* zext = dyncast<ZextInst>(inst)) {
|
|
replacement = FoldZext(ctx, zext);
|
|
}
|
|
|
|
if (!replacement || replacement == inst) {
|
|
continue;
|
|
}
|
|
inst->ReplaceAllUsesWith(replacement);
|
|
to_remove.push_back(inst);
|
|
changed = true;
|
|
}
|
|
|
|
for (auto* inst : to_remove) {
|
|
block_ptr->EraseInstruction(inst);
|
|
}
|
|
}
|
|
|
|
return changed;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool RunConstFold(Module& module) {
|
|
bool changed = false;
|
|
auto& ctx = module.GetContext();
|
|
for (const auto& function : module.GetFunctions()) {
|
|
if (function) {
|
|
changed |= FoldFunction(*function, ctx);
|
|
}
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
} // namespace ir
|