[pulse] cleanup arithmetic

Summary:
Mostly cosmetic except for a change in [solve_eq] to try harder at
normalization (improves unit tests!). Add more comments and do minor
renamings.

Reviewed By: skcho

Differential Revision: D23243629

fbshipit-source-id: 55bdaf8a8
master
Jules Villard 4 years ago committed by Facebook GitHub Bot
parent 8b23fee8f8
commit 50b94dbbd6

@ -430,7 +430,6 @@ let of_post pdesc astate =
let astate = filter_for_summary astate in let astate = filter_for_summary astate in
let astate, live_addresses, _ = discard_unreachable astate in let astate, live_addresses, _ = discard_unreachable astate in
let astate = let astate =
(* this also forces the lazy path condition for safe serialization *)
{astate with path_condition= PathCondition.simplify ~keep:live_addresses astate.path_condition} {astate with path_condition= PathCondition.simplify ~keep:live_addresses astate.path_condition}
in in
invalidate_locals pdesc astate invalidate_locals pdesc astate

@ -16,6 +16,7 @@ type operand = LiteralOperand of IntLit.t | AbstractValueOperand of Var.t
that could result in discovering something is unsatisfiable *) that could result in discovering something is unsatisfiable *)
type 'a normalized = Unsat | Sat of 'a type 'a normalized = Unsat | Sat of 'a
(** {!Q} from zarith with a few convenience functions added *)
module Q = struct module Q = struct
include Q include Q
@ -38,7 +39,7 @@ module Q = struct
let to_bigint q = conv_protect Q.to_bigint q let to_bigint q = conv_protect Q.to_bigint q
end end
(** Linear Arithmetic*) (** Linear Arithmetic *)
module LinArith : sig module LinArith : sig
(** linear combination of variables, eg [2·x + 3/4·y + 12] *) (** linear combination of variables, eg [2·x + 3/4·y + 12] *)
type t [@@deriving compare] type t [@@deriving compare]
@ -77,9 +78,9 @@ module LinArith : sig
val get_variables : t -> Var.t Seq.t val get_variables : t -> Var.t Seq.t
val fold_map_variables : t -> init:'a -> f:('a -> Var.t -> 'a * subst_target) -> 'a * t val fold_subst_variables : t -> init:'a -> f:('a -> Var.t -> 'a * subst_target) -> 'a * t
val map_variables : t -> f:(Var.t -> subst_target) -> t val subst_variables : t -> f:(Var.t -> subst_target) -> t
end = struct end = struct
(** invariant: the representation is always "canonical": coefficients cannot be [Q.zero] *) (** invariant: the representation is always "canonical": coefficients cannot be [Q.zero] *)
type t = Q.t Var.Map.t * Q.t [@@deriving compare] type t = Q.t Var.Map.t * Q.t [@@deriving compare]
@ -179,7 +180,7 @@ end = struct
let of_subst_target = function QSubst q -> of_q q | VarSubst v -> of_var v | LinSubst l -> l let of_subst_target = function QSubst q -> of_q q | VarSubst v -> of_var v | LinSubst l -> l
let fold_map_variables ((vs_foreign, c) as l0) ~init ~f = let fold_subst_variables ((vs_foreign, c) as l0) ~init ~f =
let changed = ref false in let changed = ref false in
let acc_f, l' = let acc_f, l' =
Var.Map.fold Var.Map.fold
@ -194,7 +195,7 @@ end = struct
(acc_f, l') (acc_f, l')
let map_variables l ~f = fold_map_variables l ~init:() ~f:(fun () v -> ((), f v)) |> snd let subst_variables l ~f = fold_subst_variables l ~init:() ~f:(fun () v -> ((), f v)) |> snd
let get_variables (vs, _) = Var.Map.to_seq vs |> Seq.map fst let get_variables (vs, _) = Var.Map.to_seq vs |> Seq.map fst
end end
@ -393,32 +394,13 @@ module Term = struct
let t' = let t' =
if phys_equal t_not t_not' then t if phys_equal t_not t_not' then t
else else
match t with match[@warning "-8"] t with
| Minus _ -> | Minus _ ->
Minus t_not' Minus t_not'
| BitNot _ -> | BitNot _ ->
BitNot t_not' BitNot t_not'
| Not _ -> | Not _ ->
Not t_not' Not t_not'
| Var _
| Const _
| Linear _
| Add _
| Mult _
| Div _
| Mod _
| BitAnd _
| BitOr _
| BitShiftLeft _
| BitShiftRight _
| BitXor _
| And _
| Or _
| LessThan _
| LessEqual _
| Equal _
| NotEqual _ ->
assert false
in in
(acc, t') (acc, t')
| Add (t1, t2) | Add (t1, t2)
@ -441,7 +423,7 @@ module Term = struct
let t' = let t' =
if phys_equal t1 t1' && phys_equal t2 t2' then t if phys_equal t1 t1' && phys_equal t2 t2' then t
else else
match t with match[@warning "-8"] t with
| Add _ -> | Add _ ->
Add (t1', t2') Add (t1', t2')
| Mult _ -> | Mult _ ->
@ -472,8 +454,6 @@ module Term = struct
Equal (t1', t2') Equal (t1', t2')
| NotEqual _ -> | NotEqual _ ->
NotEqual (t1', t2') NotEqual (t1', t2')
| Var _ | Const _ | Linear _ | Minus _ | BitNot _ | Not _ ->
assert false
in in
(acc, t') (acc, t')
@ -482,27 +462,27 @@ module Term = struct
fold_map_direct_subterms t ~init:() ~f:(fun () t' -> ((), f t')) |> snd fold_map_direct_subterms t ~init:() ~f:(fun () t' -> ((), f t')) |> snd
let rec fold_map_variables t ~init ~f = let rec fold_subst_variables t ~init ~f =
match t with match t with
| Var v -> | Var v ->
let acc, op = f init v in let acc, op = f init v in
let t' = match op with VarSubst v' when Var.equal v v' -> t | _ -> of_subst_target op in let t' = match op with VarSubst v' when Var.equal v v' -> t | _ -> of_subst_target op in
(acc, t') (acc, t')
| Linear l -> | Linear l ->
let acc, l' = LinArith.fold_map_variables l ~init ~f in let acc, l' = LinArith.fold_subst_variables l ~init ~f in
let t' = if phys_equal l l' then t else Linear l' in let t' = if phys_equal l l' then t else Linear l' in
(acc, t') (acc, t')
| _ -> | _ ->
fold_map_direct_subterms t ~init ~f:(fun acc t' -> fold_map_variables t' ~init:acc ~f) fold_map_direct_subterms t ~init ~f:(fun acc t' -> fold_subst_variables t' ~init:acc ~f)
let fold_variables t ~init ~f = let fold_variables t ~init ~f =
fold_map_variables t ~init ~f:(fun acc v -> (f acc v, VarSubst v)) |> fst fold_subst_variables t ~init ~f:(fun acc v -> (f acc v, VarSubst v)) |> fst
let iter_variables t ~f = fold_variables t ~init:() ~f:(fun () v -> f v) let iter_variables t ~f = fold_variables t ~init:() ~f:(fun () v -> f v)
let map_variables t ~f = fold_map_variables t ~init:() ~f:(fun () v -> ((), f v)) |> snd let subst_variables t ~f = fold_subst_variables t ~init:() ~f:(fun () v -> ((), f v)) |> snd
let has_var_notin vars t = let has_var_notin vars t =
Container.exists t ~iter:iter_variables ~f:(fun v -> not (Var.Set.mem v vars)) Container.exists t ~iter:iter_variables ~f:(fun v -> not (Var.Set.mem v vars))
@ -934,10 +914,10 @@ module Atom = struct
t t
let eval (atom : t) = map_terms atom ~f:eval_term |> eval_atom let eval atom = map_terms atom ~f:eval_term |> eval_atom
let fold_map_variables a ~init ~f = let fold_subst_variables a ~init ~f =
fold_map_terms a ~init ~f:(fun acc t -> Term.fold_map_variables t ~init:acc ~f) fold_map_terms a ~init ~f:(fun acc t -> Term.fold_subst_variables t ~init:acc ~f)
let has_var_notin vars atom = let has_var_notin vars atom =
@ -1052,7 +1032,7 @@ end = struct
(** substitute vars in [l] *once* with their linear form to discover more simplification (** substitute vars in [l] *once* with their linear form to discover more simplification
opportunities *) opportunities *)
let apply phi l = let apply phi l =
LinArith.map_variables l ~f:(fun v -> LinArith.subst_variables l ~f:(fun v ->
let repr = (get_repr phi v :> Var.t) in let repr = (get_repr phi v :> Var.t) in
match Var.Map.find_opt repr phi.linear_eqs with match Var.Map.find_opt repr phi.linear_eqs with
| None -> | None ->
@ -1061,29 +1041,35 @@ end = struct
LinSubst l' ) LinSubst l' )
let rec solve_eq ~fuel t1 t2 phi = let rec solve_normalized_eq ~fuel l1 l2 phi =
LinArith.solve_eq t1 t2 LinArith.solve_eq l1 l2
>>= function None -> Sat phi | Some (x, l) -> merge_var_linarith ~fuel x l phi >>= function
| None ->
Sat phi
and merge_var_linarith ~fuel v l phi = | Some (v, l) -> (
let v = (get_repr phi v :> Var.t) in match LinArith.get_as_var l with
let l = apply phi l in | Some v' ->
match LinArith.get_as_var l with merge_vars ~fuel (v :> Var.t) v' phi
| Some v' -> | None -> (
merge_vars ~fuel (v :> Var.t) v' phi match Var.Map.find_opt (v :> Var.t) phi.linear_eqs with
| None -> ( | None ->
match Var.Map.find_opt (v :> Var.t) phi.linear_eqs with (* this can break the (as a result non-)invariant that variables in the domain of
| None -> [linear_eqs] do not appear in the range of [linear_eqs] *)
(* this is probably dodgy as nothing guarantees that [l] does not mention [v] *) Sat {phi with linear_eqs= Var.Map.add (v :> Var.t) l phi.linear_eqs}
Sat {phi with linear_eqs= Var.Map.add (v :> Var.t) l phi.linear_eqs} | Some l' ->
| Some l' -> (* This is the only step that consumes fuel: discovering an equality [l = l']: because we
(* This is the only step that consumes fuel: discovering an equality [l = l']: because we do not record these anywhere (except when there consequence can be recorded as [y =
do not record these anywhere (except when there consequence can be recorded as [y = l''] or [y = y'], we could potentially discover the same equality over and over and
l''] or [y = y'], we could potentially discover the same equality over and over and diverge otherwise. Or could we? *)
diverge otherwise *) (* [l'] is possibly not normalized w.r.t. the current [phi] so take this opportunity to
if fuel > 0 then solve_eq ~fuel:(fuel - 1) l l' phi normalize it *)
else (* [fuel = 0]: give up simplifying further for fear of diverging *) Sat phi ) if fuel > 0 then (
L.d_printfln "Consuming fuel solving linear equality (from %d)" fuel ;
solve_normalized_eq ~fuel:(fuel - 1) l (apply phi l') phi )
else (
(* [fuel = 0]: give up simplifying further for fear of diverging *)
L.d_printfln "Ran out of fuel solving linear equality" ;
Sat phi ) ) )
and merge_vars ~fuel v1 v2 phi = and merge_vars ~fuel v1 v2 phi =
@ -1097,9 +1083,11 @@ end = struct
(* new equality [v_old = v_new]: we need to update a potential [v_old = l] to be [v_new = (* new equality [v_old = v_new]: we need to update a potential [v_old = l] to be [v_new =
l], and if [v_new = l'] was known we need to also explore the consequences of [l = l'] *) l], and if [v_new = l'] was known we need to also explore the consequences of [l = l'] *)
(* NOTE: we try to maintain the invariant that for all [x=l] in [phi.linear_eqs], [x ∉ (* NOTE: we try to maintain the invariant that for all [x=l] in [phi.linear_eqs], [x ∉
vars(l)], because other Shostak techniques do so (in fact, they impose a stricter vars(l)]. We also try to stay as close as possible (without going back and re-normalizing
condition that the domain and the range of [phi.linear_eqs] mention distinct variables), every linear equality every time we learn new equalities) to the invariant that the
but some other steps of the reasoning may break that. Not sure why we bother. *) domain and the range of [phi.linear_eqs] mention distinct variables. This is to speed up
normalization steps: when the stronger invariant holds we can normalize in one step (in
[normalize_linear_eqs]). *)
let v_new = (v_new :> Var.t) in let v_new = (v_new :> Var.t) in
let phi, l_new = let phi, l_new =
match Var.Map.find_opt v_new phi.linear_eqs with match Var.Map.find_opt v_new phi.linear_eqs with
@ -1131,15 +1119,17 @@ end = struct
(* no need to consume fuel here as we can only go through this branch finitely many (* no need to consume fuel here as we can only go through this branch finitely many
times because there are finitely many variables in a given formula *) times because there are finitely many variables in a given formula *)
(* TODO: we may want to keep the "simpler" representative for [v_new] between [l1] and [l2] *) (* TODO: we may want to keep the "simpler" representative for [v_new] between [l1] and [l2] *)
solve_eq ~fuel l1 l2 phi ) solve_normalized_eq ~fuel l1 l2 phi )
(** an arbitrary value *) (** an arbitrary value *)
let fuel = 5 let base_fuel = 5
let solve_eq t1 t2 phi = solve_normalized_eq ~fuel:base_fuel (apply phi t1) (apply phi t2) phi
let and_var_linarith v l phi = solve_eq ~fuel l (LinArith.of_var v) phi let and_var_linarith v l phi = solve_eq l (LinArith.of_var v) phi
let and_var_var v1 v2 phi = merge_vars ~fuel v1 v2 phi let and_var_var v1 v2 phi = merge_vars ~fuel:base_fuel v1 v2 phi
let rec normalize_linear_eqs ~fuel phi0 = let rec normalize_linear_eqs ~fuel phi0 =
let* changed, phi' = let* changed, phi' =
@ -1155,18 +1145,18 @@ end = struct
in in
if changed then if changed then
if fuel > 0 then ( if fuel > 0 then (
L.d_printfln "going around one more time normalizing the linear equalities" ; (* do another pass if we can afford it *)
(* do another pass if we can affort it *) L.d_printfln "consuming fuel normalizing linear equalities (from %d)" fuel ;
normalize_linear_eqs ~fuel:(fuel - 1) phi' ) normalize_linear_eqs ~fuel:(fuel - 1) phi' )
else ( else (
L.d_printfln "ran out of fuel normalizing the linear equalities" ; L.d_printfln "ran out of fuel normalizing linear equalities" ;
Sat phi' ) Sat phi' )
else Sat phi0 else Sat phi0
let normalize_atom phi (atom : Atom.t) = let normalize_atom phi (atom : Atom.t) =
let normalize_term phi t = let normalize_term phi t =
Term.map_variables t ~f:(fun v -> Term.subst_variables t ~f:(fun v ->
let v_canon = (VarUF.find phi.var_eqs v :> Var.t) in let v_canon = (VarUF.find phi.var_eqs v :> Var.t) in
match Var.Map.find_opt v_canon phi.linear_eqs with match Var.Map.find_opt v_canon phi.linear_eqs with
| None -> | None ->
@ -1187,7 +1177,7 @@ end = struct
(* NOTE: {!normalize_atom} calls {!Atom.eval}, which normalizes linear equalities so (* NOTE: {!normalize_atom} calls {!Atom.eval}, which normalizes linear equalities so
they end up only on one side, hence only this match case is needed to detect linear they end up only on one side, hence only this match case is needed to detect linear
equalities *) equalities *)
solve_eq ~fuel l (LinArith.of_q c) phi solve_eq l (LinArith.of_q c) phi
| Some atom' -> | Some atom' ->
Sat {phi with atoms= Atom.Set.add atom' phi.atoms} Sat {phi with atoms= Atom.Set.add atom' phi.atoms}
@ -1200,7 +1190,7 @@ end = struct
and_atom atom phi ) and_atom atom phi )
let normalize phi = normalize_linear_eqs ~fuel phi >>= normalize_atoms let normalize phi = normalize_linear_eqs ~fuel:base_fuel phi >>= normalize_atoms
end end
let and_mk_atom mk_atom op1 op2 phi = let and_mk_atom mk_atom op1 op2 phi =
@ -1232,7 +1222,11 @@ let prune_binop ~negated (bop : Binop.t) x y phi =
let normalize phi = Normalizer.normalize phi let normalize phi = Normalizer.normalize phi
(** translate each variable in [phi_foreign] according to [f] then incorporate each fact into [phi0] *) (** translate each variable in [phi_foreign] according to [f] then incorporate each fact into [phi0] *)
let and_fold_map_variables phi0 ~up_to_f:phi_foreign ~init ~f = let and_fold_subst_variables phi0 ~up_to_f:phi_foreign ~init ~f:f_var =
let f_subst acc v =
let acc', v' = f_var acc v in
(acc', VarSubst v')
in
(* propagate [Unsat] faster using this exception *) (* propagate [Unsat] faster using this exception *)
let exception Contradiction in let exception Contradiction in
let sat_value_exn (norm : 'a normalized) = let sat_value_exn (norm : 'a normalized) =
@ -1241,33 +1235,25 @@ let and_fold_map_variables phi0 ~up_to_f:phi_foreign ~init ~f =
let and_var_eqs acc = let and_var_eqs acc =
VarUF.fold_congruences phi_foreign.var_eqs ~init:acc VarUF.fold_congruences phi_foreign.var_eqs ~init:acc
~f:(fun (acc_f, phi) (repr_foreign, vs_foreign) -> ~f:(fun (acc_f, phi) (repr_foreign, vs_foreign) ->
let acc_f, repr = f acc_f (repr_foreign :> Var.t) in let acc_f, repr = f_var acc_f (repr_foreign :> Var.t) in
IContainer.fold_of_pervasives_set_fold Var.Set.fold vs_foreign ~init:(acc_f, phi) IContainer.fold_of_pervasives_set_fold Var.Set.fold vs_foreign ~init:(acc_f, phi)
~f:(fun (acc_f, phi) v_foreign -> ~f:(fun (acc_f, phi) v_foreign ->
let acc_f, v = f acc_f v_foreign in let acc_f, v = f_var acc_f v_foreign in
let phi = Normalizer.and_var_var repr v phi |> sat_value_exn in let phi = Normalizer.and_var_var repr v phi |> sat_value_exn in
(acc_f, phi) ) ) (acc_f, phi) ) )
in in
let and_linear_eqs acc = let and_linear_eqs acc =
IContainer.fold_of_pervasives_map_fold Var.Map.fold phi_foreign.linear_eqs ~init:acc IContainer.fold_of_pervasives_map_fold Var.Map.fold phi_foreign.linear_eqs ~init:acc
~f:(fun (acc_f, phi) (v_foreign, l_foreign) -> ~f:(fun (acc_f, phi) (v_foreign, l_foreign) ->
let acc_f, v = f acc_f v_foreign in let acc_f, v = f_var acc_f v_foreign in
let acc_f, l = let acc_f, l = LinArith.fold_subst_variables l_foreign ~init:acc_f ~f:f_subst in
LinArith.fold_map_variables l_foreign ~init:acc_f ~f:(fun acc v ->
let acc', v' = f acc v in
(acc', VarSubst v') )
in
let phi = Normalizer.and_var_linarith v l phi |> sat_value_exn in let phi = Normalizer.and_var_linarith v l phi |> sat_value_exn in
(acc_f, phi) ) (acc_f, phi) )
in in
let and_atoms acc = let and_atoms acc =
IContainer.fold_of_pervasives_set_fold Atom.Set.fold phi_foreign.atoms ~init:acc IContainer.fold_of_pervasives_set_fold Atom.Set.fold phi_foreign.atoms ~init:acc
~f:(fun (acc_f, phi) atom_foreign -> ~f:(fun (acc_f, phi) atom_foreign ->
let acc_f, atom = let acc_f, atom = Atom.fold_subst_variables atom_foreign ~init:acc_f ~f:f_subst in
Atom.fold_map_variables atom_foreign ~init:acc_f ~f:(fun acc_f v ->
let acc_f, v' = f acc_f v in
(acc_f, VarSubst v') )
in
let phi = Normalizer.and_atom atom phi |> sat_value_exn in let phi = Normalizer.and_atom atom phi |> sat_value_exn in
(acc_f, phi) ) (acc_f, phi) )
in in

@ -51,7 +51,7 @@ val normalize : t -> t normalized
val simplify : keep:Var.Set.t -> t -> t normalized val simplify : keep:Var.Set.t -> t -> t normalized
val and_fold_map_variables : val and_fold_subst_variables :
t -> up_to_f:t -> init:'acc -> f:('acc -> Var.t -> 'acc * Var.t) -> ('acc * t) normalized t -> up_to_f:t -> init:'acc -> f:('acc -> Var.t -> 'acc * Var.t) -> ('acc * t) normalized
val is_known_zero : t -> Var.t -> bool val is_known_zero : t -> Var.t -> bool

@ -186,7 +186,7 @@ let and_citvs_callee subst citvs_caller citvs_callee =
let and_formula_callee subst formula_caller ~callee:formula_callee = let and_formula_callee subst formula_caller ~callee:formula_callee =
(* need to translate callee variables to make sense for the caller, thereby possibly extending (* need to translate callee variables to make sense for the caller, thereby possibly extending
the current substitution *) the current substitution *)
Formula.and_fold_map_variables formula_caller ~up_to_f:formula_callee ~f:subst_find_or_new Formula.and_fold_subst_variables formula_caller ~up_to_f:formula_callee ~f:subst_find_or_new
~init:subst ~init:subst

@ -221,7 +221,7 @@ let%test_module "variable elimination" =
(* should keep most of this or realize that [w = z] hence this boils down to [z+1 = 0] *) (* should keep most of this or realize that [w = z] hence this boils down to [z+1 = 0] *)
let%expect_test _ = let%expect_test _ =
simplify ~keep:[y_var; z_var] (x = y + z && w = x - y && v = w + i 1 && v = i 0) ; simplify ~keep:[y_var; z_var] (x = y + z && w = x - y && v = w + i 1 && v = i 0) ;
[%expect {|x=v6 && x = y -1 z = -1 && true (no atoms)|}] [%expect {|x=v6 z=w=v7 && x = y -1 z = -1 && true (no atoms)|}]
let%expect_test _ = let%expect_test _ =
simplify ~keep:[x_var; y_var] (x = y + z && w + x + y = i 0 && v = w + i 1) ; simplify ~keep:[x_var; y_var] (x = y + z && w + x + y = i 0 && v = w + i 1) ;

Loading…
Cancel
Save