[sledge] Add Equality.solve_for_vars

Summary:
```
val solve_for_vars : Var.Set.t list -> t -> Subst.t
(** [solve_for_vars \[v₁;…\] r] is a solution substitution that is
    entailed by [r] and consists of oriented equalities [x ↦ u] such that
    [fv x ⊈ vᵢ ⊇ fv u] where [i] is minimal such that [vᵢ]
    distinguishes [fv x] and [fv u], if one exists. *)
```

To be used for existential witnessing and quantifier elimination.

Reviewed By: ngorogiannis

Differential Revision: D19282636

fbshipit-source-id: c5b006cea
master
Josh Berdine 5 years ago committed by Facebook Github Bot
parent 7cb11b587a
commit f0a660792e

@ -215,6 +215,9 @@ module List = struct
in
remove_ [] xs
let remove ?equal xs x =
try Some (remove_exn ?equal xs x) with Not_found -> None
let rec rev_init n ~f =
if n = 0 then []
else

@ -177,6 +177,7 @@ module List : sig
argument, or raise [Not_found] if no such element exists. [equal]
defaults to physical equality. *)
val remove : ?equal:('a -> 'a -> bool) -> 'a list -> 'a -> 'a list option
val rev_init : int -> f:(int -> 'a) -> 'a list
val symmetric_diff :

@ -1020,6 +1020,8 @@ let is_false = function Integer {data} -> Z.is_false data | _ -> false
(** Solve *)
let solve_zero_eq ?for_ e =
[%Trace.call fun {pf} -> pf "%a%a" pp e (Option.pp " for %a" pp) for_]
;
( match e with
| Add args ->
let+ c, q =
@ -1034,7 +1036,12 @@ let solve_zero_eq ?for_ e =
let r = div n d in
(c, r)
| _ -> None )
|> check (fun soln ->
match (for_, soln) with
|>
[%Trace.retn fun {pf} s ->
pf "%a"
(Option.pp "%a" (fun fs (c, r) ->
Format.fprintf fs "%a ↦ %a" pp c pp r ))
s ;
match (for_, s) with
| Some f, Some (c, _) -> assert (equal f c)
| _ -> () )
| _ -> ()]

@ -10,7 +10,7 @@
(** Classification of Terms by Theory *)
type kind = Interpreted | Simplified | Atomic | Uninterpreted
[@@deriving compare]
[@@deriving compare, equal]
let classify e =
match (e : Term.t) with
@ -20,6 +20,17 @@ let classify e =
| Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted
| RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic
let rec fold_max_solvables e ~init ~f =
match classify e with
| Interpreted ->
Term.fold e ~init ~f:(fun d s -> fold_max_solvables ~f d ~init:s)
| _ -> f e init
let rec iter_max_solvables e ~f =
match classify e with
| Interpreted -> Term.iter ~f:(iter_max_solvables ~f) e
| _ -> f e
(** Solution Substitutions *)
module Subst : sig
type t [@@deriving compare, equal, sexp]
@ -211,16 +222,27 @@ let pp_diff fs (r, s) =
in
Format.fprintf fs "@[{@[<hv>%t%t@]}@]" pp_sat pp_rep
let ppx_cls x = List.pp "@ = " (Term.ppx x)
let pp_cls = ppx_cls (fun _ -> None)
let pp_diff_cls = List.pp_diff ~compare:Term.compare "@ = " Term.pp
let ppx_clss x fs cs =
List.pp "@ @<2>∧ "
(fun fs (key, data) ->
Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) key (ppx_cls x)
(List.sort ~compare:Term.compare data) )
fs (Map.to_alist cs)
let pp_clss fs cs = ppx_clss (fun _ -> None) fs cs
let pp_diff_clss =
Map.pp_diff ~data_equal:(List.equal Term.equal) Term.pp pp_cls pp_diff_cls
(** Invariant *)
(** test membership in carrier *)
let in_car r e = Subst.mem r.rep e
let rec iter_max_solvables e ~f =
match classify e with
| Interpreted -> Term.iter ~f:(iter_max_solvables ~f) e
| _ -> f e
let invariant r =
Invariant.invariant [%here] r [%sexp_of: t]
@@ fun () ->
@ -411,17 +433,8 @@ let fold_vars r ~init ~f =
fold_terms r ~init ~f:(fun init -> Term.fold_vars ~f ~init)
let fv e = fold_vars e ~f:Set.add ~init:Var.Set.empty
let ppx_classes x fs r =
List.pp "@ @<2>∧ "
(fun fs (key, data) ->
Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) key
(List.pp "@ = " (Term.ppx x))
(List.sort ~compare:Term.compare data) )
fs
(Map.to_alist (classes r))
let pp_classes fs r = ppx_classes (fun _ -> None) fs r
let pp_classes fs r = pp_clss fs (classes r)
let ppx_classes x fs r = ppx_clss x fs (classes r)
let ppx_classes_diff x fs (r, s) =
let clss = classes s in
@ -439,3 +452,195 @@ let ppx_classes_diff x fs (r, s) =
(List.pp "@ = " (Term.ppx x))
(List.sort ~compare:Term.compare cls) )
fs (Map.to_alist clss)
(** Existential Witnessing and Elimination *)
type 'a zom = Zero | One of 'a | Many
(* try to find a [term] in [cls] such that [fv (poly - term) ⊆ us xs] and
[poly - term] has at most one maximal solvable subterm, [kill], where [fv
kill us]; solve [poly = term] for [kill]; extend subst mapping [kill]
to the solution *)
let solve_interp_eq us us_xs poly (cls, subst) =
[%Trace.call fun {pf} ->
pf "poly: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Term.pp poly pp_cls cls
Subst.pp subst]
;
( if not (Set.is_subset (Term.fv poly) ~of_:us_xs) then None
else
List.find_map cls ~f:(fun trm ->
if not (Set.is_subset (Term.fv trm) ~of_:us_xs) then None
else
let diff = Subst.norm subst (Term.sub poly trm) in
let max_solvables_not_ito_us =
fold_max_solvables diff ~init:Zero ~f:(fun solvable_subterm ->
function
| Many -> Many
| zom when Set.is_subset (Term.fv solvable_subterm) ~of_:us ->
zom
| One _ -> Many
| Zero -> One solvable_subterm )
in
match max_solvables_not_ito_us with
| One kill ->
let+ kill, keep = Term.solve_zero_eq diff ~for_:kill in
Subst.compose1 ~key:kill ~data:keep subst
| Many | Zero -> None ) )
|>
[%Trace.retn fun {pf} subst' ->
pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst) ;
Option.iter subst' ~f:(fun subst' ->
Subst.iteri subst' ~f:(fun ~key ~data ->
assert (Set.is_subset (Term.fv key) ~of_:us_xs) ;
assert (
Subst.mem subst key
|| not (Set.is_subset (Term.fv key) ~of_:us) ) ;
assert (Set.is_subset (Term.fv data) ~of_:us) ) )]
(* move equations from [cls] to [subst] which are between [Interpreted]
terms and can be expressed, after normalizing with [subst], as [x u]
where [us xs fv x us fv u] *)
let rec solve_interp_eqs us us_xs (cls, subst) =
[%Trace.call fun {pf} ->
pf "cls: @[%a@]@ subst: @[%a@]" pp_cls cls Subst.pp subst]
;
let rec solve_interp_eqs_ cls' (cls, subst) =
match cls with
| [] -> (cls', subst)
| trm :: cls -> (
let trm' = Subst.norm subst trm in
match classify trm' with
| Interpreted -> (
match solve_interp_eq us us_xs trm' (cls', subst) with
| None -> (
match solve_interp_eq us us_xs trm' (cls, subst) with
| None -> solve_interp_eqs_ (trm' :: cls') (cls, subst)
| Some subst -> solve_interp_eqs_ cls' (cls, subst) )
| Some subst -> solve_interp_eqs_ cls' (cls, subst) )
| _ -> solve_interp_eqs_ (trm' :: cls') (cls, subst) )
in
let cls', subst' = solve_interp_eqs_ [] (cls, subst) in
( if subst' != subst then solve_interp_eqs us us_xs (cls', subst')
else (cls', subst') )
|>
[%Trace.retn fun {pf} (cls', subst') ->
pf "cls: @[%a@]@ subst: @[%a@]" pp_diff_cls (cls, cls') Subst.pp_diff
(subst, subst')]
(* move equations from [cls] (which is assumed to be normalized by [subst])
to [subst] which are between non-[Interpreted] terms and can be expressed
as [x u] where [us xs fv x us fv u] *)
let solve_uninterp_eqs us us_xs (cls, subst) =
[%Trace.call fun {pf} ->
pf "cls: @[%a@]@ subst: @[%a@]" pp_cls cls Subst.pp subst]
;
let rep_ito_us, cls_not_ito_us, cls_delay =
List.fold cls ~init:(None, [], [])
~f:(fun (rep_ito_us, cls_not_ito_us, cls_delay) trm ->
if not (equal_kind (classify trm) Interpreted) then
let fv_trm = Term.fv trm in
if Set.is_subset fv_trm ~of_:us then
match rep_ito_us with
| Some rep when Term.compare rep trm <= 0 ->
(rep_ito_us, cls_not_ito_us, trm :: cls_delay)
| Some rep -> (Some trm, cls_not_ito_us, rep :: cls_delay)
| None -> (Some trm, cls_not_ito_us, cls_delay)
else if Set.is_subset fv_trm ~of_:us_xs then
(rep_ito_us, trm :: cls_not_ito_us, cls_delay)
else (rep_ito_us, cls_not_ito_us, trm :: cls_delay)
else (rep_ito_us, cls_not_ito_us, trm :: cls_delay) )
in
( match rep_ito_us with
| None -> (cls, subst)
| Some rep_ito_us ->
let cls =
if List.is_empty cls_delay then [] else rep_ito_us :: cls_delay
in
let subst =
List.fold cls_not_ito_us ~init:subst ~f:(fun subst trm_not_ito_us ->
Subst.compose1 ~key:trm_not_ito_us ~data:rep_ito_us subst )
in
(cls, subst) )
|>
[%Trace.retn fun {pf} (cls', subst') ->
pf "cls: @[%a@]@ subst: @[%a@]" pp_diff_cls (cls, cls') Subst.pp_diff
(subst, subst') ;
Subst.iteri subst' ~f:(fun ~key ~data ->
assert (Set.is_subset (Term.fv key) ~of_:us_xs) ;
assert (
Subst.mem subst key || not (Set.is_subset (Term.fv key) ~of_:us)
) ;
assert (Set.is_subset (Term.fv data) ~of_:us) )]
(* move equations between terms in [rep]'s class [cls] from [classes] to
[subst] which can be expressed, after normalizing with [subst], as [x
u] where [us xs fv x us fv u] *)
let solve_class us us_xs ~key:rep ~data:cls (classes, subst) =
let classes0 = classes in
[%Trace.call fun {pf} ->
pf "rep: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Term.pp rep pp_cls cls
Subst.pp subst]
;
let cls, subst = solve_interp_eqs us us_xs (rep :: cls, subst) in
let cls, subst = solve_uninterp_eqs us us_xs (cls, subst) in
let cls =
List.remove ~equal:Term.equal cls (Subst.norm subst rep)
|> Option.value ~default:cls
in
let classes =
if List.is_empty cls then Map.remove classes rep
else Map.set classes ~key:rep ~data:cls
in
(classes, subst)
|>
[%Trace.retn fun {pf} (classes', subst') ->
pf "subst: @[%a@]@ classes: @[%a@]" Subst.pp_diff (subst, subst')
pp_diff_clss (classes0, classes')]
(* move equations from [classes] to [subst] which can be expressed, after
normalizing with [subst], as [x u] where [us xs fv x us fv u] *)
let solve_classes (classes, subst, us) xs =
[%Trace.call fun {pf} -> pf "xs: {@[%a@]}" Var.Set.pp xs]
;
let rec solve_classes_ (classes0, subst0, us_xs) =
let classes, subst =
Map.fold ~f:(solve_class us us_xs) classes0 ~init:(classes0, subst0)
in
if subst != subst0 then solve_classes_ (classes, subst, us_xs)
else (classes, subst, us_xs)
in
solve_classes_ (classes, subst, Set.union us xs)
|>
[%Trace.retn fun {pf} (classes', subst', _) ->
pf "subst: @[%a@]@ classes: @[%a@]" Subst.pp_diff (subst, subst')
pp_diff_clss (classes, classes')]
let pp_vss fs vss =
Format.fprintf fs "[@[%a@]]"
(List.pp ";@ " (fun fs vs -> Format.fprintf fs "{@[%a@]}" Var.Set.pp vs))
vss
(* enumerate variable contexts vᵢ in [v₁;…] and accumulate a solution subst
with entries [x u] where [r] entails [x = u] and [ v fv x
¹ v fv u] *)
let solve_for_vars vss r =
[%Trace.call fun {pf} -> pf "%a@ @[%a@]" pp_vss vss pp_classes r]
;
List.fold ~f:solve_classes
~init:(classes r, Subst.empty, Var.Set.empty)
vss
|> snd3
|>
[%Trace.retn fun {pf} subst ->
pf "%a" Subst.pp subst ;
Subst.iteri subst ~f:(fun ~key ~data ->
assert (entails_eq r key data) ;
assert (
List.exists vss ~f:(fun vs ->
match
( Set.is_subset (Term.fv key) ~of_:vs
, Set.is_subset (Term.fv data) ~of_:vs )
with
| false, true -> true
| true, false -> assert false
| _ -> false ) ) )]

@ -62,3 +62,14 @@ val difference : t -> Term.t -> Term.t -> Z.t option
offset. *)
val fold_terms : t -> init:'a -> f:('a -> Term.t -> 'a) -> 'a
(** Solution Substitutions *)
module Subst : sig
type t [@@deriving compare, equal, sexp]
end
val solve_for_vars : Var.Set.t list -> t -> Subst.t
(** [solve_for_vars \[v₁;…\] r] is a solution substitution that is
entailed by [r] and consists of oriented equalities [x u] such that
[fv x v fv u] where [i] is minimal such that [v]
distinguishes [fv x] and [fv u], if one exists. *)

Loading…
Cancel
Save