[pulse] cleanup arithmetic

Summary:
Mostly cosmetic except for a change in [solve_eq] to try harder at
normalization (improves unit tests!). Add more comments and do minor
renamings.

Reviewed By: skcho

Differential Revision: D23243629

fbshipit-source-id: 55bdaf8a8
master
Jules Villard 4 years ago committed by Facebook GitHub Bot
parent 8b23fee8f8
commit 50b94dbbd6

@ -430,7 +430,6 @@ let of_post pdesc astate =
let astate = filter_for_summary astate in
let astate, live_addresses, _ = discard_unreachable astate in
let astate =
(* this also forces the lazy path condition for safe serialization *)
{astate with path_condition= PathCondition.simplify ~keep:live_addresses astate.path_condition}
in
invalidate_locals pdesc astate

@ -16,6 +16,7 @@ type operand = LiteralOperand of IntLit.t | AbstractValueOperand of Var.t
that could result in discovering something is unsatisfiable *)
type 'a normalized = Unsat | Sat of 'a
(** {!Q} from zarith with a few convenience functions added *)
module Q = struct
include Q
@ -38,7 +39,7 @@ module Q = struct
let to_bigint q = conv_protect Q.to_bigint q
end
(** Linear Arithmetic*)
(** Linear Arithmetic *)
module LinArith : sig
(** linear combination of variables, eg [2·x + 3/4·y + 12] *)
type t [@@deriving compare]
@ -77,9 +78,9 @@ module LinArith : sig
val get_variables : t -> Var.t Seq.t
val fold_map_variables : t -> init:'a -> f:('a -> Var.t -> 'a * subst_target) -> 'a * t
val fold_subst_variables : t -> init:'a -> f:('a -> Var.t -> 'a * subst_target) -> 'a * t
val map_variables : t -> f:(Var.t -> subst_target) -> t
val subst_variables : t -> f:(Var.t -> subst_target) -> t
end = struct
(** invariant: the representation is always "canonical": coefficients cannot be [Q.zero] *)
type t = Q.t Var.Map.t * Q.t [@@deriving compare]
@ -179,7 +180,7 @@ end = struct
let of_subst_target = function QSubst q -> of_q q | VarSubst v -> of_var v | LinSubst l -> l
let fold_map_variables ((vs_foreign, c) as l0) ~init ~f =
let fold_subst_variables ((vs_foreign, c) as l0) ~init ~f =
let changed = ref false in
let acc_f, l' =
Var.Map.fold
@ -194,7 +195,7 @@ end = struct
(acc_f, l')
let map_variables l ~f = fold_map_variables l ~init:() ~f:(fun () v -> ((), f v)) |> snd
let subst_variables l ~f = fold_subst_variables l ~init:() ~f:(fun () v -> ((), f v)) |> snd
let get_variables (vs, _) = Var.Map.to_seq vs |> Seq.map fst
end
@ -393,32 +394,13 @@ module Term = struct
let t' =
if phys_equal t_not t_not' then t
else
match t with
match[@warning "-8"] t with
| Minus _ ->
Minus t_not'
| BitNot _ ->
BitNot t_not'
| Not _ ->
Not t_not'
| Var _
| Const _
| Linear _
| Add _
| Mult _
| Div _
| Mod _
| BitAnd _
| BitOr _
| BitShiftLeft _
| BitShiftRight _
| BitXor _
| And _
| Or _
| LessThan _
| LessEqual _
| Equal _
| NotEqual _ ->
assert false
in
(acc, t')
| Add (t1, t2)
@ -441,7 +423,7 @@ module Term = struct
let t' =
if phys_equal t1 t1' && phys_equal t2 t2' then t
else
match t with
match[@warning "-8"] t with
| Add _ ->
Add (t1', t2')
| Mult _ ->
@ -472,8 +454,6 @@ module Term = struct
Equal (t1', t2')
| NotEqual _ ->
NotEqual (t1', t2')
| Var _ | Const _ | Linear _ | Minus _ | BitNot _ | Not _ ->
assert false
in
(acc, t')
@ -482,27 +462,27 @@ module Term = struct
fold_map_direct_subterms t ~init:() ~f:(fun () t' -> ((), f t')) |> snd
let rec fold_map_variables t ~init ~f =
let rec fold_subst_variables t ~init ~f =
match t with
| Var v ->
let acc, op = f init v in
let t' = match op with VarSubst v' when Var.equal v v' -> t | _ -> of_subst_target op in
(acc, t')
| Linear l ->
let acc, l' = LinArith.fold_map_variables l ~init ~f in
let acc, l' = LinArith.fold_subst_variables l ~init ~f in
let t' = if phys_equal l l' then t else Linear l' in
(acc, t')
| _ ->
fold_map_direct_subterms t ~init ~f:(fun acc t' -> fold_map_variables t' ~init:acc ~f)
fold_map_direct_subterms t ~init ~f:(fun acc t' -> fold_subst_variables t' ~init:acc ~f)
let fold_variables t ~init ~f =
fold_map_variables t ~init ~f:(fun acc v -> (f acc v, VarSubst v)) |> fst
fold_subst_variables t ~init ~f:(fun acc v -> (f acc v, VarSubst v)) |> fst
let iter_variables t ~f = fold_variables t ~init:() ~f:(fun () v -> f v)
let map_variables t ~f = fold_map_variables t ~init:() ~f:(fun () v -> ((), f v)) |> snd
let subst_variables t ~f = fold_subst_variables t ~init:() ~f:(fun () v -> ((), f v)) |> snd
let has_var_notin vars t =
Container.exists t ~iter:iter_variables ~f:(fun v -> not (Var.Set.mem v vars))
@ -934,10 +914,10 @@ module Atom = struct
t
let eval (atom : t) = map_terms atom ~f:eval_term |> eval_atom
let eval atom = map_terms atom ~f:eval_term |> eval_atom
let fold_map_variables a ~init ~f =
fold_map_terms a ~init ~f:(fun acc t -> Term.fold_map_variables t ~init:acc ~f)
let fold_subst_variables a ~init ~f =
fold_map_terms a ~init ~f:(fun acc t -> Term.fold_subst_variables t ~init:acc ~f)
let has_var_notin vars atom =
@ -1052,7 +1032,7 @@ end = struct
(** substitute vars in [l] *once* with their linear form to discover more simplification
opportunities *)
let apply phi l =
LinArith.map_variables l ~f:(fun v ->
LinArith.subst_variables l ~f:(fun v ->
let repr = (get_repr phi v :> Var.t) in
match Var.Map.find_opt repr phi.linear_eqs with
| None ->
@ -1061,29 +1041,35 @@ end = struct
LinSubst l' )
let rec solve_eq ~fuel t1 t2 phi =
LinArith.solve_eq t1 t2
>>= function None -> Sat phi | Some (x, l) -> merge_var_linarith ~fuel x l phi
and merge_var_linarith ~fuel v l phi =
let v = (get_repr phi v :> Var.t) in
let l = apply phi l in
match LinArith.get_as_var l with
| Some v' ->
merge_vars ~fuel (v :> Var.t) v' phi
| None -> (
match Var.Map.find_opt (v :> Var.t) phi.linear_eqs with
| None ->
(* this is probably dodgy as nothing guarantees that [l] does not mention [v] *)
Sat {phi with linear_eqs= Var.Map.add (v :> Var.t) l phi.linear_eqs}
| Some l' ->
(* This is the only step that consumes fuel: discovering an equality [l = l']: because we
do not record these anywhere (except when there consequence can be recorded as [y =
l''] or [y = y'], we could potentially discover the same equality over and over and
diverge otherwise *)
if fuel > 0 then solve_eq ~fuel:(fuel - 1) l l' phi
else (* [fuel = 0]: give up simplifying further for fear of diverging *) Sat phi )
let rec solve_normalized_eq ~fuel l1 l2 phi =
LinArith.solve_eq l1 l2
>>= function
| None ->
Sat phi
| Some (v, l) -> (
match LinArith.get_as_var l with
| Some v' ->
merge_vars ~fuel (v :> Var.t) v' phi
| None -> (
match Var.Map.find_opt (v :> Var.t) phi.linear_eqs with
| None ->
(* this can break the (as a result non-)invariant that variables in the domain of
[linear_eqs] do not appear in the range of [linear_eqs] *)
Sat {phi with linear_eqs= Var.Map.add (v :> Var.t) l phi.linear_eqs}
| Some l' ->
(* This is the only step that consumes fuel: discovering an equality [l = l']: because we
do not record these anywhere (except when there consequence can be recorded as [y =
l''] or [y = y'], we could potentially discover the same equality over and over and
diverge otherwise. Or could we? *)
(* [l'] is possibly not normalized w.r.t. the current [phi] so take this opportunity to
normalize it *)
if fuel > 0 then (
L.d_printfln "Consuming fuel solving linear equality (from %d)" fuel ;
solve_normalized_eq ~fuel:(fuel - 1) l (apply phi l') phi )
else (
(* [fuel = 0]: give up simplifying further for fear of diverging *)
L.d_printfln "Ran out of fuel solving linear equality" ;
Sat phi ) ) )
and merge_vars ~fuel v1 v2 phi =
@ -1097,9 +1083,11 @@ end = struct
(* new equality [v_old = v_new]: we need to update a potential [v_old = l] to be [v_new =
l], and if [v_new = l'] was known we need to also explore the consequences of [l = l'] *)
(* NOTE: we try to maintain the invariant that for all [x=l] in [phi.linear_eqs], [x ∉
vars(l)], because other Shostak techniques do so (in fact, they impose a stricter
condition that the domain and the range of [phi.linear_eqs] mention distinct variables),
but some other steps of the reasoning may break that. Not sure why we bother. *)
vars(l)]. We also try to stay as close as possible (without going back and re-normalizing
every linear equality every time we learn new equalities) to the invariant that the
domain and the range of [phi.linear_eqs] mention distinct variables. This is to speed up
normalization steps: when the stronger invariant holds we can normalize in one step (in
[normalize_linear_eqs]). *)
let v_new = (v_new :> Var.t) in
let phi, l_new =
match Var.Map.find_opt v_new phi.linear_eqs with
@ -1131,15 +1119,17 @@ end = struct
(* no need to consume fuel here as we can only go through this branch finitely many
times because there are finitely many variables in a given formula *)
(* TODO: we may want to keep the "simpler" representative for [v_new] between [l1] and [l2] *)
solve_eq ~fuel l1 l2 phi )
solve_normalized_eq ~fuel l1 l2 phi )
(** an arbitrary value *)
let fuel = 5
let base_fuel = 5
let solve_eq t1 t2 phi = solve_normalized_eq ~fuel:base_fuel (apply phi t1) (apply phi t2) phi
let and_var_linarith v l phi = solve_eq ~fuel l (LinArith.of_var v) phi
let and_var_linarith v l phi = solve_eq l (LinArith.of_var v) phi
let and_var_var v1 v2 phi = merge_vars ~fuel v1 v2 phi
let and_var_var v1 v2 phi = merge_vars ~fuel:base_fuel v1 v2 phi
let rec normalize_linear_eqs ~fuel phi0 =
let* changed, phi' =
@ -1155,18 +1145,18 @@ end = struct
in
if changed then
if fuel > 0 then (
L.d_printfln "going around one more time normalizing the linear equalities" ;
(* do another pass if we can affort it *)
(* do another pass if we can afford it *)
L.d_printfln "consuming fuel normalizing linear equalities (from %d)" fuel ;
normalize_linear_eqs ~fuel:(fuel - 1) phi' )
else (
L.d_printfln "ran out of fuel normalizing the linear equalities" ;
L.d_printfln "ran out of fuel normalizing linear equalities" ;
Sat phi' )
else Sat phi0
let normalize_atom phi (atom : Atom.t) =
let normalize_term phi t =
Term.map_variables t ~f:(fun v ->
Term.subst_variables t ~f:(fun v ->
let v_canon = (VarUF.find phi.var_eqs v :> Var.t) in
match Var.Map.find_opt v_canon phi.linear_eqs with
| None ->
@ -1187,7 +1177,7 @@ end = struct
(* NOTE: {!normalize_atom} calls {!Atom.eval}, which normalizes linear equalities so
they end up only on one side, hence only this match case is needed to detect linear
equalities *)
solve_eq ~fuel l (LinArith.of_q c) phi
solve_eq l (LinArith.of_q c) phi
| Some atom' ->
Sat {phi with atoms= Atom.Set.add atom' phi.atoms}
@ -1200,7 +1190,7 @@ end = struct
and_atom atom phi )
let normalize phi = normalize_linear_eqs ~fuel phi >>= normalize_atoms
let normalize phi = normalize_linear_eqs ~fuel:base_fuel phi >>= normalize_atoms
end
let and_mk_atom mk_atom op1 op2 phi =
@ -1232,7 +1222,11 @@ let prune_binop ~negated (bop : Binop.t) x y phi =
let normalize phi = Normalizer.normalize phi
(** translate each variable in [phi_foreign] according to [f] then incorporate each fact into [phi0] *)
let and_fold_map_variables phi0 ~up_to_f:phi_foreign ~init ~f =
let and_fold_subst_variables phi0 ~up_to_f:phi_foreign ~init ~f:f_var =
let f_subst acc v =
let acc', v' = f_var acc v in
(acc', VarSubst v')
in
(* propagate [Unsat] faster using this exception *)
let exception Contradiction in
let sat_value_exn (norm : 'a normalized) =
@ -1241,33 +1235,25 @@ let and_fold_map_variables phi0 ~up_to_f:phi_foreign ~init ~f =
let and_var_eqs acc =
VarUF.fold_congruences phi_foreign.var_eqs ~init:acc
~f:(fun (acc_f, phi) (repr_foreign, vs_foreign) ->
let acc_f, repr = f acc_f (repr_foreign :> Var.t) in
let acc_f, repr = f_var acc_f (repr_foreign :> Var.t) in
IContainer.fold_of_pervasives_set_fold Var.Set.fold vs_foreign ~init:(acc_f, phi)
~f:(fun (acc_f, phi) v_foreign ->
let acc_f, v = f acc_f v_foreign in
let acc_f, v = f_var acc_f v_foreign in
let phi = Normalizer.and_var_var repr v phi |> sat_value_exn in
(acc_f, phi) ) )
in
let and_linear_eqs acc =
IContainer.fold_of_pervasives_map_fold Var.Map.fold phi_foreign.linear_eqs ~init:acc
~f:(fun (acc_f, phi) (v_foreign, l_foreign) ->
let acc_f, v = f acc_f v_foreign in
let acc_f, l =
LinArith.fold_map_variables l_foreign ~init:acc_f ~f:(fun acc v ->
let acc', v' = f acc v in
(acc', VarSubst v') )
in
let acc_f, v = f_var acc_f v_foreign in
let acc_f, l = LinArith.fold_subst_variables l_foreign ~init:acc_f ~f:f_subst in
let phi = Normalizer.and_var_linarith v l phi |> sat_value_exn in
(acc_f, phi) )
in
let and_atoms acc =
IContainer.fold_of_pervasives_set_fold Atom.Set.fold phi_foreign.atoms ~init:acc
~f:(fun (acc_f, phi) atom_foreign ->
let acc_f, atom =
Atom.fold_map_variables atom_foreign ~init:acc_f ~f:(fun acc_f v ->
let acc_f, v' = f acc_f v in
(acc_f, VarSubst v') )
in
let acc_f, atom = Atom.fold_subst_variables atom_foreign ~init:acc_f ~f:f_subst in
let phi = Normalizer.and_atom atom phi |> sat_value_exn in
(acc_f, phi) )
in

@ -51,7 +51,7 @@ val normalize : t -> t normalized
val simplify : keep:Var.Set.t -> t -> t normalized
val and_fold_map_variables :
val and_fold_subst_variables :
t -> up_to_f:t -> init:'acc -> f:('acc -> Var.t -> 'acc * Var.t) -> ('acc * t) normalized
val is_known_zero : t -> Var.t -> bool

@ -186,7 +186,7 @@ let and_citvs_callee subst citvs_caller citvs_callee =
let and_formula_callee subst formula_caller ~callee:formula_callee =
(* need to translate callee variables to make sense for the caller, thereby possibly extending
the current substitution *)
Formula.and_fold_map_variables formula_caller ~up_to_f:formula_callee ~f:subst_find_or_new
Formula.and_fold_subst_variables formula_caller ~up_to_f:formula_callee ~f:subst_find_or_new
~init:subst

@ -221,7 +221,7 @@ let%test_module "variable elimination" =
(* should keep most of this or realize that [w = z] hence this boils down to [z+1 = 0] *)
let%expect_test _ =
simplify ~keep:[y_var; z_var] (x = y + z && w = x - y && v = w + i 1 && v = i 0) ;
[%expect {|x=v6 && x = y -1 z = -1 && true (no atoms)|}]
[%expect {|x=v6 z=w=v7 && x = y -1 z = -1 && true (no atoms)|}]
let%expect_test _ =
simplify ~keep:[x_var; y_var] (x = y + z && w + x + y = i 0 && v = w + i 1) ;

Loading…
Cancel
Save