From 1a9737dbe5fcfc4c995646b9a1043aba219efedb Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Tue, 12 Jan 2021 04:28:54 -0800 Subject: [PATCH] [sledge] Factor out representation of equality classes into Context.Cls Reviewed By: jvillard Differential Revision: D25756566 fbshipit-source-id: 01ce88f1a --- sledge/nonstdlib/list.ml | 1 + sledge/nonstdlib/list.mli | 1 + sledge/nonstdlib/set.ml | 1 + sledge/nonstdlib/set_intf.ml | 1 + sledge/src/fol/context.ml | 156 ++++++++++++++++++++++------------- sledge/src/fol/trm.mli | 1 + 6 files changed, 103 insertions(+), 58 deletions(-) diff --git a/sledge/nonstdlib/list.ml b/sledge/nonstdlib/list.ml index dfa484906..2a2457a59 100644 --- a/sledge/nonstdlib/list.ml +++ b/sledge/nonstdlib/list.ml @@ -35,6 +35,7 @@ let remove_one ~eq x xs = let remove ~eq x xs = remove ~eq ~key:x xs let filter xs ~f = filter ~f xs +let partition xs ~f = partition ~f xs let map xs ~f = map ~f xs let map_endo t ~f = map_endo map t ~f diff --git a/sledge/nonstdlib/list.mli b/sledge/nonstdlib/list.mli index 7372bf36b..2825703e3 100644 --- a/sledge/nonstdlib/list.mli +++ b/sledge/nonstdlib/list.mli @@ -44,6 +44,7 @@ val remove_one_exn : eq:('a -> 'a -> bool) -> 'a -> 'a list -> 'a list val remove_one : eq:('a -> 'a -> bool) -> 'a -> 'a list -> 'a list option val remove : eq:('a -> 'a -> bool) -> 'a -> 'a list -> 'a list val filter : 'a list -> f:('a -> bool) -> 'a list +val partition : 'a list -> f:('a -> bool) -> 'a list * 'a list val map : 'a t -> f:('a -> 'b) -> 'b t val map_endo : 'a t -> f:('a -> 'a) -> 'a t diff --git a/sledge/nonstdlib/set.ml b/sledge/nonstdlib/set.ml index 72a684f7e..8783dbc81 100644 --- a/sledge/nonstdlib/set.ml +++ b/sledge/nonstdlib/set.ml @@ -107,6 +107,7 @@ end) : S with type elt = Elt.t = struct let map s ~f = S.map f s let filter s ~f = S.filter f s + let partition s ~f = S.partition f s let iter s ~f = S.iter f s let exists s ~f = S.exists f s let for_all s ~f = S.for_all f s diff --git a/sledge/nonstdlib/set_intf.ml b/sledge/nonstdlib/set_intf.ml index b06520e7c..30b1a6c8d 100644 --- a/sledge/nonstdlib/set_intf.ml +++ b/sledge/nonstdlib/set_intf.ml @@ -83,6 +83,7 @@ module type S = sig val map : t -> f:(elt -> elt) -> t val filter : t -> f:(elt -> bool) -> t + val partition : t -> f:(elt -> bool) -> t * t (** {1 Traverse} *) diff --git a/sledge/src/fol/context.ml b/sledge/src/fol/context.ml index 148ea4ed9..5b304be1a 100644 --- a/sledge/src/fol/context.ml +++ b/sledge/src/fol/context.ml @@ -414,6 +414,48 @@ let solve ?f ~wrt ~xs d e = | Some (xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s | None -> pf "false"] +(** Equality classes *) + +module Cls : sig + type t [@@deriving equal] + + val empty : t + val of_ : Trm.t -> t + val add : Trm.t -> t -> t + val remove : Trm.t -> t -> t + val union : t -> t -> t + val is_empty : t -> bool + val pop : t -> (Trm.t * t) option + val filter : t -> f:(Trm.t -> bool) -> t + val partition : t -> f:(Trm.t -> bool) -> t * t + val fold : t -> 's -> f:(Trm.t -> 's -> 's) -> 's + val to_iter : t -> Trm.t iter + val to_set : t -> Trm.Set.t + val sort : t -> t + val ppx : Trm.Var.strength -> t pp + val pp : t pp + val pp_diff : (t * t) pp +end = struct + type t = Trm.t list [@@deriving equal] + + let empty = [] + let of_ e = [e] + let add = List.cons + let remove = List.remove ~eq:Trm.equal + let union = List.rev_append + let is_empty = List.is_empty + let pop = function [] -> None | x :: xs -> Some (x, xs) + let filter = List.filter + let partition = List.partition + let fold = List.fold + let to_iter = List.to_iter + let to_set = Trm.Set.of_list + let sort = List.sort ~cmp:Trm.compare + let ppx x = List.pp "@ = " (Trm.ppx x) + let pp = ppx (fun _ -> None) + let pp_diff = List.pp_diff ~cmp:Trm.compare "@ = " Trm.pp +end + (** Equality Relations *) (** see also [invariant] *) @@ -430,11 +472,13 @@ type t = let classes r = Subst.fold r.rep Trm.Map.empty ~f:(fun ~key:elt ~data:rep cls -> if Trm.equal elt rep then cls - else Trm.Map.add_multi ~key:rep ~data:elt cls ) + else + Trm.Map.update rep cls ~f:(fun cls0 -> + Cls.add elt (Option.value cls0 ~default:Cls.empty) ) ) let cls_of r e = let e' = Subst.apply r.rep e in - Trm.Map.find e' (classes r) |> Option.value ~default:[e'] + Trm.Map.find e' (classes r) |> Option.value ~default:(Cls.of_ e') (** Pretty-printing *) @@ -461,22 +505,16 @@ let pp_diff fs (r, s) = in Format.fprintf fs "@[{@[%t%t@]}@]" pp_sat pp_rep -let ppx_cls x = List.pp "@ = " (Trm.ppx x) -let pp_cls = ppx_cls (fun _ -> None) -let pp_diff_cls = List.pp_diff ~cmp:Trm.compare "@ = " Trm.pp - let ppx_classes x fs clss = List.pp "@ @<2>∧ " (fun fs (rep, cls) -> - if not (List.is_empty cls) then - Format.fprintf fs "@[%a@ = %a@]" (Trm.ppx x) rep (ppx_cls x) cls ) + if not (Cls.is_empty cls) then + Format.fprintf fs "@[%a@ = %a@]" (Trm.ppx x) rep (Cls.ppx x) cls ) fs (Iter.to_list (Trm.Map.to_iter clss)) let pp_classes fs r = ppx_classes (fun _ -> None) fs (classes r) - -let pp_diff_clss = - Trm.Map.pp_diff ~eq:(List.equal Trm.equal) Trm.pp pp_cls pp_diff_cls +let pp_diff_clss = Trm.Map.pp_diff ~eq:Cls.equal Trm.pp Cls.pp Cls.pp_diff let pp fs r = let clss = classes r in @@ -486,7 +524,7 @@ let pp fs r = let ppx var_strength fs clss noneqs = let without_anon_vars = - List.filter ~f:(fun e -> + Cls.filter ~f:(fun e -> match Var.of_trm e with | Some v -> Poly.(var_strength v <> Some `Anonymous) | None -> true ) @@ -494,8 +532,8 @@ let ppx var_strength fs clss noneqs = let clss = Trm.Map.fold clss Trm.Map.empty ~f:(fun ~key:rep ~data:cls m -> let cls = without_anon_vars cls in - if not (List.is_empty cls) then - Trm.Map.add ~key:rep ~data:(List.sort ~cmp:Trm.compare cls) m + if not (Cls.is_empty cls) then + Trm.Map.add ~key:rep ~data:(Cls.sort cls) m else m ) in let first = Trm.Map.is_empty clss in @@ -690,16 +728,15 @@ let normalize r e = Term.map_trms ~f:(canon r) e let class_of r e = match Term.get_trm (normalize r e) with | Some e' -> - List.map ~f:Term.of_trm (e' :: Trm.Map.find_multi e' (classes r)) + Iter.to_list (Iter.map ~f:Term.of_trm (Cls.to_iter (cls_of r e'))) | None -> [] let diff_classes r s = Trm.Map.filter_mapi (classes r) ~f:(fun ~key:rep ~data:cls -> - match - List.filter cls ~f:(fun exp -> not (implies s (Fml.eq rep exp))) - with - | [] -> None - | cls -> Some cls ) + let cls' = + Cls.filter cls ~f:(fun exp -> not (implies s (Fml.eq rep exp))) + in + if Cls.is_empty cls' then None else Some cls' ) let ppx_diff var_strength fs parent_ctx fml ctx = let fml' = canon_f ctx fml in @@ -729,7 +766,7 @@ let apply_subst wrt s r = Trm.Map.fold (classes r) {r with rep= Subst.empty} ~f:(fun ~key:rep ~data:cls r -> let rep' = Subst.subst_ s rep in - List.fold cls r ~f:(fun trm r -> + Cls.fold cls r ~f:(fun trm r -> let trm' = Subst.subst_ s trm in and_eq_ ~wrt trm' rep' r ) ) ) |> extract_xs @@ -762,7 +799,7 @@ let inter wrt r s = else let merge_mems rs r s = Trm.Map.fold (classes s) rs ~f:(fun ~key:rep ~data:cls rs -> - List.fold cls + Cls.fold cls ([rep], rs) ~f:(fun exp (reps, rs) -> match @@ -906,10 +943,10 @@ let solve_seq_eq ~wrt us e' f' subst = let solve_interp_eq ~wrt us e' (cls, subst) = [%Trace.call fun {pf} -> - pf "@ trm: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp e' pp_cls cls + pf "@ trm: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp e' Cls.pp cls Subst.pp subst] ; - List.find_map cls ~f:(fun f -> + Iter.find_map (Cls.to_iter cls) ~f:(fun f -> let f' = Subst.norm subst f in match solve_seq_eq ~wrt us e' f' subst with | Some subst -> Some subst @@ -925,32 +962,32 @@ let solve_interp_eq ~wrt us e' (cls, subst) = [fv u ⊆ us ∪ xs] *) let rec solve_interp_eqs ~wrt us (cls, subst) = [%Trace.call fun {pf} -> - pf "@ cls: @[%a@]@ subst: @[%a@]" pp_cls cls Subst.pp subst] + pf "@ cls: @[%a@]@ subst: @[%a@]" Cls.pp cls Subst.pp subst] ; let rec solve_interp_eqs_ cls' (cls, subst) = - match cls with - | [] -> (cls', subst) - | trm :: cls -> + match Cls.pop cls with + | None -> (cls', subst) + | Some (trm, cls) -> let trm' = Subst.norm subst trm in if is_interpreted trm' then match solve_interp_eq ~wrt us trm' (cls, subst) with | Some subst -> solve_interp_eqs_ cls' (cls, subst) - | None -> solve_interp_eqs_ (trm' :: cls') (cls, subst) - else solve_interp_eqs_ (trm' :: cls') (cls, subst) + | None -> solve_interp_eqs_ (Cls.add trm' cls') (cls, subst) + else solve_interp_eqs_ (Cls.add trm' cls') (cls, subst) in - let cls', subst' = solve_interp_eqs_ [] (cls, subst) in + let cls', subst' = solve_interp_eqs_ Cls.empty (cls, subst) in ( if subst' != subst then solve_interp_eqs ~wrt us (cls', subst') else (cls', subst') ) |> [%Trace.retn fun {pf} (cls', subst') -> - pf "cls: @[%a@]@ subst: @[%a@]" pp_diff_cls (cls, cls') Subst.pp_diff + pf "cls: @[%a@]@ subst: @[%a@]" Cls.pp_diff (cls, cls') Subst.pp_diff (subst, subst')] type cls_solve_state = { rep_us: Trm.t option (** rep, that is ito us, for class *) - ; cls_us: Trm.t list (** cls that is ito us, or interpreted *) + ; cls_us: Cls.t (** cls that is ito us, or interpreted *) ; rep_xs: Trm.t option (** rep, that is *not* ito us, for class *) - ; cls_xs: Trm.t list (** cls that is *not* ito us *) } + ; cls_xs: Cls.t (** cls that is *not* ito us *) } let dom_trm e = match (e : Trm.t) with @@ -964,47 +1001,49 @@ let dom_trm e = [fv u ⊆ us ∪ xs] *) let solve_uninterp_eqs us (cls, subst) = [%Trace.call fun {pf} -> - pf "@ cls: @[%a@]@ subst: @[%a@]" pp_cls cls Subst.pp subst] + pf "@ cls: @[%a@]@ subst: @[%a@]" Cls.pp cls Subst.pp subst] ; let compare e f = [%compare: kind * Trm.t] (classify e, e) (classify f, f) in let {rep_us; cls_us; rep_xs; cls_xs} = - List.fold cls {rep_us= None; cls_us= []; rep_xs= None; cls_xs= []} + Cls.fold cls + {rep_us= None; cls_us= Cls.empty; rep_xs= None; cls_xs= Cls.empty} ~f:(fun trm ({rep_us; cls_us; rep_xs; cls_xs} as s) -> if Var.Set.subset (Trm.fv trm) ~of_:us then match rep_us with | Some rep when compare rep trm <= 0 -> - {s with cls_us= trm :: cls_us} - | Some rep -> {s with rep_us= Some trm; cls_us= rep :: cls_us} + {s with cls_us= Cls.add trm cls_us} + | Some rep -> {s with rep_us= Some trm; cls_us= Cls.add rep cls_us} | None -> {s with rep_us= Some trm} else match rep_xs with | Some rep -> ( if compare rep trm <= 0 then match dom_trm trm with - | Some trm -> {s with cls_xs= trm :: cls_xs} - | None -> {s with cls_us= trm :: cls_us} + | Some trm -> {s with cls_xs= Cls.add trm cls_xs} + | None -> {s with cls_us= Cls.add trm cls_us} else match dom_trm rep with | Some rep -> - {s with rep_xs= Some trm; cls_xs= rep :: cls_xs} - | None -> {s with rep_xs= Some trm; cls_us= rep :: cls_us} ) + {s with rep_xs= Some trm; cls_xs= Cls.add rep cls_xs} + | None -> + {s with rep_xs= Some trm; cls_us= Cls.add rep cls_us} ) | None -> {s with rep_xs= Some trm} ) in ( match rep_us with | Some rep_us -> - let cls = rep_us :: cls_us in + let cls = Cls.add rep_us cls_us in let cls, cls_xs = match rep_xs with | Some rep -> ( match dom_trm rep with - | Some rep -> (cls, rep :: cls_xs) - | None -> (rep :: cls, cls_xs) ) + | Some rep -> (cls, Cls.add rep cls_xs) + | None -> (Cls.add rep cls, cls_xs) ) | None -> (cls, cls_xs) in let subst = - List.fold cls_xs subst ~f:(fun trm_xs subst -> + Cls.fold cls_xs subst ~f:(fun trm_xs subst -> let trm_xs = Subst.subst_ subst trm_xs in let rep_us = Subst.subst_ subst rep_us in Subst.compose1 ~key:trm_xs ~data:rep_us subst ) @@ -1013,16 +1052,16 @@ let solve_uninterp_eqs us (cls, subst) = | None -> ( match rep_xs with | Some rep_xs -> - let cls = rep_xs :: cls_us in + let cls = Cls.add rep_xs cls_us in let subst = - List.fold cls_xs subst ~f:(fun trm_xs subst -> + Cls.fold cls_xs subst ~f:(fun trm_xs subst -> Subst.compose1 ~key:trm_xs ~data:rep_xs subst ) in (cls, subst) | None -> (cls, subst) ) ) |> [%Trace.retn fun {pf} (cls', subst') -> - pf "cls: @[%a@]@ subst: @[%a@]" pp_diff_cls (cls, cls') Subst.pp_diff + pf "cls: @[%a@]@ subst: @[%a@]" Cls.pp_diff (cls, cls') Subst.pp_diff (subst, subst') ; subst_invariant us subst subst'] @@ -1033,20 +1072,20 @@ let solve_uninterp_eqs us (cls, subst) = let solve_class ~wrt us us_xs ~key:rep ~data:cls (classes, subst) = let classes0 = classes in [%Trace.call fun {pf} -> - pf "@ rep: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp rep pp_cls cls + pf "@ rep: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp rep Cls.pp cls Subst.pp subst] ; let cls, cls_not_ito_us_xs = - List.partition + Cls.partition ~f:(fun e -> Var.Set.subset (Trm.fv e) ~of_:us_xs) - (rep :: cls) + (Cls.add rep cls) in let cls, subst = solve_interp_eqs ~wrt us (cls, subst) in let cls, subst = solve_uninterp_eqs us (cls, subst) in - let cls = List.rev_append cls_not_ito_us_xs cls in - let cls = List.remove ~eq:Trm.equal (Subst.norm subst rep) cls in + let cls = Cls.union cls_not_ito_us_xs cls in + let cls = Cls.remove (Subst.norm subst rep) cls in let classes = - if List.is_empty cls then Trm.Map.remove rep classes + if Cls.is_empty cls then Trm.Map.remove rep classes else Trm.Map.add ~key:rep ~data:cls classes in (classes, subst) @@ -1092,7 +1131,7 @@ let solve_concat_extracts r us x (classes, subst, us_xs) = List.filter_map (solve_concat_extracts_eq r x) ~f:(fun rev_extracts -> Iter.fold_opt (Iter.of_list rev_extracts) [] ~f:(fun e suffix -> let+ rep_ito_us = - List.fold (cls_of r e) None ~f:(fun trm rep_ito_us -> + Cls.fold (cls_of r e) None ~f:(fun trm rep_ito_us -> match rep_ito_us with | Some rep when Trm.compare rep trm <= 0 -> rep_ito_us | _ when Var.Set.subset (Trm.fv trm) ~of_:us -> Some trm @@ -1208,12 +1247,13 @@ let trim ks r = in let clss = Trm.Set.fold reps (classes r) ~f:(fun rep clss -> - Trm.Map.add_multi ~key:rep ~data:rep clss ) + Trm.Map.update rep clss ~f:(fun cls0 -> + Cls.add rep (Option.value cls0 ~default:Cls.empty) ) ) in (* trim classes to those that intersect kills *) let clss = Trm.Map.filter_mapi clss ~f:(fun ~key:_ ~data:cls -> - let cls = Trm.Set.of_list cls in + let cls = Cls.to_set cls in if Trm.Set.disjoint kills cls then None else Some cls ) in (* enumerate affected classes and update solution subst *) diff --git a/sledge/src/fol/trm.mli b/sledge/src/fol/trm.mli index a273a11f6..969a3cd4a 100644 --- a/sledge/src/fol/trm.mli +++ b/sledge/src/fol/trm.mli @@ -42,6 +42,7 @@ module Set : sig val t_of_sexp : Sexp.t -> t val pp : t pp + val pp_diff : (t * t) pp end module Map : sig