[sledge] Factor out representation of equality classes into Context.Cls

Reviewed By: jvillard

Differential Revision: D25756566

fbshipit-source-id: 01ce88f1a
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent 9b4059b07c
commit 1a9737dbe5

@ -35,6 +35,7 @@ let remove_one ~eq x xs =
let remove ~eq x xs = remove ~eq ~key:x xs
let filter xs ~f = filter ~f xs
let partition xs ~f = partition ~f xs
let map xs ~f = map ~f xs
let map_endo t ~f = map_endo map t ~f

@ -44,6 +44,7 @@ val remove_one_exn : eq:('a -> 'a -> bool) -> 'a -> 'a list -> 'a list
val remove_one : eq:('a -> 'a -> bool) -> 'a -> 'a list -> 'a list option
val remove : eq:('a -> 'a -> bool) -> 'a -> 'a list -> 'a list
val filter : 'a list -> f:('a -> bool) -> 'a list
val partition : 'a list -> f:('a -> bool) -> 'a list * 'a list
val map : 'a t -> f:('a -> 'b) -> 'b t
val map_endo : 'a t -> f:('a -> 'a) -> 'a t

@ -107,6 +107,7 @@ end) : S with type elt = Elt.t = struct
let map s ~f = S.map f s
let filter s ~f = S.filter f s
let partition s ~f = S.partition f s
let iter s ~f = S.iter f s
let exists s ~f = S.exists f s
let for_all s ~f = S.for_all f s

@ -83,6 +83,7 @@ module type S = sig
val map : t -> f:(elt -> elt) -> t
val filter : t -> f:(elt -> bool) -> t
val partition : t -> f:(elt -> bool) -> t * t
(** {1 Traverse} *)

