From 36af901d79dceae3314f7b776d1c07605dfb2481 Mon Sep 17 00:00:00 2001 From: Jules Villard Date: Tue, 25 Aug 2020 01:53:17 -0700 Subject: [PATCH] [pulse] normalize any linear atom Summary: Linear arithmetic is able to simplify more atoms, eg `x+y <= x+y` becomes `True` by normalising to "lhs - rhs <= 0". This does the first step of normalisation, but to get True in this example we also need to substitute inside atoms according to the linear equalities, which is the next diff (for now we only substitute variables inside atoms for other variables or for constants). Reviewed By: skcho Differential Revision: D23241457 fbshipit-source-id: 0da0b545c --- infer/src/pulse/PulseFormula.ml | 141 ++++++++++++----------- infer/src/pulse/unit/PulseFormulaTest.ml | 8 +- 2 files changed, 82 insertions(+), 67 deletions(-) diff --git a/infer/src/pulse/PulseFormula.ml b/infer/src/pulse/PulseFormula.ml index ea8526de7..ea17244b1 100644 --- a/infer/src/pulse/PulseFormula.ml +++ b/infer/src/pulse/PulseFormula.ml @@ -660,46 +660,55 @@ module Term = struct (** more or less syntactic attempt at detecting when an arbitrary term is a linear formula; call {!Atom.eval_term} first for best results *) - let rec to_lin_arith t = - (* NOTE: don't duplicate simplifications between here and {!Atom.eval_term} *) - let open IOption.Let_syntax in - match t with - | Var v -> - Some (LinArith.of_var v) - | Const c -> - Some (LinArith.of_q c) - | Linear l -> - Some l - | Minus t -> - let+ l = to_lin_arith t in - LinArith.minus l - | Add (t1, t2) -> - let* l1 = to_lin_arith t1 in - let+ l2 = to_lin_arith t2 in - LinArith.add l1 l2 - | Mult (Const c, t) | Mult (t, Const c) -> - let+ l = to_lin_arith t in - LinArith.mult c l - | Div (t, Const c) when Q.is_not_zero c -> - let+ l = to_lin_arith t in - LinArith.mult (Q.inv c) l - | Mult _ - | Div _ - | Mod _ - | BitNot _ - | BitAnd _ - | BitOr _ - | BitShiftLeft _ - | BitShiftRight _ - | BitXor _ - | Not _ - | And _ - | Or _ - | LessThan _ - | LessEqual _ - | Equal _ - | NotEqual _ -> - None + let linearize t = + let rec aux_linearize t = + let open IOption.Let_syntax in + match t with + | Var v -> + Some (LinArith.of_var v) + | Const c -> + Some (LinArith.of_q c) + | Linear l -> + Some l + | Minus t -> + let+ l = aux_linearize t in + LinArith.minus l + | Add (t1, t2) -> + let* l1 = aux_linearize t1 in + let+ l2 = aux_linearize t2 in + LinArith.add l1 l2 + | Mult (Const c, t) | Mult (t, Const c) -> + let+ l = aux_linearize t in + LinArith.mult c l + | Div (t, Const c) when Q.is_not_zero c -> + let+ l = aux_linearize t in + LinArith.mult (Q.inv c) l + | Mult _ + | Div _ + | Mod _ + | BitNot _ + | BitAnd _ + | BitOr _ + | BitShiftLeft _ + | BitShiftRight _ + | BitXor _ + | Not _ + | And _ + | Or _ + | LessThan _ + | LessEqual _ + | Equal _ + | NotEqual _ -> + None + in + match aux_linearize t with None -> t | Some l -> Linear l + + + let simplify_linear = function + | Linear l -> ( + match LinArith.get_as_const l with Some c -> Const c | None -> Linear l ) + | t -> + t end (** Basically boolean terms, used to build the part of a formula that is not equalities between @@ -731,8 +740,6 @@ module Atom = struct F.fprintf fmt "%a ≠ %a" pp_term t1 pp_term t2 - let pp = pp_with_pp_var Var.pp - let get_terms atom = let (LessEqual (t1, t2) | LessThan (t1, t2) | Equal (t1, t2) | NotEqual (t1, t2)) = atom in (t1, t2) @@ -787,7 +794,8 @@ module Atom = struct let rec eval_term t = let t = - Term.map_direct_subterms ~f:eval_term t |> Term.simplify_shallow |> Term.eval_const_shallow + Term.map_direct_subterms ~f:eval_term t + |> Term.simplify_shallow |> Term.eval_const_shallow |> Term.linearize |> Term.simplify_linear in match (t : Term.t) with (* terms that are atoms can be simplified in [eval_atom] *) @@ -803,11 +811,7 @@ module Atom = struct t - (** This assumes that the terms in the atom have been normalized/evaluated already. - - TODO: probably a better way to implement this would be to rely entirely on - [eval_term (term_of_atom (atom))], possibly implementing it as something about observing the - sign/zero-ness of [t1 - t2]. *) + (** This assumes that the terms in the atom have been normalized/evaluated already. *) and eval_atom (atom : t) = let t1, t2 = get_terms atom in match (t1, t2) with @@ -821,6 +825,19 @@ module Atom = struct eval_result_of_bool (Q.leq c1 c2) | LessThan _ -> eval_result_of_bool (Q.lt c1 c2) ) + | Linear l1, Linear l2 -> + let l = LinArith.subtract l1 l2 in + let t = Term.simplify_linear (Linear l) in + eval_atom + ( match atom with + | Equal _ -> + Equal (t, Term.zero) + | NotEqual _ -> + NotEqual (t, Term.zero) + | LessEqual _ -> + LessEqual (t, Term.zero) + | LessThan _ -> + LessThan (t, Term.zero) ) | _ -> if Term.equal_syntax t1 t2 then match atom with @@ -865,6 +882,10 @@ module SatUnsatMonad = struct let ( let* ) phi f = bind_normalized f phi end +let sat_of_eval_result (eval_result : Atom.eval_result) = + match eval_result with True -> Sat None | False -> Unsat | Atom atom -> Sat (Some atom) + + module VarUF = UnionFind.Make (struct @@ -1055,14 +1076,7 @@ end = struct match q_opt with None -> VarSubst v_canon | Some q -> QSubst q ) in let atom' = Atom.map_terms atom ~f:(fun t -> normalize_term phi t) in - match Atom.eval atom' with - | True -> - Sat None - | False -> - L.d_printfln "Contradiction detected! %a ~> %a is False" Atom.pp atom Atom.pp atom' ; - Unsat - | Atom atom -> - Sat (Some atom) + Atom.eval atom' |> sat_of_eval_result let normalize_atoms phi = @@ -1075,15 +1089,12 @@ end = struct >>= function | None -> acc - | Some (Atom.Equal (t1, t2) as atom') -> ( - match Option.both (Term.to_lin_arith t1) (Term.to_lin_arith t2) with - | None -> - Sat (phi, Atom.Set.add atom' facts) - | Some (l1, l2) -> - (* an atom has been found to have become a linear equality, move it to the linear - equalities *) - let+ phi = solve_eq ~fuel:5 l1 l2 phi in - (phi, facts) ) + | Some (Atom.Equal (Linear l, Const c)) | Some (Atom.Equal (Const c, Linear l)) -> + (* NOTE: {!normalize_atom} calls {!Atom.eval}, which normalizes linear equalities so + they end up only on one side, hence only this match case is needed to detect linear + equalities *) + let+ phi = solve_eq ~fuel:5 l (LinArith.of_q c) phi in + (phi, facts) | Some atom' -> Sat (phi, Atom.Set.add atom' facts) ) in diff --git a/infer/src/pulse/unit/PulseFormulaTest.ml b/infer/src/pulse/unit/PulseFormulaTest.ml index 1317d1433..e7abadb75 100644 --- a/infer/src/pulse/unit/PulseFormulaTest.ml +++ b/infer/src/pulse/unit/PulseFormulaTest.ml @@ -130,7 +130,7 @@ let%test_module "normalization" = ( module struct let%expect_test _ = normalize (x < y) ; - [%expect {|true (no var=var) && true (no linear) && {x < y}|}] + [%expect {|true (no var=var) && true (no linear) && {[x + -y] < 0}|}] let%expect_test _ = normalize (x + i 1 - i 1 < x) ; @@ -166,6 +166,10 @@ let%test_module "normalization" = normalize (x = i 0 && x < i 0) ; [%expect {|unsat|}] + let%expect_test _ = + normalize (x + y < x + y) ; + [%expect {|true (no var=var) && v6 = x + y ∧ v7 = x + y && {[v6 + -v7] < 0}|}] + let%expect_test "nonlinear arithmetic" = normalize (z * (x + (v * y) + i 1) / w = i 0) ; [%expect @@ -174,7 +178,7 @@ let%test_module "normalization" = && v7 = x + v6 ∧ v8 = x + v6 +1 ∧ v10 = 0 && - {0 = v9÷w}∧{v6 = v×y}∧{v9 = z×v8} |}] + {0 = [v9]÷[w]}∧{[v6] = [v]×[y]}∧{[v9] = [z]×[v8]} |}] (* check that this becomes all linear equalities *) let%expect_test _ =