diff --git a/infer/src/pulse/PulseFormula.ml b/infer/src/pulse/PulseFormula.ml index 2e14051c0..b8c097530 100644 --- a/infer/src/pulse/PulseFormula.ml +++ b/infer/src/pulse/PulseFormula.ml @@ -24,6 +24,14 @@ module Q = struct let is_zero q = Q.equal q Q.zero let is_not_zero q = not (is_zero q) + + let conv_protect f q = try Some (f q) with Division_by_zero | Z.Overflow -> None + + let to_int q = conv_protect Q.to_int q + + let to_int64 q = conv_protect Q.to_int64 q + + let to_bigint q = conv_protect Q.to_bigint q end (** Expressive term structure to be able to express all of SIL, but the main smarts of the formulas @@ -143,6 +151,8 @@ module Term = struct let zero = Const Q.zero + let of_bool b = if b then one else zero + let of_unop (unop : Unop.t) t = match unop with Neg -> Minus t | BNot -> BitNot t | LNot -> Not t @@ -309,6 +319,150 @@ module Term = struct let has_var_notin vars t = Container.exists t ~iter:iter_variables ~f:(fun v -> not (Var.Set.mem v vars)) + + + (** reduce to a constant when the direct sub-terms are constants *) + let eval_const_shallow t0 = + let map_const t f = match t with Const c -> f c | _ -> t0 in + let map_const2 t1 t2 f = match (t1, t2) with Const c1, Const c2 -> f c1 c2 | _ -> t0 in + let q_map t q_f = map_const t (fun c -> Const (q_f c)) in + let q_map2 t1 t2 q_f = map_const2 t1 t2 (fun c1 c2 -> Const (q_f c1 c2)) in + let q_predicate_map t q_f = map_const t (fun c -> q_f c |> of_bool) in + let q_predicate_map2 t1 t2 q_f = map_const2 t1 t2 (fun c1 c2 -> q_f c1 c2 |> of_bool) in + let conv2 conv1 conv2 conv_back c1 c2 f = + let open IOption.Let_syntax in + let* i1 = conv1 c1 in + let+ i2 = conv2 c2 in + f i1 i2 |> conv_back + in + let map_i64_i64 = conv2 Q.to_int64 Q.to_int64 Q.of_int64 in + let map_i64_i = conv2 Q.to_int64 Q.to_int Q.of_int64 in + let map_z_z = conv2 Q.to_bigint Q.to_bigint Q.of_bigint in + let or_undef q_opt = Option.value ~default:Q.undef q_opt in + match t0 with + | Const _ | Var _ -> + t0 + | Minus t' -> + q_map t' Q.(mul minus_one) + | Add (t1, t2) -> + q_map2 t1 t2 Q.add + | BitNot t' -> + q_map t' (fun c -> + let open Option.Monad_infix in + Q.to_int64 c >>| Int64.bit_not >>| Q.of_int64 |> or_undef ) + | Mult (t1, t2) -> + q_map2 t1 t2 Q.mul + | Div (t1, t2) -> + q_map2 t1 t2 Q.div + | Mod (t1, t2) -> + q_map2 t1 t2 (fun c1 c2 -> map_z_z c1 c2 Z.( mod ) |> or_undef) + | Not t' -> + q_predicate_map t' Q.is_zero + | And (t1, t2) -> + map_const2 t1 t2 (fun c1 c2 -> of_bool (Q.is_not_zero c1 && Q.is_not_zero c2)) + | Or (t1, t2) -> + map_const2 t1 t2 (fun c1 c2 -> of_bool (Q.is_not_zero c1 || Q.is_not_zero c2)) + | LessThan (t1, t2) -> + q_predicate_map2 t1 t2 Q.lt + | LessEqual (t1, t2) -> + q_predicate_map2 t1 t2 Q.leq + | Equal (t1, t2) -> + q_predicate_map2 t1 t2 Q.equal + | NotEqual (t1, t2) -> + q_predicate_map2 t1 t2 Q.not_equal + | BitAnd (t1, t2) + | BitOr (t1, t2) + | BitShiftLeft (t1, t2) + | BitShiftRight (t1, t2) + | BitXor (t1, t2) -> + q_map2 t1 t2 (fun c1 c2 -> + match[@warning "-8"] t0 with + | BitAnd _ -> + map_i64_i64 c1 c2 Int64.bit_and |> or_undef + | BitOr _ -> + map_i64_i64 c1 c2 Int64.bit_or |> or_undef + | BitShiftLeft _ -> + map_i64_i c1 c2 Int64.shift_left |> or_undef + | BitShiftRight _ -> + map_i64_i c1 c2 Int64.shift_right |> or_undef + | BitXor _ -> + map_i64_i64 c1 c2 Int64.bit_xor |> or_undef ) + + + let rec simplify_shallow t = + match t with + | Var _ | Const _ -> + t + | Minus (Minus t) -> + (* [--t = t] *) + t + | BitNot (BitNot t) -> + (* [~~t = t] *) + t + | Add (Const c, t) when Q.is_zero c -> + (* [0 + t = t] *) + t + | Add (t, Const c) when Q.is_zero c -> + (* [t + 0 = t] *) + t + | Mult (Const c, t) when Q.is_one c -> + (* [1 × t = t] *) + t + | Mult (t, Const c) when Q.is_one c -> + (* [t × 1 = t] *) + t + | Mult (Const c, _) when Q.is_zero c -> + (* [0 × t = 0] *) + zero + | Mult (_, Const c) when Q.is_zero c -> + (* [t × 0 = 0] *) + zero + | Div (Const c, _) when Q.is_zero c -> + (* [0 / t = 0] *) + zero + | Div (_, Const c) when Q.is_zero c -> + (* [t / 0 = undefined] *) + Const Q.undef + | Div (t, Const c) -> + (* [t / c = (1/c)·t] *) + simplify_shallow (Mult (Const (Q.inv c), t)) + | Div (Minus t1, Minus t2) -> + (* [(-t1) / (-t2) = t1 / t2] *) + simplify_shallow (Div (t1, t2)) + | Div (t1, t2) when equal_syntax t1 t2 -> + (* [t / t = 1] *) + one + | Mod (Const c, _) when Q.is_zero c -> + (* [0 % t = 0] *) + zero + | Mod (_, Const q) when Q.is_one q -> + (* [t % 1 = 0] *) + zero + | Mod (t1, t2) when equal_syntax t1 t2 -> + (* [t % t = 0] *) + zero + | BitAnd (t1, t2) when is_zero t1 || is_zero t2 -> + zero + | BitXor (t1, t2) when equal_syntax t1 t2 -> + zero + | (BitShiftLeft (t1, _) | BitShiftRight (t1, _)) when is_zero t1 -> + zero + | (BitShiftLeft (t1, t2) | BitShiftRight (t1, t2)) when is_zero t2 -> + t1 + | And (t1, t2) when is_zero t1 || is_zero t2 -> + (* [false ∧ t = t ∧ false = false] *) zero + | And (t1, t2) when is_non_zero_const t1 -> + (* [true ∧ t = t] *) t2 + | And (t1, t2) when is_non_zero_const t2 -> + (* [t ∧ true = t] *) t1 + | Or (t1, t2) when is_non_zero_const t1 || is_non_zero_const t2 -> + (* [true ∨ t = t ∨ true = true] *) one + | Or (t1, t2) when is_zero t1 -> + (* [false ∨ t = t] *) t2 + | Or (t1, t2) when is_zero t2 -> + (* [t ∨ false = t] *) t1 + | _ -> + t end (** Basically boolean terms, used to build the part of a formula that is not equalities between @@ -394,80 +548,11 @@ module Atom = struct to_term atom - (* Many simplifications are still TODO *) let rec eval_term t = - let open Term in - let t_norm_subterms = map_direct_subterms ~f:eval_term t in - match t_norm_subterms with - | Var _ | Const _ -> - t - | Minus (Minus t) -> - (* [--t = t] *) - t - | Minus (Const c) -> - (* [-c = -1*c] *) - Const (Q.(mul minus_one) c) - | BitNot (BitNot t) -> - (* [~~t = t] *) - t - | Not (Const c) -> - if Q.is_zero c then (* [!0 = 1] *) - one else (* [! = 0] *) - zero - | Add (Const c1, Const c2) -> - (* constants *) - Const (Q.add c1 c2) - | Add (Const c, t) when Q.is_zero c -> - (* [0 + t = t] *) - t - | Add (t, Const c) when Q.is_zero c -> - (* [t + 0 = t] *) - t - | Mult (Const c, t) when Q.is_one c -> - (* [1 × t = t] *) - t - | Mult (t, Const c) when Q.is_one c -> - (* [t × 1 = t] *) - t - | Mult (Const c, _) when Q.is_zero c -> - (* [0 × t = 0] *) - zero - | Mult (_, Const c) when Q.is_zero c -> - (* [t × 0 = 0] *) - zero - | Div (Const c, _) when Q.is_zero c -> - (* [0 / t = 0] *) - zero - | Div (t, Const c) when Q.is_one c -> - (* [t / 1 = t] *) - t - | Div (t, Const c) when Q.is_minus_one c -> - (* [t / (-1) = -t] *) - eval_term (Minus t) - | Div (Minus t1, Minus t2) -> - (* [(-t1) / (-t2) = t1 / t2] *) - eval_term (Div (t1, t2)) - | Mod (Const c, _) when Q.is_zero c -> - (* [0 % t = 0] *) - zero - | Mod (_, Const q) when Q.is_one q -> - (* [t % 1 = 0] *) - zero - | Mod (t1, t2) when equal_syntax t1 t2 -> - (* [t % t = 0] *) - zero - | And (t1, t2) when is_zero t1 || is_zero t2 -> - (* [false ∧ t = t ∧ false = false] *) zero - | And (t1, t2) when is_non_zero_const t1 -> - (* [true ∧ t = t] *) t2 - | And (t1, t2) when is_non_zero_const t2 -> - (* [t ∧ true = t] *) t1 - | Or (t1, t2) when is_non_zero_const t1 || is_non_zero_const t2 -> - (* [true ∨ t = t ∨ true = true] *) one - | Or (t1, t2) when is_zero t1 -> - (* [false ∨ t = t] *) t2 - | Or (t1, t2) when is_zero t2 -> - (* [t ∨ false = t] *) t1 + let t = + Term.map_direct_subterms ~f:eval_term t |> Term.simplify_shallow |> Term.eval_const_shallow + in + match (t : Term.t) with (* terms that are atoms can be simplified in [eval_atom] *) | LessEqual (t1, t2) -> eval_atom (LessEqual (t1, t2) : atom) |> term_of_eval_result @@ -478,7 +563,7 @@ module Atom = struct | NotEqual (t1, t2) -> eval_atom (NotEqual (t1, t2) : atom) |> term_of_eval_result | _ -> - t_norm_subterms + t (** This assumes that the terms in the atom have been normalized/evaluated already. diff --git a/infer/src/pulse/unit/PulseFormulaTest.ml b/infer/src/pulse/unit/PulseFormulaTest.ml index 79889b365..1317d1433 100644 --- a/infer/src/pulse/unit/PulseFormulaTest.ml +++ b/infer/src/pulse/unit/PulseFormulaTest.ml @@ -238,9 +238,9 @@ let%test_module "non-linear simplifications" = ( module struct let%expect_test "zero propagation" = simplify ~keep:[w_var] (((i 0 / (x * z)) & v) * v mod y = w) ; - [%expect {|w=v10 && true (no linear) && {w = v9 mod y}∧{v8 = 0&v}∧{v9 = v8×v}|}] + [%expect {|w=v10 && w = 0 && true (no atoms)|}] let%expect_test "constant propagation: bitshift" = simplify ~keep:[x_var] (of_binop Shiftlt (of_binop Shiftrt (i 0b111) (i 2)) (i 2) = x) ; - [%expect {|x=v7 && true (no linear) && {x = v6<<2}∧{v6 = 7>>2}|}] + [%expect {|x=v7 && x = 4 && true (no atoms)|}] end )