@ -414,6 +414,48 @@ let solve ?f ~wrt ~xs d e =
| Some (xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s
| None -> pf "false"]
(** Equality classes *)
module Cls : sig
type t [@@deriving equal]
val empty : t
val of_ : Trm.t -> t
val add : Trm.t -> t -> t
val remove : Trm.t -> t -> t
val union : t -> t -> t
val is_empty : t -> bool
val pop : t -> (Trm.t * t) option
val filter : t -> f:(Trm.t -> bool) -> t
val partition : t -> f:(Trm.t -> bool) -> t * t
val fold : t -> 's -> f:(Trm.t -> 's -> 's) -> 's
val to_iter : t -> Trm.t iter
val to_set : t -> Trm.Set.t
val sort : t -> t
val ppx : Trm.Var.strength -> t pp
val pp : t pp
val pp_diff : (t * t) pp
end = struct
type t = Trm.t list [@@deriving equal]
let empty = []
let of_ e = [e]
let add = List.cons
let remove = List.remove ~eq:Trm.equal
let union = List.rev_append
let is_empty = List.is_empty
let pop = function [] -> None | x :: xs -> Some (x, xs)
let filter = List.filter
let partition = List.partition
let fold = List.fold
let to_iter = List.to_iter
let to_set = Trm.Set.of_list
let sort = List.sort ~cmp:Trm.compare
let ppx x = List.pp "@ = " (Trm.ppx x)
let pp = ppx (fun _ -> None)
let pp_diff = List.pp_diff ~cmp:Trm.compare "@ = " Trm.pp
end
(** Equality Relations *)
(** see also [invariant] *)
@ -430,11 +472,13 @@ type t =
let classes r =
Subst.fold r.rep Trm.Map.empty ~f:(fun ~key:elt ~data:rep cls ->
if Trm.equal elt rep then cls
else Trm.Map.add_multi ~key:rep ~data:elt cls )
else
Trm.Map.update rep cls ~f:(fun cls0 ->
Cls.add elt (Option.value cls0 ~default:Cls.empty) ) )
let cls_of r e =
let e' = Subst.apply r.rep e in
Trm.Map.find e' (classes r) |> Option.value ~default:[e']
Trm.Map.find e' (classes r) |> Option.value ~default:(Cls.of_ e')
(** Pretty-printing *)
@ -461,22 +505,16 @@ let pp_diff fs (r, s) =
in
Format.fprintf fs "@[{@[<hv>%t%t@]}@]" pp_sat pp_rep
let ppx_cls x = List.pp "@ = " (Trm.ppx x)
let pp_cls = ppx_cls (fun _ -> None)
let pp_diff_cls = List.pp_diff ~cmp:Trm.compare "@ = " Trm.pp
let ppx_classes x fs clss =
List.pp "@ @<2>∧ "
(fun fs (rep, cls) ->
if not (List.is_empty cls) then
Format.fprintf fs "@[%a@ = %a@]" (Trm.ppx x) rep (ppx_cls x) cls )
if not (Cls.is_empty cls) then
Format.fprintf fs "@[%a@ = %a@]" (Trm.ppx x) rep (Cls.ppx x) cls )
fs
(Iter.to_list (Trm.Map.to_iter clss))
let pp_classes fs r = ppx_classes (fun _ -> None) fs (classes r)
let pp_diff_clss =
Trm.Map.pp_diff ~eq:(List.equal Trm.equal) Trm.pp pp_cls pp_diff_cls
let pp_diff_clss = Trm.Map.pp_diff ~eq:Cls.equal Trm.pp Cls.pp Cls.pp_diff
let pp fs r =
let clss = classes r in
@ -486,7 +524,7 @@ let pp fs r =
let ppx var_strength fs clss noneqs =
let without_anon_vars =
List.filter ~f:(fun e ->
Cls.filter ~f:(fun e ->
match Var.of_trm e with
| Some v -> Poly.(var_strength v <> Some `Anonymous)
| None -> true )
@ -494,8 +532,8 @@ let ppx var_strength fs clss noneqs =
let clss =
Trm.Map.fold clss Trm.Map.empty ~f:(fun ~key:rep ~data:cls m ->
let cls = without_anon_vars cls in
if not (List.is_empty cls) then
Trm.Map.add ~key:rep ~data:(List.sort ~cmp:Trm.compare cls) m
if not (Cls.is_empty cls) then
Trm.Map.add ~key:rep ~data:(Cls.sort cls) m
else m )
in
let first = Trm.Map.is_empty clss in
@ -690,16 +728,15 @@ let normalize r e = Term.map_trms ~f:(canon r) e
let class_of r e =
match Term.get_trm (normalize r e) with
| Some e' ->
List.map ~f:Term.of_trm (e' :: Trm.Map.find_multi e' (classes r))
Iter.to_list (Iter.map ~f:Term.of_trm (Cls.to_iter (cls_of r e')))
| None -> []
let diff_classes r s =
Trm.Map.filter_mapi (classes r) ~f:(fun ~key:rep ~data:cls ->
match
List.filter cls ~f:(fun exp -> not (implies s (Fml.eq rep exp)))
with
| [] -> None
| cls -> Some cls )
let cls' =
Cls.filter cls ~f:(fun exp -> not (implies s (Fml.eq rep exp)))
in
if Cls.is_empty cls' then None else Some cls' )
let ppx_diff var_strength fs parent_ctx fml ctx =
let fml' = canon_f ctx fml in
@ -729,7 +766,7 @@ let apply_subst wrt s r =
Trm.Map.fold (classes r) {r with rep= Subst.empty}
~f:(fun ~key:rep ~data:cls r ->
let rep' = Subst.subst_ s rep in
List.fold cls r ~f:(fun trm r ->
Cls.fold cls r ~f:(fun trm r ->
let trm' = Subst.subst_ s trm in
and_eq_ ~wrt trm' rep' r ) ) )
|> extract_xs
@ -762,7 +799,7 @@ let inter wrt r s =
else
let merge_mems rs r s =
Trm.Map.fold (classes s) rs ~f:(fun ~key:rep ~data:cls rs ->
List.fold cls
Cls.fold cls
([rep], rs)
~f:(fun exp (reps, rs) ->
match
@ -906,10 +943,10 @@ let solve_seq_eq ~wrt us e' f' subst =
let solve_interp_eq ~wrt us e' (cls, subst) =
[%Trace.call fun {pf} ->
pf "@ trm: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp e' pp_cls cls
pf "@ trm: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp e' Cls.pp cls
Subst.pp subst]
;
List.find_map cls ~f:(fun f ->
Iter.find_map (Cls.to_iter cls) ~f:(fun f ->
let f' = Subst.norm subst f in
match solve_seq_eq ~wrt us e' f' subst with
| Some subst -> Some subst
@ -925,32 +962,32 @@ let solve_interp_eq ~wrt us e' (cls, subst) =
[fv u us xs] *)
let rec solve_interp_eqs ~wrt us (cls, subst) =
[%Trace.call fun {pf} ->
pf "@ cls: @[%a@]@ subst: @[%a@]" pp_cls cls Subst.pp subst]
pf "@ cls: @[%a@]@ subst: @[%a@]" Cls.pp cls Subst.pp subst]
;
let rec solve_interp_eqs_ cls' (cls, subst) =
match cls with
| [] -> (cls', subst)
| trm :: cls ->
match Cls.pop cls with
| None -> (cls', subst)
| Some (trm, cls) ->
let trm' = Subst.norm subst trm in
if is_interpreted trm' then
match solve_interp_eq ~wrt us trm' (cls, subst) with
| Some subst -> solve_interp_eqs_ cls' (cls, subst)
| None -> solve_interp_eqs_ (trm' :: cls') (cls, subst)
else solve_interp_eqs_ (trm' :: cls') (cls, subst)
| None -> solve_interp_eqs_ (Cls.add trm' cls') (cls, subst)
else solve_interp_eqs_ (Cls.add trm' cls') (cls, subst)
in
let cls', subst' = solve_interp_eqs_ [] (cls, subst) in
let cls', subst' = solve_interp_eqs_ Cls.empty (cls, subst) in
( if subst' != subst then solve_interp_eqs ~wrt us (cls', subst')
else (cls', subst') )
|>
[%Trace.retn fun {pf} (cls', subst') ->
pf "cls: @[%a@]@ subst: @[%a@]" pp_diff_cls (cls, cls') Subst.pp_diff
pf "cls: @[%a@]@ subst: @[%a@]" Cls.pp_diff (cls, cls') Subst.pp_diff
(subst, subst')]
type cls_solve_state =
{ rep_us: Trm.t option (** rep, that is ito us, for class *)
; cls_us: Trm.t list (** cls that is ito us, or interpreted *)
; cls_us: Cls.t (** cls that is ito us, or interpreted *)
; rep_xs: Trm.t option (** rep, that is *not* ito us, for class *)
; cls_xs: Trm.t list (** cls that is *not* ito us *) }
; cls_xs: Cls.t (** cls that is *not* ito us *) }
let dom_trm e =
match (e : Trm.t) with
@ -964,47 +1001,49 @@ let dom_trm e =
[fv u us xs] *)
let solve_uninterp_eqs us (cls, subst) =
[%Trace.call fun {pf} ->
pf "@ cls: @[%a@]@ subst: @[%a@]" pp_cls cls Subst.pp subst]
pf "@ cls: @[%a@]@ subst: @[%a@]" Cls.pp cls Subst.pp subst]
;
let compare e f =
[%compare: kind * Trm.t] (classify e, e) (classify f, f)
in
let {rep_us; cls_us; rep_xs; cls_xs} =
List.fold cls {rep_us= None; cls_us= []; rep_xs= None; cls_xs= []}
Cls.fold cls
{rep_us= None; cls_us= Cls.empty; rep_xs= None; cls_xs= Cls.empty}
~f:(fun trm ({rep_us; cls_us; rep_xs; cls_xs} as s) ->
if Var.Set.subset (Trm.fv trm) ~of_:us then
match rep_us with
| Some rep when compare rep trm <= 0 ->
{s with cls_us= trm :: cls_us}
| Some rep -> {s with rep_us= Some trm; cls_us= rep :: cls_us}
{s with cls_us= Cls.add trm cls_us}
| Some rep -> {s with rep_us= Some trm; cls_us= Cls.add rep cls_us}
| None -> {s with rep_us= Some trm}
else
match rep_xs with
| Some rep -> (
if compare rep trm <= 0 then
match dom_trm trm with
| Some trm -> {s with cls_xs= trm :: cls_xs}
| None -> {s with cls_us= trm :: cls_us}
| Some trm -> {s with cls_xs= Cls.add trm cls_xs}
| None -> {s with cls_us= Cls.add trm cls_us}
else
match dom_trm rep with
| Some rep ->
{s with rep_xs= Some trm; cls_xs= rep :: cls_xs}
| None -> {s with rep_xs= Some trm; cls_us= rep :: cls_us} )
{s with rep_xs= Some trm; cls_xs= Cls.add rep cls_xs}
| None ->
{s with rep_xs= Some trm; cls_us= Cls.add rep cls_us} )
| None -> {s with rep_xs= Some trm} )
in
( match rep_us with
| Some rep_us ->
let cls = rep_us :: cls_us in
let cls = Cls.add rep_us cls_us in
let cls, cls_xs =
match rep_xs with
| Some rep -> (
match dom_trm rep with
| Some rep -> (cls, rep :: cls_xs)
| None -> (rep :: cls, cls_xs) )
| Some rep -> (cls, Cls.add rep cls_xs)
| None -> (Cls.add rep cls, cls_xs) )
| None -> (cls, cls_xs)
in
let subst =
List.fold cls_xs subst ~f:(fun trm_xs subst ->
Cls.fold cls_xs subst ~f:(fun trm_xs subst ->
let trm_xs = Subst.subst_ subst trm_xs in
let rep_us = Subst.subst_ subst rep_us in
Subst.compose1 ~key:trm_xs ~data:rep_us subst )
@ -1013,16 +1052,16 @@ let solve_uninterp_eqs us (cls, subst) =
| None -> (
match rep_xs with
| Some rep_xs ->
let cls = rep_xs :: cls_us in
let cls = Cls.add rep_xs cls_us in
let subst =
List.fold cls_xs subst ~f:(fun trm_xs subst ->
Cls.fold cls_xs subst ~f:(fun trm_xs subst ->
Subst.compose1 ~key:trm_xs ~data:rep_xs subst )
in
(cls, subst)
| None -> (cls, subst) ) )
|>
[%Trace.retn fun {pf} (cls', subst') ->
pf "cls: @[%a@]@ subst: @[%a@]" pp_diff_cls (cls, cls') Subst.pp_diff
pf "cls: @[%a@]@ subst: @[%a@]" Cls.pp_diff (cls, cls') Subst.pp_diff
(subst, subst') ;
subst_invariant us subst subst']
@ -1033,20 +1072,20 @@ let solve_uninterp_eqs us (cls, subst) =
let solve_class ~wrt us us_xs ~key:rep ~data:cls (classes, subst) =
let classes0 = classes in
[%Trace.call fun {pf} ->
pf "@ rep: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp rep pp_cls cls
pf "@ rep: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp rep Cls.pp cls
Subst.pp subst]
;
let cls, cls_not_ito_us_xs =
List.partition
Cls.partition
~f:(fun e -> Var.Set.subset (Trm.fv e) ~of_:us_xs)
(rep :: cls)
(Cls.add rep cls)
in
let cls, subst = solve_interp_eqs ~wrt us (cls, subst) in
let cls, subst = solve_uninterp_eqs us (cls, subst) in
let cls = List.rev_append cls_not_ito_us_xs cls in
let cls = List.remove ~eq:Trm.equal (Subst.norm subst rep) cls in
let cls = Cls.union cls_not_ito_us_xs cls in
let cls = Cls.remove (Subst.norm subst rep) cls in
let classes =
if List.is_empty cls then Trm.Map.remove rep classes
if Cls.is_empty cls then Trm.Map.remove rep classes
else Trm.Map.add ~key:rep ~data:cls classes
in
(classes, subst)
@ -1092,7 +1131,7 @@ let solve_concat_extracts r us x (classes, subst, us_xs) =
List.filter_map (solve_concat_extracts_eq r x) ~f:(fun rev_extracts ->
Iter.fold_opt (Iter.of_list rev_extracts) [] ~f:(fun e suffix ->
let+ rep_ito_us =
List.fold (cls_of r e) None ~f:(fun trm rep_ito_us ->
Cls.fold (cls_of r e) None ~f:(fun trm rep_ito_us ->
match rep_ito_us with
| Some rep when Trm.compare rep trm <= 0 -> rep_ito_us
| _ when Var.Set.subset (Trm.fv trm) ~of_:us -> Some trm
@ -1208,12 +1247,13 @@ let trim ks r =
in
let clss =
Trm.Set.fold reps (classes r) ~f:(fun rep clss ->
Trm.Map.add_multi ~key:rep ~data:rep clss )
Trm.Map.update rep clss ~f:(fun cls0 ->
Cls.add rep (Option.value cls0 ~default:Cls.empty) ) )
in
(* trim classes to those that intersect kills *)
let clss =
Trm.Map.filter_mapi clss ~f:(fun ~key:_ ~data:cls ->
let cls = Trm.Set.of_list cls in
let cls = Cls.to_set cls in
if Trm.Set.disjoint kills cls then None else Some cls )
in
(* enumerate affected classes and update solution subst *)

@ -42,6 +42,7 @@ module Set : sig
val t_of_sexp : Sexp.t -> t
val pp : t pp
val pp_diff : (t * t) pp
end
module Map : sig

Loading…
Cancel
Save