You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

470 lines
13 KiB

#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <cmath>
#include <cstdint>
#include <limits>
#include <vector>
namespace ir {
namespace {
Value* GetInt32Const(Context& ctx, std::int32_t value) {
return ctx.GetConstInt(static_cast<int>(value));
}
Value* GetBoolConst(Context& ctx, bool value) { return ctx.GetConstBool(value); }
Value* GetFloatConst(float value) {
return new ConstantFloat(Type::GetFloatType(), value);
}
bool TryGetInt32(Value* value, std::int32_t& out) {
if (auto* ci = dyncast<ConstantInt>(value)) {
out = static_cast<std::int32_t>(ci->GetValue());
return true;
}
return false;
}
bool TryGetBool(Value* value, bool& out) {
if (auto* cb = dyncast<ConstantI1>(value)) {
out = cb->GetValue();
return true;
}
return false;
}
bool TryGetFloat(Value* value, float& out) {
if (auto* cf = dyncast<ConstantFloat>(value)) {
out = cf->GetValue();
return true;
}
return false;
}
bool IsZeroValue(Value* value) {
std::int32_t i32 = 0;
bool i1 = false;
float f32 = 0.0f;
return (TryGetInt32(value, i32) && i32 == 0) || (TryGetBool(value, i1) && !i1) ||
(TryGetFloat(value, f32) && passutils::FloatBits(f32) == 0);
}
bool IsOneValue(Value* value) {
std::int32_t i32 = 0;
bool i1 = false;
float f32 = 0.0f;
return (TryGetInt32(value, i32) && i32 == 1) || (TryGetBool(value, i1) && i1) ||
(TryGetFloat(value, f32) &&
passutils::FloatBits(f32) == passutils::FloatBits(1.0f));
}
bool IsAllOnesInt(Value* value) {
std::int32_t i32 = 0;
return TryGetInt32(value, i32) && i32 == -1;
}
std::int32_t WrapInt32(std::uint32_t value) {
return static_cast<std::int32_t>(value);
}
Value* FoldBinary(Context& ctx, BinaryInst* inst) {
const auto opcode = inst->GetOpcode();
auto* lhs = inst->GetLhs();
auto* rhs = inst->GetRhs();
std::int32_t lhs_i32 = 0;
std::int32_t rhs_i32 = 0;
bool lhs_i1 = false;
bool rhs_i1 = false;
float lhs_f32 = 0.0f;
float rhs_f32 = 0.0f;
const bool has_lhs_i32 = TryGetInt32(lhs, lhs_i32);
const bool has_rhs_i32 = TryGetInt32(rhs, rhs_i32);
const bool has_lhs_i1 = TryGetBool(lhs, lhs_i1);
const bool has_rhs_i1 = TryGetBool(rhs, rhs_i1);
const bool has_lhs_f32 = TryGetFloat(lhs, lhs_f32);
const bool has_rhs_f32 = TryGetFloat(rhs, rhs_f32);
if (has_lhs_i32 && has_rhs_i32) {
switch (opcode) {
case Opcode::Add:
return GetInt32Const(
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) +
static_cast<std::uint32_t>(rhs_i32)));
case Opcode::Sub:
return GetInt32Const(
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) -
static_cast<std::uint32_t>(rhs_i32)));
case Opcode::Mul:
return GetInt32Const(
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) *
static_cast<std::uint32_t>(rhs_i32)));
case Opcode::Div:
if (rhs_i32 == 0 ||
(lhs_i32 == std::numeric_limits<std::int32_t>::min() && rhs_i32 == -1)) {
return nullptr;
}
return GetInt32Const(ctx, lhs_i32 / rhs_i32);
case Opcode::Rem:
if (rhs_i32 == 0 ||
(lhs_i32 == std::numeric_limits<std::int32_t>::min() && rhs_i32 == -1)) {
return nullptr;
}
return GetInt32Const(ctx, lhs_i32 % rhs_i32);
case Opcode::And:
return GetInt32Const(ctx, lhs_i32 & rhs_i32);
case Opcode::Or:
return GetInt32Const(ctx, lhs_i32 | rhs_i32);
case Opcode::Xor:
return GetInt32Const(ctx, lhs_i32 ^ rhs_i32);
case Opcode::Shl:
if (rhs_i32 < 0 || rhs_i32 >= 32) {
return nullptr;
}
return GetInt32Const(ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32)
<< rhs_i32));
case Opcode::AShr:
if (rhs_i32 < 0 || rhs_i32 >= 32) {
return nullptr;
}
return GetInt32Const(ctx, lhs_i32 >> rhs_i32);
case Opcode::LShr:
if (rhs_i32 < 0 || rhs_i32 >= 32) {
return nullptr;
}
return GetInt32Const(
ctx,
WrapInt32(static_cast<std::uint32_t>(lhs_i32) >> rhs_i32));
case Opcode::ICmpEQ:
return GetBoolConst(ctx, lhs_i32 == rhs_i32);
case Opcode::ICmpNE:
return GetBoolConst(ctx, lhs_i32 != rhs_i32);
case Opcode::ICmpLT:
return GetBoolConst(ctx, lhs_i32 < rhs_i32);
case Opcode::ICmpGT:
return GetBoolConst(ctx, lhs_i32 > rhs_i32);
case Opcode::ICmpLE:
return GetBoolConst(ctx, lhs_i32 <= rhs_i32);
case Opcode::ICmpGE:
return GetBoolConst(ctx, lhs_i32 >= rhs_i32);
default:
break;
}
}
if (has_lhs_i1 && has_rhs_i1) {
switch (opcode) {
case Opcode::And:
return GetBoolConst(ctx, lhs_i1 && rhs_i1);
case Opcode::Or:
return GetBoolConst(ctx, lhs_i1 || rhs_i1);
case Opcode::Xor:
return GetBoolConst(ctx, lhs_i1 != rhs_i1);
case Opcode::ICmpEQ:
return GetBoolConst(ctx, lhs_i1 == rhs_i1);
case Opcode::ICmpNE:
return GetBoolConst(ctx, lhs_i1 != rhs_i1);
case Opcode::ICmpLT:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) < static_cast<int>(rhs_i1));
case Opcode::ICmpGT:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) > static_cast<int>(rhs_i1));
case Opcode::ICmpLE:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) <= static_cast<int>(rhs_i1));
case Opcode::ICmpGE:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) >= static_cast<int>(rhs_i1));
default:
break;
}
}
if (has_lhs_f32 && has_rhs_f32) {
switch (opcode) {
case Opcode::FAdd:
return GetFloatConst(lhs_f32 + rhs_f32);
case Opcode::FSub:
return GetFloatConst(lhs_f32 - rhs_f32);
case Opcode::FMul:
return GetFloatConst(lhs_f32 * rhs_f32);
case Opcode::FDiv:
return GetFloatConst(lhs_f32 / rhs_f32);
case Opcode::FRem:
return GetFloatConst(std::fmod(lhs_f32, rhs_f32));
case Opcode::FCmpEQ:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 == rhs_f32);
case Opcode::FCmpNE:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 != rhs_f32);
case Opcode::FCmpLT:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 < rhs_f32);
case Opcode::FCmpGT:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 > rhs_f32);
case Opcode::FCmpLE:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 <= rhs_f32);
case Opcode::FCmpGE:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 >= rhs_f32);
default:
break;
}
}
switch (opcode) {
case Opcode::Add:
if (IsZeroValue(rhs)) {
return lhs;
}
if (IsZeroValue(lhs)) {
return rhs;
}
break;
case Opcode::Sub:
if (IsZeroValue(rhs)) {
return lhs;
}
break;
case Opcode::Mul:
if (IsOneValue(rhs)) {
return lhs;
}
if (IsOneValue(lhs)) {
return rhs;
}
if (IsZeroValue(lhs) || IsZeroValue(rhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::Div:
if (IsOneValue(rhs)) {
return lhs;
}
if (IsZeroValue(lhs) && !IsZeroValue(rhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::Rem:
if ((has_rhs_i32 && (rhs_i32 == 1 || rhs_i32 == -1)) ||
(has_rhs_i1 && rhs_i1)) {
return GetInt32Const(ctx, 0);
}
if (IsZeroValue(lhs) && !IsZeroValue(rhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::And:
if (IsZeroValue(lhs) || IsZeroValue(rhs)) {
return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false)
: GetInt32Const(ctx, 0);
}
if (has_lhs_i1 && lhs_i1) {
return rhs;
}
if (has_rhs_i1 && rhs_i1) {
return lhs;
}
if (IsAllOnesInt(lhs)) {
return rhs;
}
if (IsAllOnesInt(rhs)) {
return lhs;
}
if (passutils::AreEquivalentValues(lhs, rhs)) {
return lhs;
}
break;
case Opcode::Or:
if (IsZeroValue(lhs)) {
return rhs;
}
if (IsZeroValue(rhs)) {
return lhs;
}
if (has_lhs_i1 && lhs_i1) {
return GetBoolConst(ctx, true);
}
if (has_rhs_i1 && rhs_i1) {
return GetBoolConst(ctx, true);
}
if (passutils::AreEquivalentValues(lhs, rhs)) {
return lhs;
}
break;
case Opcode::Xor:
if (IsZeroValue(lhs)) {
return rhs;
}
if (IsZeroValue(rhs)) {
return lhs;
}
if (passutils::AreEquivalentValues(lhs, rhs)) {
return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false)
: GetInt32Const(ctx, 0);
}
break;
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
if (IsZeroValue(rhs)) {
return lhs;
}
if (IsZeroValue(lhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::FAdd:
if (IsZeroValue(lhs)) {
return rhs;
}
if (IsZeroValue(rhs)) {
return lhs;
}
break;
case Opcode::FSub:
if (IsZeroValue(rhs)) {
return lhs;
}
break;
case Opcode::FMul:
if (IsOneValue(lhs)) {
return rhs;
}
if (IsOneValue(rhs)) {
return lhs;
}
break;
case Opcode::FDiv:
if (IsOneValue(rhs)) {
return lhs;
}
break;
default:
break;
}
return nullptr;
}
Value* FoldUnary(Context& ctx, UnaryInst* inst) {
auto* operand = inst->GetOprd();
std::int32_t i32 = 0;
bool i1 = false;
float f32 = 0.0f;
switch (inst->GetOpcode()) {
case Opcode::Neg:
if (TryGetInt32(operand, i32)) {
return GetInt32Const(ctx, WrapInt32(0u - static_cast<std::uint32_t>(i32)));
}
break;
case Opcode::Not:
if (TryGetBool(operand, i1)) {
return GetBoolConst(ctx, !i1);
}
if (TryGetInt32(operand, i32)) {
return GetInt32Const(ctx, i32 ^ 1);
}
break;
case Opcode::FNeg:
if (TryGetFloat(operand, f32)) {
return GetFloatConst(-f32);
}
break;
case Opcode::FtoI:
if (TryGetFloat(operand, f32)) {
return GetInt32Const(ctx, static_cast<std::int32_t>(f32));
}
break;
case Opcode::IToF:
if (TryGetInt32(operand, i32)) {
return GetFloatConst(static_cast<float>(i32));
}
if (TryGetBool(operand, i1)) {
return GetFloatConst(i1 ? 1.0f : 0.0f);
}
break;
default:
break;
}
return nullptr;
}
Value* FoldZext(Context& ctx, ZextInst* inst) {
auto* value = inst->GetValue();
bool i1 = false;
std::int32_t i32 = 0;
if (inst->GetType()->IsInt1()) {
if (TryGetBool(value, i1)) {
return GetBoolConst(ctx, i1);
}
if (TryGetInt32(value, i32)) {
return GetBoolConst(ctx, i32 != 0);
}
}
if (inst->GetType()->IsInt32()) {
if (TryGetBool(value, i1)) {
return GetInt32Const(ctx, i1 ? 1 : 0);
}
if (TryGetInt32(value, i32)) {
return GetInt32Const(ctx, i32);
}
}
return nullptr;
}
bool FoldFunction(Function& function, Context& ctx) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
Value* replacement = nullptr;
if (auto* binary = dyncast<BinaryInst>(inst)) {
replacement = FoldBinary(ctx, binary);
} else if (auto* unary = dyncast<UnaryInst>(inst)) {
replacement = FoldUnary(ctx, unary);
} else if (auto* zext = dyncast<ZextInst>(inst)) {
replacement = FoldZext(ctx, zext);
}
if (!replacement || replacement == inst) {
continue;
}
inst->ReplaceAllUsesWith(replacement);
to_remove.push_back(inst);
changed = true;
}
for (auto* inst : to_remove) {
block_ptr->EraseInstruction(inst);
}
}
return changed;
}
} // namespace
bool RunConstFold(Module& module) {
bool changed = false;
auto& ctx = module.GetContext();
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= FoldFunction(*function, ctx);
}
}
return changed;
}
} // namespace ir