From b16e85d10da96b2e0c269037515d0cc018241a35 Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Mon, 2 Mar 2020 08:43:59 -0800 Subject: [PATCH] [sledge] Eliminate redundant existential quantifiers Summary: This diff changes `Sh.simplify` from a logically-weakening syntactic simplification to an equivalence-preserving rewrite. The implementation is based on `Equality.solve_for_vars` which is also used by `Solver` to witness existential variables. Reviewed By: jvillard Differential Revision: D20120274 fbshipit-source-id: 5e11659ea --- sledge/src/import/import.ml | 5 + sledge/src/import/import.mli | 3 + sledge/src/llair/term.ml | 1 + sledge/src/llair/term.mli | 1 + sledge/src/symbheap/equality.ml | 20 ++++ sledge/src/symbheap/equality.mli | 11 +- sledge/src/symbheap/sh.ml | 187 ++++++++++++++++++++++++++----- sledge/src/symbheap/sh_test.ml | 15 ++- 8 files changed, 212 insertions(+), 31 deletions(-) diff --git a/sledge/src/import/import.ml b/sledge/src/import/import.ml index d9a5be4d0..daa64eec5 100644 --- a/sledge/src/import/import.ml +++ b/sledge/src/import/import.ml @@ -222,6 +222,11 @@ module List = struct let map_preserving_phys_equal t ~f = map_preserving_phys_equal map t ~f + let rev_map_unzip xs ~f = + fold xs ~init:([], []) ~f:(fun (ys, zs) x -> + let y, z = f x in + (y :: ys, z :: zs) ) + let remove_exn ?(equal = phys_equal) xs x = let rec remove_ ys = function | [] -> raise Not_found diff --git a/sledge/src/import/import.mli b/sledge/src/import/import.mli index 57e6c5d87..467aed91d 100644 --- a/sledge/src/import/import.mli +++ b/sledge/src/import/import.mli @@ -176,6 +176,9 @@ module List : sig (** Like filter_map, but preserves [phys_equal] if [f] preserves [phys_equal] of every element. *) + val rev_map_unzip : 'a t -> f:('a -> 'b * 'c) -> 'b list * 'c list + (** [rev_map_unzip ~f xs] is [unzip (rev_map ~f xs)] but more efficient. *) + val remove_exn : ?equal:('a -> 'a -> bool) -> 'a list -> 'a -> 'a list (** Returns the input list without the first element [equal] to the argument, or raise [Not_found] if no such element exists. [equal] diff --git a/sledge/src/llair/term.ml b/sledge/src/llair/term.ml index 5273319e6..edfdc1a23 100644 --- a/sledge/src/llair/term.ml +++ b/sledge/src/llair/term.ml @@ -338,6 +338,7 @@ module Var = struct let of_option = Option.fold ~f:Set.add ~init:empty let of_list = Set.of_list (module T) let of_vector = Set.of_vector (module T) + let union_list = Set.union_list (module T) end let invariant x = diff --git a/sledge/src/llair/term.mli b/sledge/src/llair/term.mli index f2444415e..49977f386 100644 --- a/sledge/src/llair/term.mli +++ b/sledge/src/llair/term.mli @@ -107,6 +107,7 @@ module Var : sig val of_option : var option -> t val of_list : var list -> t val of_vector : var vector -> t + val union_list : t list -> t end module Map : sig diff --git a/sledge/src/symbheap/equality.ml b/sledge/src/symbheap/equality.ml index bac3fadb1..a4b6c2ef9 100644 --- a/sledge/src/symbheap/equality.ml +++ b/sledge/src/symbheap/equality.ml @@ -47,6 +47,7 @@ module Subst : sig val iteri : t -> f:(key:Term.t -> data:Term.t -> unit) -> unit val for_alli : t -> f:(key:Term.t -> data:Term.t -> bool) -> bool val apply : t -> Term.t -> Term.t + val subst : t -> Term.t -> Term.t val norm : t -> Term.t -> Term.t val compose : t -> t -> t val compose1 : key:Term.t -> data:Term.t -> t -> t @@ -75,6 +76,8 @@ end = struct (** look up a term in a substitution *) let apply s a = Map.find s a |> Option.value ~default:a + let rec subst s a = apply s (Term.map ~f:(subst s) a) + (** apply a substitution to maximal non-interpreted subterms *) let rec norm s a = match classify a with @@ -550,6 +553,18 @@ let difference r a b = [%Trace.retn fun {pf} -> function Some d -> pf "%a" Z.pp_print d | None -> pf ""] +let apply_subst us s r = + [%Trace.call fun {pf} -> pf "%a@ %a" Subst.pp s pp r] + ; + Map.fold (classes r) ~init:true_ ~f:(fun ~key:rep ~data:cls r -> + let rep' = Subst.subst s rep in + List.fold cls ~init:r ~f:(fun r trm -> + let trm' = Subst.subst s trm in + and_eq us trm' rep' r ) ) + |> extract_xs + |> + [%Trace.retn fun {pf} (xs, r') -> pf "%a%a" Var.Set.pp_xs xs pp r'] + let and_ us r s = ( if not r.sat then r else if not s.sat then s @@ -585,6 +600,11 @@ let or_ us r s = |> [%Trace.retn fun {pf} (_, r) -> pf "%a" pp r] +let orN us rs = + match rs with + | [] -> (us, false_) + | r :: rs -> List.fold ~f:(fun (us, s) r -> or_ us s r) ~init:(us, r) rs + let rec and_term_ us e r = let eq_false b r = and_eq us b Term.false_ r in match (e : Term.t) with diff --git a/sledge/src/symbheap/equality.mli b/sledge/src/symbheap/equality.mli index 98a4a1e78..ef0469241 100644 --- a/sledge/src/symbheap/equality.mli +++ b/sledge/src/symbheap/equality.mli @@ -32,6 +32,9 @@ val and_ : Var.Set.t -> t -> t -> Var.Set.t * t val or_ : Var.Set.t -> t -> t -> Var.Set.t * t (** Disjunction. *) +val orN : Var.Set.t -> t list -> Var.Set.t * t +(** Nary disjunction. *) + val rename : t -> Var.Subst.t -> t (** Apply a renaming substitution to the relation. *) @@ -73,7 +76,9 @@ module Subst : sig val pp : t pp val is_empty : t -> bool val fold : t -> init:'a -> f:(key:Term.t -> data:Term.t -> 'a -> 'a) -> 'a - val norm : t -> Term.t -> Term.t + + val subst : t -> Term.t -> Term.t + (** Apply a substitution recursively to subterms. *) val partition_valid : Var.Set.t -> t -> t * Var.Set.t * t (** Partition ∃xs. σ into equivalent ∃xs. τ ∧ ∃ks. ν where ks @@ -81,6 +86,10 @@ module Subst : sig ks ∩ fv(τ) = ∅. *) end +val apply_subst : Var.Set.t -> Subst.t -> t -> Var.Set.t * t +(** Relation induced by applying a substitution to a set of equations + generating the argument relation. *) + val solve_for_vars : Var.Set.t list -> t -> Subst.t (** [solve_for_vars vss r] is a solution substitution that is entailed by [r] and consists of oriented equalities [x ↦ e] that map terms [x] diff --git a/sledge/src/symbheap/sh.ml b/sledge/src/symbheap/sh.ml index 5c028ca27..e69ccc451 100644 --- a/sledge/src/symbheap/sh.ml +++ b/sledge/src/symbheap/sh.ml @@ -295,22 +295,6 @@ let rec invariant q = invariant sjn ) ) with exc -> [%Trace.info "%a" pp q] ; raise exc -let rec simplify {us; xs; cong; pure; heap; djns} = - [%Trace.call fun {pf} -> pf "%a" pp {us; xs; cong; pure; heap; djns}] - ; - let heap = List.map heap ~f:(map_seg ~f:(Equality.normalize cong)) in - let pure = List.map pure ~f:(Equality.normalize cong) in - let cong = Equality.true_ in - let djns = List.map djns ~f:(List.map ~f:simplify) in - let all_vars = - fv {us= Set.union us xs; xs= Var.Set.empty; cong; pure; heap; djns} - in - let xs = Set.inter all_vars xs in - let us = Set.inter all_vars us in - {us; xs; cong; pure; heap; djns} |> check invariant - |> - [%Trace.retn fun {pf} s -> pf "%a" pp s] - (** Quantification and Vocabulary *) (** primitive application of a substitution, ignores us and xs, may violate @@ -405,6 +389,11 @@ let exists xs q = |> [%Trace.retn fun {pf} -> pf "%a" pp] +(** remove quantification on variables disjoint from vocabulary *) +let elim_exists xs q = + assert (Set.disjoint xs q.us) ; + {q with us= Set.union q.us xs; xs= Set.diff q.xs xs} + (** Construct *) let emp = @@ -418,13 +407,17 @@ let emp = let false_ us = {emp with us; djns= [[]]} |> check invariant +(** conjoin an equality relation assuming vocabulary is compatible *) +let and_cong_ cong q = + assert (Set.is_subset (Equality.fv cong) ~of_:q.us) ; + let xs, cong = Equality.and_ (Set.union q.us q.xs) q.cong cong in + if Equality.is_false cong then false_ q.us + else exists_fresh xs {q with cong} + let and_cong cong q = [%Trace.call fun {pf} -> pf "%a@ %a" Equality.pp cong pp q] ; - let q = extend_us (Equality.fv cong) q in - let xs, cong = Equality.and_ (Set.union q.us q.xs) q.cong cong in - ( if Equality.is_false cong then false_ q.us - else exists_fresh xs {q with cong} ) + and_cong_ cong (extend_us (Equality.fv cong) q) |> [%Trace.retn fun {pf} q -> pf "%a" pp q ; invariant q] @@ -467,7 +460,7 @@ let star q1 q2 = invariant q ; assert (Set.equal q.us (Set.union q1.us q2.us))] -let stars = function +let starN = function | [] -> emp | [q] -> q | q :: qs -> List.fold ~f:star ~init:q qs @@ -503,6 +496,11 @@ let or_ q1 q2 = invariant q ; assert (Set.equal q.us (Set.union q1.us q2.us))] +let orN = function + | [] -> false_ Var.Set.empty + | [q] -> q + | q :: qs -> List.fold ~f:or_ ~init:q qs + let rec pure (e : Term.t) = [%Trace.call fun {pf} -> pf "%a" Term.pp e] ; @@ -616,7 +614,7 @@ let dnf q = ; let conj sjn conjuncts = sjn :: conjuncts in let disj (xs, conjuncts) disjuncts = - exists xs (stars conjuncts) :: disjuncts + exists xs (starN conjuncts) :: disjuncts in fold_dnf ~conj ~disj q (Var.Set.empty, []) [] |> @@ -624,9 +622,148 @@ let dnf q = (** Simplify *) -let rec norm s q = - [%Trace.call fun {pf} -> pf "@[%a@]@ %a" Equality.Subst.pp s pp q] +let rec norm_ s q = + [%Trace.call fun {pf} -> pf "@[%a@]@ %a" Equality.Subst.pp s pp_raw q] ; - map q ~f_sjn:(norm s) ~f_cong:Fn.id ~f_trm:(Equality.Subst.norm s) + let q = + map q ~f_sjn:(norm_ s) ~f_cong:Fn.id ~f_trm:(Equality.Subst.subst s) + in + let xs, cong = Equality.apply_subst (Set.union q.us q.xs) s q.cong in + exists_fresh xs {q with cong} + |> + [%Trace.retn fun {pf} q' -> pf "%a" pp_raw q' ; invariant q'] + +let norm s q = + [%Trace.call fun {pf} -> pf "@[%a@]@ %a" Equality.Subst.pp s pp_raw q] + ; + (if Equality.Subst.is_empty s then q else norm_ s q) + |> + [%Trace.retn fun {pf} q' -> pf "%a" pp_raw q' ; invariant q'] + +(** rename existentially quantified variables to avoid shadowing, and reduce + quantifier scopes by sinking them as low as possible into disjunctions *) +let rec freshen_nested_xs q = + [%Trace.call fun {pf} -> pf "%a" pp q] + ; + (* trim xs to those that appear in the stem and sink the rest *) + let fv_stem = fv {q with xs= Var.Set.empty; djns= []} in + let xs_sink, xs = Set.diff_inter q.xs fv_stem in + let xs_below, djns = + List.fold_map ~init:Var.Set.empty q.djns ~f:(fun xs_below djn -> + List.fold_map ~init:xs_below djn ~f:(fun xs_below dj -> + (* quantify xs not in stem and freshen disjunct *) + let dj' = + freshen_nested_xs (exists (Set.inter xs_sink dj.us) dj) + in + let xs_below' = Set.union xs_below dj'.xs in + (xs_below', dj') ) ) + in + (* rename xs to miss all xs in subformulas *) + freshen_xs {q with xs; djns} ~wrt:(Set.union q.us xs_below) |> [%Trace.retn fun {pf} q' -> pf "%a" pp q' ; invariant q'] + +let rec propagate_equality_ ancestor_vs ancestor_cong q = + [%Trace.call fun {pf} -> + pf "(%a)@ %a" Equality.pp_classes ancestor_cong pp q] + ; + (* extend vocabulary with variables in scope above *) + let ancestor_vs = Set.union ancestor_vs (Set.union q.us q.xs) in + (* decompose formula *) + let xs, stem, djns = + (q.xs, {q with us= ancestor_vs; xs= emp.xs; djns= emp.djns}, q.djns) + in + (* strengthen equality relation with that from above *) + let ancestor_stem = and_cong_ ancestor_cong stem in + let ancestor_cong = ancestor_stem.cong in + exists xs + (List.fold djns ~init:ancestor_stem ~f:(fun q' djn -> + let dj_congs, djn = + List.rev_map_unzip djn ~f:(fun dj -> + let dj = propagate_equality_ ancestor_vs ancestor_cong dj in + (dj.cong, dj) ) + in + let new_xs, djn_cong = Equality.orN ancestor_vs dj_congs in + (* hoist xs appearing in disjunction's equality relation *) + let djn_xs = Set.diff (Equality.fv djn_cong) q'.us in + let djn = List.map ~f:(elim_exists djn_xs) djn in + let cong_djn = and_cong_ djn_cong (orN djn) in + assert (is_false cong_djn || Set.is_subset new_xs ~of_:djn_xs) ; + star (exists djn_xs cong_djn) q' )) + |> + [%Trace.retn fun {pf} q' -> pf "%a" pp q' ; invariant q'] + +let propagate_equality ancestor_vs ancestor_cong q = + [%Trace.call fun {pf} -> + pf "(%a)@ %a" Equality.pp_classes ancestor_cong pp q] + ; + propagate_equality_ ancestor_vs ancestor_cong q + |> + [%Trace.retn fun {pf} q' -> pf "%a" pp q' ; invariant q'] + +let pp_vss fs vss = + Format.fprintf fs "[@[%a@]]" + (List.pp ";@ " (fun fs vs -> Format.fprintf fs "{@[%a@]}" Var.Set.pp vs)) + vss + +let remove_absent_xs ks q = + let ks = Set.inter ks q.xs in + if Set.is_empty ks then q + else + let xs = Set.diff q.xs ks in + let djns = + let rec trim_ks ks djns = + List.map djns ~f:(fun djn -> + List.map djn ~f:(fun sjn -> + {sjn with us= Set.diff sjn.us ks; djns= trim_ks ks sjn.djns} + ) ) + in + trim_ks ks q.djns + in + {q with xs; djns} + +let rec simplify_ us rev_xss q = + [%Trace.call fun {pf} -> pf "%a@ %a" pp_vss (List.rev rev_xss) pp_raw q] + ; + let rev_xss = q.xs :: rev_xss in + (* recursively simplify subformulas *) + let q = + exists q.xs + (starN + ( {q with us= Set.union q.us q.xs; xs= emp.xs; djns= []} + :: List.map q.djns ~f:(fun djn -> + orN (List.map djn ~f:(fun sjn -> simplify_ us rev_xss sjn)) + ) )) + in + (* try to solve equations in cong for variables in xss *) + let subst = Equality.solve_for_vars (us :: List.rev rev_xss) q.cong in + (* simplification can reveal inconsistency *) + ( if is_false q then false_ q.us + else if Equality.Subst.is_empty subst then q + else + (* normalize wrt solutions *) + let q = norm subst q in + (* reconjoin only non-redundant equations *) + let removed = + Set.diff + (Var.Set.union_list rev_xss) + (fv ~ignore_cong:() (elim_exists q.xs q)) + in + let keep, removed, _ = Equality.Subst.partition_valid removed subst in + let q = and_subst keep q in + (* remove the eliminated variables from xs and subformulas' us *) + remove_absent_xs removed q ) + |> + [%Trace.retn fun {pf} q' -> + pf "%a@ %a" Equality.Subst.pp subst pp_raw q' ; + invariant q'] + +let simplify q = + [%Trace.call fun {pf} -> pf "%a" pp_raw q] + ; + let q = freshen_nested_xs q in + let q = propagate_equality Var.Set.empty Equality.true_ q in + let q = simplify_ q.us [] q in + q + |> + [%Trace.retn fun {pf} q' -> pf "@\n" ; invariant q'] diff --git a/sledge/src/symbheap/sh_test.ml b/sledge/src/symbheap/sh_test.ml index c47c0e2dc..527af711f 100644 --- a/sledge/src/symbheap/sh_test.ml +++ b/sledge/src/symbheap/sh_test.ml @@ -120,7 +120,7 @@ let%test_module _ = ∨ ( ( ( 1 = _ = %y_7 ∧ emp) ∨ ( 2 = _ ∧ emp) )) ) - ( ( emp) ∨ ( ( ( emp) ∨ ( emp) )) ) |}] + ( ( 1 = %y_7 ∧ emp) ∨ ( emp) ∨ ( emp) ) |}] let of_eqs l = List.fold ~init:emp ~f:(fun q (a, b) -> and_ (Term.eq a b) q) l @@ -139,9 +139,14 @@ let%test_module _ = ∧ ((u8) %y_7) = ((u8) (((u8) %y_7) + 1)) ∧ emp - -1 ∧ emp + (((u8) %y_7) + 1) = %y_7 + ∧ ((u8) %y_7) = ((u8) (((u8) %y_7) + 1)) + ∧ ((%y_7 + -1) = ((u8) %y_7)) + ∧ emp - emp |}] + (((u8) %y_7) + 1) = %y_7 + ∧ ((u8) %y_7) = ((u8) (((u8) %y_7) + 1)) + ∧ emp |}] let%expect_test _ = let q = @@ -172,7 +177,7 @@ let%test_module _ = ∧ emp) ) - -1 ∧ emp * ( ( (%x_6 ≠ 0) ∧ emp) ∨ ( -1 ∧ emp) ) + ( ( emp) ∨ ( (%x_6 ≠ 0) ∧ emp) ) - ( ( (%x_6 ≠ 0) ∧ emp) ∨ ( emp) ) |}] + ( ( emp) ∨ ( (%x_6 ≠ 0) ∧ emp) ) |}] end )