diff --git a/infer/src/pulse/PulseFormula.ml b/infer/src/pulse/PulseFormula.ml index ea8526de7..ea17244b1 100644 --- a/infer/src/pulse/PulseFormula.ml +++ b/infer/src/pulse/PulseFormula.ml @@ -660,46 +660,55 @@ module Term = struct (** more or less syntactic attempt at detecting when an arbitrary term is a linear formula; call {!Atom.eval_term} first for best results *) - let rec to_lin_arith t = - (* NOTE: don't duplicate simplifications between here and {!Atom.eval_term} *) - let open IOption.Let_syntax in - match t with - | Var v -> - Some (LinArith.of_var v) - | Const c -> - Some (LinArith.of_q c) - | Linear l -> - Some l - | Minus t -> - let+ l = to_lin_arith t in - LinArith.minus l - | Add (t1, t2) -> - let* l1 = to_lin_arith t1 in - let+ l2 = to_lin_arith t2 in - LinArith.add l1 l2 - | Mult (Const c, t) | Mult (t, Const c) -> - let+ l = to_lin_arith t in - LinArith.mult c l - | Div (t, Const c) when Q.is_not_zero c -> - let+ l = to_lin_arith t in - LinArith.mult (Q.inv c) l - | Mult _ - | Div _ - | Mod _ - | BitNot _ - | BitAnd _ - | BitOr _ - | BitShiftLeft _ - | BitShiftRight _ - | BitXor _ - | Not _ - | And _ - | Or _ - | LessThan _ - | LessEqual _ - | Equal _ - | NotEqual _ -> - None + let linearize t = + let rec aux_linearize t = + let open IOption.Let_syntax in + match t with + | Var v -> + Some (LinArith.of_var v) + | Const c -> + Some (LinArith.of_q c) + | Linear l -> + Some l + | Minus t -> + let+ l = aux_linearize t in + LinArith.minus l + | Add (t1, t2) -> + let* l1 = aux_linearize t1 in + let+ l2 = aux_linearize t2 in + LinArith.add l1 l2 + | Mult (Const c, t) | Mult (t, Const c) -> + let+ l = aux_linearize t in + LinArith.mult c l + | Div (t, Const c) when Q.is_not_zero c -> + let+ l = aux_linearize t in + LinArith.mult (Q.inv c) l + | Mult _ + | Div _ + | Mod _ + | BitNot _ + | BitAnd _ + | BitOr _ + | BitShiftLeft _ + | BitShiftRight _ + | BitXor _ + | Not _ + | And _ + | Or _ + | LessThan _ + | LessEqual _ + | Equal _ + | NotEqual _ -> + None + in + match aux_linearize t with None -> t | Some l -> Linear l + + + let simplify_linear = function + | Linear l -> ( + match LinArith.get_as_const l with Some c -> Const c | None -> Linear l ) + | t -> + t end (** Basically boolean terms, used to build the part of a formula that is not equalities between @@ -731,8 +740,6 @@ module Atom = struct F.fprintf fmt "%a ≠ %a" pp_term t1 pp_term t2 - let pp = pp_with_pp_var Var.pp - let get_terms atom = let (LessEqual (t1, t2) | LessThan (t1, t2) | Equal (t1, t2) | NotEqual (t1, t2)) = atom in (t1, t2) @@ -787,7 +794,8 @@ module Atom = struct let rec eval_term t = let t = - Term.map_direct_subterms ~f:eval_term t |> Term.simplify_shallow |> Term.eval_const_shallow + Term.map_direct_subterms ~f:eval_term t + |> Term.simplify_shallow |> Term.eval_const_shallow |> Term.linearize |> Term.simplify_linear in match (t : Term.t) with (* terms that are atoms can be simplified in [eval_atom] *) @@ -803,11 +811,7 @@ module Atom = struct t - (** This assumes that the terms in the atom have been normalized/evaluated already. - - TODO: probably a better way to implement this would be to rely entirely on - [eval_term (term_of_atom (atom))], possibly implementing it as something about observing the - sign/zero-ness of [t1 - t2]. *) + (** This assumes that the terms in the atom have been normalized/evaluated already. *) and eval_atom (atom : t) = let t1, t2 = get_terms atom in match (t1, t2) with @@ -821,6 +825,19 @@ module Atom = struct eval_result_of_bool (Q.leq c1 c2) | LessThan _ -> eval_result_of_bool (Q.lt c1 c2) ) + | Linear l1, Linear l2 -> + let l = LinArith.subtract l1 l2 in + let t = Term.simplify_linear (Linear l) in + eval_atom + ( match atom with + | Equal _ -> + Equal (t, Term.zero) + | NotEqual _ -> + NotEqual (t, Term.zero) + | LessEqual _ -> + LessEqual (t, Term.zero) + | LessThan _ -> + LessThan (t, Term.zero) ) | _ -> if Term.equal_syntax t1 t2 then match atom with @@ -865,6 +882,10 @@ module SatUnsatMonad = struct let ( let* ) phi f = bind_normalized f phi end +let sat_of_eval_result (eval_result : Atom.eval_result) = + match eval_result with True -> Sat None | False -> Unsat | Atom atom -> Sat (Some atom) + + module VarUF = UnionFind.Make (struct @@ -1055,14 +1076,7 @@ end = struct match q_opt with None -> VarSubst v_canon | Some q -> QSubst q ) in let atom' = Atom.map_terms atom ~f:(fun t -> normalize_term phi t) in - match Atom.eval atom' with - | True -> - Sat None - | False -> - L.d_printfln "Contradiction detected! %a ~> %a is False" Atom.pp atom Atom.pp atom' ; - Unsat - | Atom atom -> - Sat (Some atom) + Atom.eval atom' |> sat_of_eval_result let normalize_atoms phi = @@ -1075,15 +1089,12 @@ end = struct >>= function | None -> acc - | Some (Atom.Equal (t1, t2) as atom') -> ( - match Option.both (Term.to_lin_arith t1) (Term.to_lin_arith t2) with - | None -> - Sat (phi, Atom.Set.add atom' facts) - | Some (l1, l2) -> - (* an atom has been found to have become a linear equality, move it to the linear - equalities *) - let+ phi = solve_eq ~fuel:5 l1 l2 phi in - (phi, facts) ) + | Some (Atom.Equal (Linear l, Const c)) | Some (Atom.Equal (Const c, Linear l)) -> + (* NOTE: {!normalize_atom} calls {!Atom.eval}, which normalizes linear equalities so + they end up only on one side, hence only this match case is needed to detect linear + equalities *) + let+ phi = solve_eq ~fuel:5 l (LinArith.of_q c) phi in + (phi, facts) | Some atom' -> Sat (phi, Atom.Set.add atom' facts) ) in diff --git a/infer/src/pulse/unit/PulseFormulaTest.ml b/infer/src/pulse/unit/PulseFormulaTest.ml index 1317d1433..e7abadb75 100644 --- a/infer/src/pulse/unit/PulseFormulaTest.ml +++ b/infer/src/pulse/unit/PulseFormulaTest.ml @@ -130,7 +130,7 @@ let%test_module "normalization" = ( module struct let%expect_test _ = normalize (x < y) ; - [%expect {|true (no var=var) && true (no linear) && {x < y}|}] + [%expect {|true (no var=var) && true (no linear) && {[x + -y] < 0}|}] let%expect_test _ = normalize (x + i 1 - i 1 < x) ; @@ -166,6 +166,10 @@ let%test_module "normalization" = normalize (x = i 0 && x < i 0) ; [%expect {|unsat|}] + let%expect_test _ = + normalize (x + y < x + y) ; + [%expect {|true (no var=var) && v6 = x + y ∧ v7 = x + y && {[v6 + -v7] < 0}|}] + let%expect_test "nonlinear arithmetic" = normalize (z * (x + (v * y) + i 1) / w = i 0) ; [%expect @@ -174,7 +178,7 @@ let%test_module "normalization" = && v7 = x + v6 ∧ v8 = x + v6 +1 ∧ v10 = 0 && - {0 = v9÷w}∧{v6 = v×y}∧{v9 = z×v8} |}] + {0 = [v9]÷[w]}∧{[v6] = [v]×[y]}∧{[v9] = [z]×[v8]} |}] (* check that this becomes all linear equalities *) let%expect_test _ =