From 45894a7dd9ab8740d07f38e5d365f4e4223012d4 Mon Sep 17 00:00:00 2001 From: Jules Villard Date: Tue, 25 Aug 2020 01:53:08 -0700 Subject: [PATCH] [pulse] move LinArith before Term Summary: This is needed for the rest of the stack that introduces a `Linear of LinArith.t` variant in `Term.t` to enable more normalisation inside of terms. Reviewed By: skcho Differential Revision: D23241353 fbshipit-source-id: ad765cd13 --- infer/src/pulse/PulseFormula.ml | 426 ++++++++++++++++---------------- 1 file changed, 214 insertions(+), 212 deletions(-) diff --git a/infer/src/pulse/PulseFormula.ml b/infer/src/pulse/PulseFormula.ml index b8c097530..3f6b4e883 100644 --- a/infer/src/pulse/PulseFormula.ml +++ b/infer/src/pulse/PulseFormula.ml @@ -12,6 +12,10 @@ module Var = PulseAbstractValue type operand = LiteralOperand of IntLit.t | AbstractValueOperand of Var.t +(** "normalized" is not to be taken too seriously, it just means *some* normalization was applied + that could result in discovering something is unsatisfiable *) +type 'a normalized = Unsat | Sat of 'a + module Q = struct include Q @@ -34,6 +38,173 @@ module Q = struct let to_bigint q = conv_protect Q.to_bigint q end +(** Linear Arithmetic*) +module LinArith : sig + (** linear combination of variables, eg [2·x + 3/4·y + 12] *) + type t + + val pp : (F.formatter -> Var.t -> unit) -> F.formatter -> t -> unit + + val is_zero : t -> bool + + val add : t -> t -> t + + val minus : t -> t + + val subtract : t -> t -> t + + val mult : Q.t -> t -> t + + val solve_eq : t -> t -> (Var.t * t) option normalized + (** [solve_eq l1 l2] is [Sat (Some (x, l))] if [l1=l2 <=> x=l], [Sat None] if [l1 = l2] is always + true, and [Unsat] if it is always false *) + + val of_q : Q.t -> t + + val of_var : Var.t -> t + + val of_intlit : IntLit.t -> t + + val of_operand : operand -> t + + val get_as_const : t -> Q.t option + (** [get_as_const l] is [Some c] if [l=c], else [None] *) + + val get_as_var : t -> Var.t option + (** [get_as_var l] is [Some x] if [l=x], else [None] *) + + val has_var : Var.t -> t -> bool + + val subst : Var.t -> Var.t -> t -> t + + val subst_vars : f:(Var.t -> t) -> t -> t + + val get_variables : t -> Var.t Seq.t + + val fold_map_variables : t -> init:'a -> f:('a -> Var.t -> 'a * Var.t) -> 'a * t +end = struct + (** invariant: the representation is always "canonical": coefficients cannot be [Q.zero] *) + type t = Q.t Var.Map.t * Q.t + + let pp pp_var fmt (vs, c) = + if Var.Map.is_empty vs then Q.pp_print fmt c + else + let pp_c fmt c = + if Q.is_zero c then () + else + let plusminus, c_pos = if Q.geq c Q.zero then ('+', c) else ('-', Q.neg c) in + F.fprintf fmt " %c%a" plusminus Q.pp_print c_pos + in + let pp_coeff fmt q = + if Q.is_one q then () + else if Q.is_minus_one q then F.pp_print_string fmt "-" + else F.fprintf fmt "%a·" Q.pp_print q + in + let pp_vs fmt vs = + Pp.collection ~sep:" + " + ~fold:(IContainer.fold_of_pervasives_map_fold Var.Map.fold) + ~pp_item:(fun fmt (v, q) -> F.fprintf fmt "%a%a" pp_coeff q pp_var v) + fmt vs + in + F.fprintf fmt "@[%a%a@]" pp_vs vs pp_c c + + + let add (vs1, c1) (vs2, c2) = + ( Var.Map.union + (fun _v c1 c2 -> + let c = Q.add c1 c2 in + if Q.is_zero c then None else Some c ) + vs1 vs2 + , Q.add c1 c2 ) + + + let minus (vs, c) = (Var.Map.map (fun c -> Q.neg c) vs, Q.neg c) + + let subtract l1 l2 = add l1 (minus l2) + + let zero = (Var.Map.empty, Q.zero) + + let is_zero (vs, c) = Q.is_zero c && Var.Map.is_empty vs + + let mult q ((vs, c) as l) = + if Q.is_zero q then (* needed for correction: coeffs cannot be zero *) zero + else if Q.is_one q then (* purely an optimisation *) l + else (Var.Map.map (fun c -> Q.mul q c) vs, Q.mul q c) + + + let solve_eq_zero (vs, c) = + match Var.Map.min_binding_opt vs with + | None -> + if Q.is_zero c then Sat None else Unsat + | Some (x, coeff) -> + let d = Q.neg coeff in + let vs' = + Var.Map.fold + (fun v' coeff' vs' -> + if Var.equal v' x then vs' else Var.Map.add v' (Q.div coeff' d) vs' ) + vs Var.Map.empty + in + let c' = Q.div c d in + Sat (Some (x, (vs', c'))) + + + let solve_eq l1 l2 = solve_eq_zero (subtract l1 l2) + + let of_var v = (Var.Map.singleton v Q.one, Q.zero) + + let of_q q = (Var.Map.empty, q) + + let of_intlit i = IntLit.to_big_int i |> Q.of_bigint |> of_q + + let of_operand = function AbstractValueOperand v -> of_var v | LiteralOperand i -> of_intlit i + + let get_as_const (vs, c) = if Var.Map.is_empty vs then Some c else None + + let get_as_var (vs, c) = + if Q.is_zero c then + match Var.Map.is_singleton_or_more vs with + | Singleton (x, cx) when Q.is_one cx -> + Some x + | _ -> + None + else None + + + let has_var x (vs, _) = Var.Map.mem x vs + + let subst x y ((vs, c) as l) = + match Var.Map.find_opt x vs with + | None -> + l + | Some cx -> + let vs' = Var.Map.remove x vs |> Var.Map.add y cx in + (vs', c) + + + let subst_vars ~f (vs, c) = Var.Map.fold (fun v q l -> mult q (f v) |> add l) vs (Var.Map.empty, c) + + let fold_map_variables (vs_foreign, c) ~init ~f = + let acc_f, vs = + Var.Map.fold + (fun v_foreign q0 (acc_f, vs) -> + let acc_f, v = f acc_f v_foreign in + let vs = + match Var.Map.find_opt v vs with + | None -> + Var.Map.add v q0 vs + | Some q -> + let q' = Q.add q q0 in + if Q.is_zero q' then Var.Map.remove v vs else Var.Map.add v q vs + in + (acc_f, vs) ) + vs_foreign (init, Var.Map.empty) + in + (acc_f, (vs, c)) + + + let get_variables (vs, _) = Var.Map.to_seq vs |> Seq.map fst +end + (** Expressive term structure to be able to express all of SIL, but the main smarts of the formulas are for the equality between variables and linear arithmetic subsets. Terms (and atoms, below) are kept as a last-resort for when outside that fragment. *) @@ -463,6 +634,48 @@ module Term = struct (* [t ∨ false = t] *) t1 | _ -> t + + + (** more or less syntactic attempt at detecting when an arbitrary term is a linear formula; call + {!Atom.eval_term} first for best results *) + let rec to_lin_arith t = + (* NOTE: don't duplicate simplifications between here and {!Atom.eval_term} *) + let open IOption.Let_syntax in + match t with + | Var v -> + Some (LinArith.of_var v) + | Const c -> + Some (LinArith.of_q c) + | Minus t -> + let+ l = to_lin_arith t in + LinArith.minus l + | Add (t1, t2) -> + let* l1 = to_lin_arith t1 in + let+ l2 = to_lin_arith t2 in + LinArith.add l1 l2 + | Mult (Const c, t) | Mult (t, Const c) -> + let+ l = to_lin_arith t in + LinArith.mult c l + | Div (t, Const c) when Q.is_not_zero c -> + let+ l = to_lin_arith t in + LinArith.mult (Q.inv c) l + | Mult _ + | Div _ + | Mod _ + | BitNot _ + | BitAnd _ + | BitOr _ + | BitShiftLeft _ + | BitShiftRight _ + | BitXor _ + | Not _ + | And _ + | Or _ + | LessThan _ + | LessEqual _ + | Equal _ + | NotEqual _ -> + None end (** Basically boolean terms, used to build the part of a formula that is not equalities between @@ -614,10 +827,6 @@ module Atom = struct end) end -(** "normalized" is not to be taken too seriously, it just means *some* normalization was applied - that could result in discovering something is unsatisfiable *) -type 'a normalized = Unsat | Sat of 'a - module SatUnsatMonad = struct let map_normalized f norm = match norm with Unsat -> Unsat | Sat phi -> Sat (f phi) @@ -632,213 +841,6 @@ module SatUnsatMonad = struct let ( let* ) phi f = bind_normalized f phi end -(** Linear Arithmetic*) -module LinArith : sig - (** linear combination of variables, eg [2·x + 3/4·y + 12] *) - type t - - val pp : (F.formatter -> Var.t -> unit) -> F.formatter -> t -> unit - - val is_zero : t -> bool - - val add : t -> t -> t - - val minus : t -> t - - val subtract : t -> t -> t - - val solve_eq : t -> t -> (Var.t * t) option normalized - (** [solve_eq l1 l2] is [Sat (Some (x, l))] if [l1=l2 <=> x=l], [Sat None] if [l1 = l2] is always - true, and [Unsat] if it is always false *) - - val of_var : Var.t -> t - - val of_intlit : IntLit.t -> t - - val of_operand : operand -> t - - val of_term : Term.t -> t option - (** more or less syntactic attempt at detecting when an arbitrary term is a linear formula; call - {!Atom.eval_term} first for best results *) - - val get_as_const : t -> Q.t option - (** [get_as_const l] is [Some c] if [l=c], else [None] *) - - val get_as_var : t -> Var.t option - (** [get_as_var l] is [Some x] if [l=x], else [None] *) - - val has_var : Var.t -> t -> bool - - val subst : Var.t -> Var.t -> t -> t - - val subst_vars : f:(Var.t -> t) -> t -> t - - val get_variables : t -> Var.t Seq.t - - val fold_map_variables : t -> init:'a -> f:('a -> Var.t -> 'a * Var.t) -> 'a * t -end = struct - (** invariant: the representation is always "canonical": coefficients cannot be [Q.zero] *) - type t = Q.t Var.Map.t * Q.t - - let pp pp_var fmt (vs, c) = - if Var.Map.is_empty vs then Q.pp_print fmt c - else - let pp_c fmt c = - if Q.is_zero c then () - else - let plusminus, c_pos = if Q.geq c Q.zero then ('+', c) else ('-', Q.neg c) in - F.fprintf fmt " %c%a" plusminus Q.pp_print c_pos - in - let pp_coeff fmt q = - if Q.is_one q then () - else if Q.is_minus_one q then F.pp_print_string fmt "-" - else F.fprintf fmt "%a·" Q.pp_print q - in - let pp_vs fmt vs = - Pp.collection ~sep:" + " - ~fold:(IContainer.fold_of_pervasives_map_fold Var.Map.fold) - ~pp_item:(fun fmt (v, q) -> F.fprintf fmt "%a%a" pp_coeff q pp_var v) - fmt vs - in - F.fprintf fmt "@[%a%a@]" pp_vs vs pp_c c - - - let add (vs1, c1) (vs2, c2) = - ( Var.Map.union - (fun _v c1 c2 -> - let c = Q.add c1 c2 in - if Q.is_zero c then None else Some c ) - vs1 vs2 - , Q.add c1 c2 ) - - - let minus (vs, c) = (Var.Map.map (fun c -> Q.neg c) vs, Q.neg c) - - let subtract l1 l2 = add l1 (minus l2) - - let zero = (Var.Map.empty, Q.zero) - - let is_zero (vs, c) = Q.is_zero c && Var.Map.is_empty vs - - let mult q ((vs, c) as l) = - if Q.is_zero q then (* needed for correction: coeffs cannot be zero *) zero - else if Q.is_one q then (* purely an optimisation *) l - else (Var.Map.map (fun c -> Q.mul q c) vs, Q.mul q c) - - - let solve_eq_zero (vs, c) = - match Var.Map.min_binding_opt vs with - | None -> - if Q.is_zero c then Sat None else Unsat - | Some (x, coeff) -> - let d = Q.neg coeff in - let vs' = - Var.Map.fold - (fun v' coeff' vs' -> - if Var.equal v' x then vs' else Var.Map.add v' (Q.div coeff' d) vs' ) - vs Var.Map.empty - in - let c' = Q.div c d in - Sat (Some (x, (vs', c'))) - - - let solve_eq l1 l2 = solve_eq_zero (subtract l1 l2) - - let of_var v = (Var.Map.singleton v Q.one, Q.zero) - - let of_q q = (Var.Map.empty, q) - - let of_intlit i = IntLit.to_big_int i |> Q.of_bigint |> of_q - - let of_operand = function AbstractValueOperand v -> of_var v | LiteralOperand i -> of_intlit i - - (* don't duplicate simplifications between here and {!Atom.eval_term} *) - let rec of_term (t : Term.t) = - let open IOption.Let_syntax in - match t with - | Var v -> - Some (of_var v) - | Const c -> - Some (of_q c) - | Minus t -> - let+ l = of_term t in - minus l - | Add (t1, t2) -> - let* l1 = of_term t1 in - let+ l2 = of_term t2 in - add l1 l2 - | Mult (Const c, t) | Mult (t, Const c) -> - let+ l = of_term t in - mult c l - | Div (t, Const c) when Q.is_not_zero c -> - let+ l = of_term t in - mult (Q.inv c) l - | Mult _ - | Div _ - | Mod _ - | BitNot _ - | BitAnd _ - | BitOr _ - | BitShiftLeft _ - | BitShiftRight _ - | BitXor _ - | Not _ - | And _ - | Or _ - | LessThan _ - | LessEqual _ - | Equal _ - | NotEqual _ -> - None - - - let get_as_const (vs, c) = if Var.Map.is_empty vs then Some c else None - - let get_as_var (vs, c) = - if Q.is_zero c then - match Var.Map.is_singleton_or_more vs with - | Singleton (x, cx) when Q.is_one cx -> - Some x - | _ -> - None - else None - - - let has_var x (vs, _) = Var.Map.mem x vs - - let subst x y ((vs, c) as l) = - match Var.Map.find_opt x vs with - | None -> - l - | Some cx -> - let vs' = Var.Map.remove x vs |> Var.Map.add y cx in - (vs', c) - - - let subst_vars ~f (vs, c) = Var.Map.fold (fun v q l -> mult q (f v) |> add l) vs (Var.Map.empty, c) - - let fold_map_variables (vs_foreign, c) ~init ~f = - let acc_f, vs = - Var.Map.fold - (fun v_foreign q0 (acc_f, vs) -> - let acc_f, v = f acc_f v_foreign in - let vs = - match Var.Map.find_opt v vs with - | None -> - Var.Map.add v q0 vs - | Some q -> - let q' = Q.add q q0 in - if Q.is_zero q' then Var.Map.remove v vs else Var.Map.add v q vs - in - (acc_f, vs) ) - vs_foreign (init, Var.Map.empty) - in - (acc_f, (vs, c)) - - - let get_variables (vs, _) = Var.Map.to_seq vs |> Seq.map fst -end - module VarUF = UnionFind.Make (struct @@ -1050,7 +1052,7 @@ end = struct | None -> acc | Some (Atom.Equal (t1, t2) as atom') -> ( - match Option.both (LinArith.of_term t1) (LinArith.of_term t2) with + match Option.both (Term.to_lin_arith t1) (Term.to_lin_arith t2) with | None -> Sat (phi, Atom.Set.add atom' facts) | Some (l1, l2) ->