From f0a660792ecb8f8a98dbed877e06a85a9ac483e3 Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Mon, 27 Jan 2020 08:19:26 -0800 Subject: [PATCH] [sledge] Add Equality.solve_for_vars MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: ``` val solve_for_vars : Var.Set.t list -> t -> Subst.t (** [solve_for_vars \[v₁;…\] r] is a solution substitution that is entailed by [r] and consists of oriented equalities [x ↦ u] such that [fv x ⊈ vᵢ ⊇ fv u] where [i] is minimal such that [vᵢ] distinguishes [fv x] and [fv u], if one exists. *) ``` To be used for existential witnessing and quantifier elimination. Reviewed By: ngorogiannis Differential Revision: D19282636 fbshipit-source-id: c5b006cea --- sledge/src/import/import.ml | 3 + sledge/src/import/import.mli | 1 + sledge/src/llair/term.ml | 15 +- sledge/src/symbheap/equality.ml | 239 ++++++++++++++++++++++++++++--- sledge/src/symbheap/equality.mli | 11 ++ 5 files changed, 248 insertions(+), 21 deletions(-) diff --git a/sledge/src/import/import.ml b/sledge/src/import/import.ml index 5b5d7ccac..7ae104edf 100644 --- a/sledge/src/import/import.ml +++ b/sledge/src/import/import.ml @@ -215,6 +215,9 @@ module List = struct in remove_ [] xs + let remove ?equal xs x = + try Some (remove_exn ?equal xs x) with Not_found -> None + let rec rev_init n ~f = if n = 0 then [] else diff --git a/sledge/src/import/import.mli b/sledge/src/import/import.mli index a966f5174..2b7dd53b4 100644 --- a/sledge/src/import/import.mli +++ b/sledge/src/import/import.mli @@ -177,6 +177,7 @@ module List : sig argument, or raise [Not_found] if no such element exists. [equal] defaults to physical equality. *) + val remove : ?equal:('a -> 'a -> bool) -> 'a list -> 'a -> 'a list option val rev_init : int -> f:(int -> 'a) -> 'a list val symmetric_diff : diff --git a/sledge/src/llair/term.ml b/sledge/src/llair/term.ml index 3b098ce21..62ad271d6 100644 --- a/sledge/src/llair/term.ml +++ b/sledge/src/llair/term.ml @@ -1020,6 +1020,8 @@ let is_false = function Integer {data} -> Z.is_false data | _ -> false (** Solve *) let solve_zero_eq ?for_ e = + [%Trace.call fun {pf} -> pf "%a%a" pp e (Option.pp " for %a" pp) for_] + ; ( match e with | Add args -> let+ c, q = @@ -1034,7 +1036,12 @@ let solve_zero_eq ?for_ e = let r = div n d in (c, r) | _ -> None ) - |> check (fun soln -> - match (for_, soln) with - | Some f, Some (c, _) -> assert (equal f c) - | _ -> () ) + |> + [%Trace.retn fun {pf} s -> + pf "%a" + (Option.pp "%a" (fun fs (c, r) -> + Format.fprintf fs "%a ↦ %a" pp c pp r )) + s ; + match (for_, s) with + | Some f, Some (c, _) -> assert (equal f c) + | _ -> ()] diff --git a/sledge/src/symbheap/equality.ml b/sledge/src/symbheap/equality.ml index 50c04388b..6ddf3c6b4 100644 --- a/sledge/src/symbheap/equality.ml +++ b/sledge/src/symbheap/equality.ml @@ -10,7 +10,7 @@ (** Classification of Terms by Theory *) type kind = Interpreted | Simplified | Atomic | Uninterpreted -[@@deriving compare] +[@@deriving compare, equal] let classify e = match (e : Term.t) with @@ -20,6 +20,17 @@ let classify e = | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted | RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic +let rec fold_max_solvables e ~init ~f = + match classify e with + | Interpreted -> + Term.fold e ~init ~f:(fun d s -> fold_max_solvables ~f d ~init:s) + | _ -> f e init + +let rec iter_max_solvables e ~f = + match classify e with + | Interpreted -> Term.iter ~f:(iter_max_solvables ~f) e + | _ -> f e + (** Solution Substitutions *) module Subst : sig type t [@@deriving compare, equal, sexp] @@ -211,16 +222,27 @@ let pp_diff fs (r, s) = in Format.fprintf fs "@[{@[%t%t@]}@]" pp_sat pp_rep +let ppx_cls x = List.pp "@ = " (Term.ppx x) +let pp_cls = ppx_cls (fun _ -> None) +let pp_diff_cls = List.pp_diff ~compare:Term.compare "@ = " Term.pp + +let ppx_clss x fs cs = + List.pp "@ @<2>∧ " + (fun fs (key, data) -> + Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) key (ppx_cls x) + (List.sort ~compare:Term.compare data) ) + fs (Map.to_alist cs) + +let pp_clss fs cs = ppx_clss (fun _ -> None) fs cs + +let pp_diff_clss = + Map.pp_diff ~data_equal:(List.equal Term.equal) Term.pp pp_cls pp_diff_cls + (** Invariant *) (** test membership in carrier *) let in_car r e = Subst.mem r.rep e -let rec iter_max_solvables e ~f = - match classify e with - | Interpreted -> Term.iter ~f:(iter_max_solvables ~f) e - | _ -> f e - let invariant r = Invariant.invariant [%here] r [%sexp_of: t] @@ fun () -> @@ -411,17 +433,8 @@ let fold_vars r ~init ~f = fold_terms r ~init ~f:(fun init -> Term.fold_vars ~f ~init) let fv e = fold_vars e ~f:Set.add ~init:Var.Set.empty - -let ppx_classes x fs r = - List.pp "@ @<2>∧ " - (fun fs (key, data) -> - Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) key - (List.pp "@ = " (Term.ppx x)) - (List.sort ~compare:Term.compare data) ) - fs - (Map.to_alist (classes r)) - -let pp_classes fs r = ppx_classes (fun _ -> None) fs r +let pp_classes fs r = pp_clss fs (classes r) +let ppx_classes x fs r = ppx_clss x fs (classes r) let ppx_classes_diff x fs (r, s) = let clss = classes s in @@ -439,3 +452,195 @@ let ppx_classes_diff x fs (r, s) = (List.pp "@ = " (Term.ppx x)) (List.sort ~compare:Term.compare cls) ) fs (Map.to_alist clss) + +(** Existential Witnessing and Elimination *) + +type 'a zom = Zero | One of 'a | Many + +(* try to find a [term] in [cls] such that [fv (poly - term) ⊆ us ∪ xs] and + [poly - term] has at most one maximal solvable subterm, [kill], where [fv + kill ⊈ us]; solve [poly = term] for [kill]; extend subst mapping [kill] + to the solution *) +let solve_interp_eq us us_xs poly (cls, subst) = + [%Trace.call fun {pf} -> + pf "poly: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Term.pp poly pp_cls cls + Subst.pp subst] + ; + ( if not (Set.is_subset (Term.fv poly) ~of_:us_xs) then None + else + List.find_map cls ~f:(fun trm -> + if not (Set.is_subset (Term.fv trm) ~of_:us_xs) then None + else + let diff = Subst.norm subst (Term.sub poly trm) in + let max_solvables_not_ito_us = + fold_max_solvables diff ~init:Zero ~f:(fun solvable_subterm -> + function + | Many -> Many + | zom when Set.is_subset (Term.fv solvable_subterm) ~of_:us -> + zom + | One _ -> Many + | Zero -> One solvable_subterm ) + in + match max_solvables_not_ito_us with + | One kill -> + let+ kill, keep = Term.solve_zero_eq diff ~for_:kill in + Subst.compose1 ~key:kill ~data:keep subst + | Many | Zero -> None ) ) + |> + [%Trace.retn fun {pf} subst' -> + pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst) ; + Option.iter subst' ~f:(fun subst' -> + Subst.iteri subst' ~f:(fun ~key ~data -> + assert (Set.is_subset (Term.fv key) ~of_:us_xs) ; + assert ( + Subst.mem subst key + || not (Set.is_subset (Term.fv key) ~of_:us) ) ; + assert (Set.is_subset (Term.fv data) ~of_:us) ) )] + +(* move equations from [cls] to [subst] which are between [Interpreted] + terms and can be expressed, after normalizing with [subst], as [x ↦ u] + where [us ∪ xs ⊇ fv x ⊈ us ⊇ fv u] *) +let rec solve_interp_eqs us us_xs (cls, subst) = + [%Trace.call fun {pf} -> + pf "cls: @[%a@]@ subst: @[%a@]" pp_cls cls Subst.pp subst] + ; + let rec solve_interp_eqs_ cls' (cls, subst) = + match cls with + | [] -> (cls', subst) + | trm :: cls -> ( + let trm' = Subst.norm subst trm in + match classify trm' with + | Interpreted -> ( + match solve_interp_eq us us_xs trm' (cls', subst) with + | None -> ( + match solve_interp_eq us us_xs trm' (cls, subst) with + | None -> solve_interp_eqs_ (trm' :: cls') (cls, subst) + | Some subst -> solve_interp_eqs_ cls' (cls, subst) ) + | Some subst -> solve_interp_eqs_ cls' (cls, subst) ) + | _ -> solve_interp_eqs_ (trm' :: cls') (cls, subst) ) + in + let cls', subst' = solve_interp_eqs_ [] (cls, subst) in + ( if subst' != subst then solve_interp_eqs us us_xs (cls', subst') + else (cls', subst') ) + |> + [%Trace.retn fun {pf} (cls', subst') -> + pf "cls: @[%a@]@ subst: @[%a@]" pp_diff_cls (cls, cls') Subst.pp_diff + (subst, subst')] + +(* move equations from [cls] (which is assumed to be normalized by [subst]) + to [subst] which are between non-[Interpreted] terms and can be expressed + as [x ↦ u] where [us ∪ xs ⊇ fv x ⊈ us ⊇ fv u] *) +let solve_uninterp_eqs us us_xs (cls, subst) = + [%Trace.call fun {pf} -> + pf "cls: @[%a@]@ subst: @[%a@]" pp_cls cls Subst.pp subst] + ; + let rep_ito_us, cls_not_ito_us, cls_delay = + List.fold cls ~init:(None, [], []) + ~f:(fun (rep_ito_us, cls_not_ito_us, cls_delay) trm -> + if not (equal_kind (classify trm) Interpreted) then + let fv_trm = Term.fv trm in + if Set.is_subset fv_trm ~of_:us then + match rep_ito_us with + | Some rep when Term.compare rep trm <= 0 -> + (rep_ito_us, cls_not_ito_us, trm :: cls_delay) + | Some rep -> (Some trm, cls_not_ito_us, rep :: cls_delay) + | None -> (Some trm, cls_not_ito_us, cls_delay) + else if Set.is_subset fv_trm ~of_:us_xs then + (rep_ito_us, trm :: cls_not_ito_us, cls_delay) + else (rep_ito_us, cls_not_ito_us, trm :: cls_delay) + else (rep_ito_us, cls_not_ito_us, trm :: cls_delay) ) + in + ( match rep_ito_us with + | None -> (cls, subst) + | Some rep_ito_us -> + let cls = + if List.is_empty cls_delay then [] else rep_ito_us :: cls_delay + in + let subst = + List.fold cls_not_ito_us ~init:subst ~f:(fun subst trm_not_ito_us -> + Subst.compose1 ~key:trm_not_ito_us ~data:rep_ito_us subst ) + in + (cls, subst) ) + |> + [%Trace.retn fun {pf} (cls', subst') -> + pf "cls: @[%a@]@ subst: @[%a@]" pp_diff_cls (cls, cls') Subst.pp_diff + (subst, subst') ; + Subst.iteri subst' ~f:(fun ~key ~data -> + assert (Set.is_subset (Term.fv key) ~of_:us_xs) ; + assert ( + Subst.mem subst key || not (Set.is_subset (Term.fv key) ~of_:us) + ) ; + assert (Set.is_subset (Term.fv data) ~of_:us) )] + +(* move equations between terms in [rep]'s class [cls] from [classes] to + [subst] which can be expressed, after normalizing with [subst], as [x ↦ + u] where [us ∪ xs ⊇ fv x ⊈ us ⊇ fv u] *) +let solve_class us us_xs ~key:rep ~data:cls (classes, subst) = + let classes0 = classes in + [%Trace.call fun {pf} -> + pf "rep: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Term.pp rep pp_cls cls + Subst.pp subst] + ; + let cls, subst = solve_interp_eqs us us_xs (rep :: cls, subst) in + let cls, subst = solve_uninterp_eqs us us_xs (cls, subst) in + let cls = + List.remove ~equal:Term.equal cls (Subst.norm subst rep) + |> Option.value ~default:cls + in + let classes = + if List.is_empty cls then Map.remove classes rep + else Map.set classes ~key:rep ~data:cls + in + (classes, subst) + |> + [%Trace.retn fun {pf} (classes', subst') -> + pf "subst: @[%a@]@ classes: @[%a@]" Subst.pp_diff (subst, subst') + pp_diff_clss (classes0, classes')] + +(* move equations from [classes] to [subst] which can be expressed, after + normalizing with [subst], as [x ↦ u] where [us ∪ xs ⊇ fv x ⊈ us ⊇ fv u] *) +let solve_classes (classes, subst, us) xs = + [%Trace.call fun {pf} -> pf "xs: {@[%a@]}" Var.Set.pp xs] + ; + let rec solve_classes_ (classes0, subst0, us_xs) = + let classes, subst = + Map.fold ~f:(solve_class us us_xs) classes0 ~init:(classes0, subst0) + in + if subst != subst0 then solve_classes_ (classes, subst, us_xs) + else (classes, subst, us_xs) + in + solve_classes_ (classes, subst, Set.union us xs) + |> + [%Trace.retn fun {pf} (classes', subst', _) -> + pf "subst: @[%a@]@ classes: @[%a@]" Subst.pp_diff (subst, subst') + pp_diff_clss (classes, classes')] + +let pp_vss fs vss = + Format.fprintf fs "[@[%a@]]" + (List.pp ";@ " (fun fs vs -> Format.fprintf fs "{@[%a@]}" Var.Set.pp vs)) + vss + +(* enumerate variable contexts vᵢ in [v₁;…] and accumulate a solution subst + with entries [x ↦ u] where [r] entails [x = u] and [⋃ⱼ₌₁ⁱ vⱼ ⊇ fv x ⊈ + ⋃ⱼ₌₁ⁱ⁻¹ vⱼ ⊇ fv u] *) +let solve_for_vars vss r = + [%Trace.call fun {pf} -> pf "%a@ @[%a@]" pp_vss vss pp_classes r] + ; + List.fold ~f:solve_classes + ~init:(classes r, Subst.empty, Var.Set.empty) + vss + |> snd3 + |> + [%Trace.retn fun {pf} subst -> + pf "%a" Subst.pp subst ; + Subst.iteri subst ~f:(fun ~key ~data -> + assert (entails_eq r key data) ; + assert ( + List.exists vss ~f:(fun vs -> + match + ( Set.is_subset (Term.fv key) ~of_:vs + , Set.is_subset (Term.fv data) ~of_:vs ) + with + | false, true -> true + | true, false -> assert false + | _ -> false ) ) )] diff --git a/sledge/src/symbheap/equality.mli b/sledge/src/symbheap/equality.mli index 4e6ea5ae1..7c67abf6f 100644 --- a/sledge/src/symbheap/equality.mli +++ b/sledge/src/symbheap/equality.mli @@ -62,3 +62,14 @@ val difference : t -> Term.t -> Term.t -> Z.t option offset. *) val fold_terms : t -> init:'a -> f:('a -> Term.t -> 'a) -> 'a + +(** Solution Substitutions *) +module Subst : sig + type t [@@deriving compare, equal, sexp] +end + +val solve_for_vars : Var.Set.t list -> t -> Subst.t +(** [solve_for_vars \[v₁;…\] r] is a solution substitution that is + entailed by [r] and consists of oriented equalities [x ↦ u] such that + [fv x ⊈ vᵢ ⊇ fv u] where [i] is minimal such that [vᵢ] + distinguishes [fv x] and [fv u], if one exists. *)