diff --git a/infer/src/pulse/PulseFormula.ml b/infer/src/pulse/PulseFormula.ml index 801086958..f0b8684bc 100644 --- a/infer/src/pulse/PulseFormula.ml +++ b/infer/src/pulse/PulseFormula.ml @@ -65,10 +65,6 @@ module LinArith : sig val of_var : Var.t -> t - val of_intlit : IntLit.t -> t - - val of_operand : operand -> t - val get_as_const : t -> Q.t option (** [get_as_const l] is [Some c] if [l=c], else [None] *) @@ -158,10 +154,6 @@ end = struct let of_q q = (Var.Map.empty, q) - let of_intlit i = IntLit.to_big_int i |> Q.of_bigint |> of_q - - let of_operand = function AbstractValueOperand v -> of_var v | LiteralOperand i -> of_intlit i - let get_as_const (vs, c) = if Var.Map.is_empty vs then Some c else None let get_as_var (vs, c) = @@ -725,8 +717,6 @@ module Atom = struct | NotEqual of Term.t * Term.t [@@deriving compare] - type atom = t - let pp_with_pp_var pp_var fmt atom = (* add parens around terms that look like atoms to disambiguate *) let needs_paren (t : Term.t) = @@ -770,6 +760,23 @@ module Atom = struct (acc, t') + let equal t1 t2 = Equal (t1, t2) + + let less_equal t1 t2 = LessEqual (t1, t2) + + let less_than t1 t2 = LessThan (t1, t2) + + let nnot = function + | Equal (t1, t2) -> + NotEqual (t1, t2) + | NotEqual (t1, t2) -> + Equal (t1, t2) + | LessEqual (t1, t2) -> + LessThan (t2, t1) + | LessThan (t1, t2) -> + LessEqual (t2, t1) + + let map_terms atom ~f = fold_map_terms atom ~init:() ~f:(fun () t -> ((), f t)) |> snd let to_term : t -> Term.t = function @@ -796,30 +803,40 @@ module Atom = struct to_term atom - let rec eval_term t = - let t = - Term.map_direct_subterms ~f:eval_term t - |> Term.simplify_shallow |> Term.eval_const_shallow |> Term.linearize |> Term.simplify_linear - in - match (t : Term.t) with + let atom_of_term : Term.t -> t option = function (* terms that are atoms can be simplified in [eval_atom] *) | LessEqual (t1, t2) -> - eval_atom (LessEqual (t1, t2) : atom) |> term_of_eval_result + Some (LessEqual (t1, t2)) | LessThan (t1, t2) -> - eval_atom (LessThan (t1, t2) : atom) |> term_of_eval_result + Some (LessThan (t1, t2)) | Equal (t1, t2) -> - eval_atom (Equal (t1, t2) : atom) |> term_of_eval_result + Some (Equal (t1, t2)) | NotEqual (t1, t2) -> - eval_atom (NotEqual (t1, t2) : atom) |> term_of_eval_result + Some (NotEqual (t1, t2)) | _ -> + None + + + let term_is_atom t = atom_of_term t |> Option.is_some + + let rec eval_term t = + let t = + Term.map_direct_subterms ~f:eval_term t + |> Term.simplify_shallow |> Term.eval_const_shallow |> Term.linearize |> Term.simplify_linear + in + match atom_of_term t with + | Some atom -> + (* terms that are atoms can be simplified in [eval_atom] *) + eval_atom atom |> term_of_eval_result + | None -> t (** This assumes that the terms in the atom have been normalized/evaluated already. *) and eval_atom (atom : t) = let t1, t2 = get_terms atom in - match (t1, t2) with - | Const c1, Const c2 -> ( + match (atom, t1, t2) with + | _, Const c1, Const c2 -> ( match atom with | Equal _ -> eval_result_of_bool (Q.equal c1 c2) @@ -829,7 +846,7 @@ module Atom = struct eval_result_of_bool (Q.leq c1 c2) | LessThan _ -> eval_result_of_bool (Q.lt c1 c2) ) - | Linear l1, Linear l2 -> + | _, Linear l1, Linear l2 -> let l = LinArith.subtract l1 l2 in let t = Term.simplify_linear (Linear l) in eval_atom @@ -842,18 +859,29 @@ module Atom = struct LessEqual (t, Term.zero) | LessThan _ -> LessThan (t, Term.zero) ) + | (Equal _ | NotEqual _), _, _ + when (Term.is_zero t1 && term_is_atom t2) || (Term.is_zero t2 && term_is_atom t1) -> ( + (* [atom = 0] or [atom ≠ 0] can be interpreted as [!atom] and [atom], respectively *) + match (atom, atom_of_term t1, atom_of_term t2) with + | _, Some _, Some _ | _, None, None | LessEqual _, _, _ | LessThan _, _, _ -> + (* impossible thanks to the match pattern and guard *) + assert false + | NotEqual _, Some atom, None | NotEqual _, None, Some atom -> + (* [atom] is true *) eval_atom atom + | Equal _, Some atom, None | Equal _, None, Some atom -> + (* [atom] is false *) eval_atom (nnot atom) ) + | _ when Term.equal_syntax t1 t2 -> ( + match atom with + | Equal _ -> + True + | NotEqual _ -> + False + | LessEqual _ -> + True + | LessThan _ -> + False ) | _ -> - if Term.equal_syntax t1 t2 then - match atom with - | Equal _ -> - True - | NotEqual _ -> - False - | LessEqual _ -> - True - | LessThan _ -> - False - else Atom atom + Atom atom let eval (atom : t) = map_terms atom ~f:eval_term |> eval_atom @@ -937,6 +965,8 @@ module Normalizer : sig val and_var_var : Var.t -> Var.t -> t -> t normalized + val and_atom : Atom.t -> t -> t normalized + val normalize : t -> t normalized end = struct (* Use the monadic notations when normalizing formulas. *) @@ -1057,16 +1087,31 @@ end = struct (** an arbitrary value *) let fuel = 5 - let and_var_linarith v l phi = merge_var_linarith ~fuel v l phi + let and_var_linarith v l phi = solve_eq ~fuel l (LinArith.of_var v) phi let and_var_var v1 v2 phi = merge_vars ~fuel v1 v2 phi - let normalize_linear_eqs phi0 = - (* reconstruct the relation from scratch *) - Var.Map.fold - (fun v l phi -> bind_normalized (and_var_linarith v (apply phi0 l)) phi) - phi0.linear_eqs - (Sat {phi0 with linear_eqs= Var.Map.empty}) + let rec normalize_linear_eqs ~fuel phi0 = + let* changed, phi' = + (* reconstruct the relation from scratch *) + Var.Map.fold + (fun v l acc -> + let* changed, phi = acc in + let l' = apply phi0 l in + let+ phi' = and_var_linarith v l' phi in + (changed || not (phys_equal l l'), phi') ) + phi0.linear_eqs + (Sat (false, {phi0 with linear_eqs= Var.Map.empty})) + in + if changed then + if fuel > 0 then ( + L.d_printfln "going around one more time normalizing the linear equalities" ; + (* do another pass if we can affort it *) + normalize_linear_eqs ~fuel:(fuel - 1) phi' ) + else ( + L.d_printfln "ran out of fuel normalizing the linear equalities" ; + Sat phi' ) + else Sat phi0 let normalize_atom phi (atom : Atom.t) = @@ -1083,142 +1128,55 @@ end = struct Atom.eval atom' |> sat_of_eval_result - let normalize_atoms phi = - let+ phi, atoms = - IContainer.fold_of_pervasives_set_fold Atom.Set.fold phi.atoms - ~init:(Sat (phi, Atom.Set.empty)) - ~f:(fun acc atom -> - let* phi, facts = acc in - normalize_atom phi atom - >>= function - | None -> - acc - | Some (Atom.Equal (Linear l, Const c)) | Some (Atom.Equal (Const c, Linear l)) -> - (* NOTE: {!normalize_atom} calls {!Atom.eval}, which normalizes linear equalities so - they end up only on one side, hence only this match case is needed to detect linear - equalities *) - let+ phi = solve_eq ~fuel:5 l (LinArith.of_q c) phi in - (phi, facts) - | Some atom' -> - Sat (phi, Atom.Set.add atom' facts) ) - in - {phi with atoms} + let and_atom atom phi = + normalize_atom phi atom + >>= function + | None -> + Sat phi + | Some (Atom.Equal (Linear l, Const c)) | Some (Atom.Equal (Const c, Linear l)) -> + (* NOTE: {!normalize_atom} calls {!Atom.eval}, which normalizes linear equalities so + they end up only on one side, hence only this match case is needed to detect linear + equalities *) + solve_eq ~fuel l (LinArith.of_q c) phi + | Some atom' -> + Sat {phi with atoms= Atom.Set.add atom' phi.atoms} - let normalize phi = normalize_linear_eqs phi >>= normalize_atoms -end + let normalize_atoms phi = + let atoms0 = phi.atoms in + let init = Sat {phi with atoms= Atom.Set.empty} in + IContainer.fold_of_pervasives_set_fold Atom.Set.fold atoms0 ~init ~f:(fun acc atom -> + let* phi = acc in + and_atom atom phi ) -let and_equal op1 op2 phi = - match (op1, op2) with - | LiteralOperand i1, LiteralOperand i2 -> - if IntLit.eq i1 i2 then Sat phi else Unsat - | AbstractValueOperand v, LiteralOperand i | LiteralOperand i, AbstractValueOperand v -> - Normalizer.and_var_linarith v (LinArith.of_intlit i) phi - | AbstractValueOperand v1, AbstractValueOperand v2 -> - Normalizer.and_var_var v1 v2 phi + let normalize phi = normalize_linear_eqs ~fuel phi >>= normalize_atoms +end -let and_atom atom phi = {phi with atoms= Atom.Set.add atom phi.atoms} +let and_mk_atom mk_atom op1 op2 phi = + Normalizer.and_atom (mk_atom (Term.of_operand op1) (Term.of_operand op2)) phi -let and_less_equal op1 op2 phi = - match (op1, op2) with - | LiteralOperand i1, LiteralOperand i2 -> - if IntLit.leq i1 i2 then Sat phi else Unsat - | _ -> - Sat (and_atom (LessEqual (Term.of_operand op1, Term.of_operand op2)) phi) +let and_equal = and_mk_atom Atom.equal -let and_less_than op1 op2 phi = - match (op1, op2) with - | LiteralOperand i1, LiteralOperand i2 -> - if IntLit.lt i1 i2 then Sat phi else Unsat - | _ -> - Sat (and_atom (LessThan (Term.of_operand op1, Term.of_operand op2)) phi) +let and_less_equal = and_mk_atom Atom.less_equal +let and_less_than = and_mk_atom Atom.less_than let and_equal_unop v (op : Unop.t) x phi = - match op with - | Neg -> - Normalizer.and_var_linarith v LinArith.(minus (of_operand x)) phi - | BNot | LNot -> - Sat (and_atom (Equal (Term.Var v, Term.of_unop op (Term.of_operand x))) phi) + Normalizer.and_atom (Equal (Var v, Term.of_unop op (Term.of_operand x))) phi let and_equal_binop v (bop : Binop.t) x y phi = - let and_linear_eq l = Normalizer.and_var_linarith v l phi in - match bop with - | PlusA _ | PlusPI -> - LinArith.(add (of_operand x) (of_operand y)) |> and_linear_eq - | MinusA _ | MinusPI | MinusPP -> - LinArith.(subtract (of_operand x) (of_operand y)) |> and_linear_eq - (* TODO: some of the below could become linear arithmetic after simplifications (e.g. up to constants) *) - | Mult _ - | Div - | Mod - | Shiftlt - | Shiftrt - | BAnd - | BXor - | BOr - (* TODO: (most) logical operators should be translated into the formula structure *) - | Lt - | Gt - | Le - | Ge - | Eq - | Ne - | LAnd - | LOr -> - Sat - (and_atom - (Equal (Term.Var v, Term.of_binop bop (Term.of_operand x) (Term.of_operand y))) - phi) + Normalizer.and_atom (Equal (Var v, Term.of_binop bop (Term.of_operand x) (Term.of_operand y))) phi let prune_binop ~negated (bop : Binop.t) x y phi = - let atom op x y = - let tx = Term.of_operand x in - let ty = Term.of_operand y in - let atom : Atom.t = - match op with - | `LessThan -> - LessThan (tx, ty) - | `LessEqual -> - LessEqual (tx, ty) - | `NotEqual -> - NotEqual (tx, ty) - in - Sat (and_atom atom phi) - in - match (bop, negated) with - | Eq, false | Ne, true -> - and_equal x y phi - | Ne, false | Eq, true -> - atom `NotEqual x y - | Lt, false | Ge, true -> - atom `LessThan x y - | Le, false | Gt, true -> - atom `LessEqual x y - | Ge, false | Lt, true -> - atom `LessEqual y x - | Gt, false | Le, true -> - atom `LessThan y x - | LAnd, _ - | LOr, _ - | Mult _, _ - | Div, _ - | Mod, _ - | Shiftlt, _ - | Shiftrt, _ - | BAnd, _ - | BXor, _ - | BOr, _ - | PlusA _, _ - | PlusPI, _ - | MinusA _, _ - | MinusPI, _ - | MinusPP, _ -> - Sat phi + let tx = Term.of_operand x in + let ty = Term.of_operand y in + let t = Term.of_binop bop tx ty in + let atom = if negated then Atom.Equal (t, Term.zero) else Atom.NotEqual (t, Term.zero) in + Normalizer.and_atom atom phi let normalize phi = Normalizer.normalize phi @@ -1260,7 +1218,7 @@ let and_fold_map_variables phi0 ~up_to_f:phi_foreign ~init ~f = let acc_f, v' = f acc_f v in (acc_f, VarSubst v') ) in - let phi = and_atom atom phi in + let phi = Normalizer.and_atom atom phi |> sat_value_exn in (acc_f, phi) ) in try Sat (and_var_eqs (init, phi0) |> and_linear_eqs |> and_atoms) with Contradiction -> Unsat diff --git a/infer/src/pulse/unit/PulseFormulaTest.ml b/infer/src/pulse/unit/PulseFormulaTest.ml index 8fc56464d..b215aa35d 100644 --- a/infer/src/pulse/unit/PulseFormulaTest.ml +++ b/infer/src/pulse/unit/PulseFormulaTest.ml @@ -176,9 +176,9 @@ let%test_module "normalization" = {| true (no var=var) && - v7 = x + v6 ∧ v8 = x + v6 +1 ∧ v10 = 0 + x = -v6 + v8 -1 ∧ v7 = v8 -1 ∧ v10 = 0 && - {0 = [v9]÷[w]}∧{[v6] = [v]×[y]}∧{[v9] = [z]×[x + v6 +1]} |}] + {0 = [v9]÷[w]}∧{[v6] = [v]×[y]}∧{[v9] = [z]×[v8]} |}] (* check that this becomes all linear equalities *) let%expect_test _ = @@ -187,7 +187,7 @@ let%test_module "normalization" = {| true (no var=var) && - x = -v6 + 1/12·v9 -1 ∧ y = 1/3·v6 ∧ v7 = x + v6 ∧ v8 = x + v6 +1 ∧ v9 = 0 ∧ v10 = 0 + x = -v6 -1 ∧ y = 1/3·v6 ∧ v7 = -1 ∧ v8 = 0 ∧ v9 = 0 ∧ v10 = 0 && true (no atoms)|}] @@ -198,8 +198,8 @@ let%test_module "normalization" = {| true (no var=var) && - x = -v6 + 1/12·v9 -1 ∧ y = 1/3·v6 ∧ z = 12 ∧ w = 7 ∧ v = 3 ∧ v7 = x + v6 - ∧ v8 = x + v6 +1 ∧ v9 = 0 ∧ v10 = 0 + x = -v6 + v8 -1 ∧ y = 1/3·v6 ∧ z = 12 ∧ w = 7 ∧ v = 3 ∧ v7 = v8 -1 + ∧ v8 = 1/12·v9 ∧ v9 = 0 ∧ v10 = 0 && true (no atoms)|}] end ) @@ -221,17 +221,12 @@ let%test_module "variable elimination" = (* should keep most of this or realize that [w = z] hence this boils down to [z+1 = 0] *) let%expect_test _ = simplify ~keep:[y_var; z_var] (x = y + z && w = x - y && v = w + i 1 && v = i 0) ; - [%expect {|x=v6 ∧ z=w=v7 && x = y -1 ∧ z = -1 && true (no atoms)|}] + [%expect {|x=v6 && x = y -1 ∧ z = -1 && true (no atoms)|}] let%expect_test _ = simplify ~keep:[x_var; y_var] (x = y + z && w + x + y = i 0 && v = w + i 1) ; [%expect - {| - x=v6 ∧ v=v9 - && - x = 1/2·z + -1/2·w ∧ y = -1/2·z + -1/2·w ∧ v = w +1 ∧ v7 = 1/2·z + 1/2·w - && - true (no atoms)|}] + {|x=v6 ∧ v=v9 && x = -v + v7 +1 ∧ y = -v7 ∧ z = -v + 2·v7 +1 ∧ w = v -1 && true (no atoms)|}] let%expect_test _ = simplify ~keep:[x_var; y_var] (x = y + i 4 && x = w && y = z) ; @@ -242,9 +237,9 @@ let%test_module "non-linear simplifications" = ( module struct let%expect_test "zero propagation" = simplify ~keep:[w_var] (((i 0 / (x * z)) & v) * v mod y = w) ; - [%expect {|w=v10 && w = 0 && true (no atoms)|}] + [%expect {|true (no var=var) && w = 0 && true (no atoms)|}] let%expect_test "constant propagation: bitshift" = simplify ~keep:[x_var] (of_binop Shiftlt (of_binop Shiftrt (i 0b111) (i 2)) (i 2) = x) ; - [%expect {|x=v7 && x = 4 && true (no atoms)|}] + [%expect {|true (no var=var) && x = 4 && true (no atoms)|}] end )