@ -12,7 +12,7 @@ open Exp
(* Solution Substitutions ================================================= *)
module Subst : sig
type t [@@ deriving compare , equal , sexp ]
type t = Trm . t Trm . Map . t [@@ deriving compare , equal , sexp ]
val pp : t pp
val pp_diff : ( t * t ) pp
@ -34,6 +34,7 @@ module Subst : sig
val extend : Trm . t -> t -> t option
val map_entries : f : ( Trm . t -> Trm . t ) -> t -> t
val to_iter : t -> ( Trm . t * Trm . t ) iter
val to_list : t -> ( Trm . t * Trm . t ) list
val partition_valid : Var . Set . t -> t -> t * Var . Set . t * t
(* direct representation manipulation *)
@ -58,6 +59,7 @@ end = struct
let iteri = Trm . Map . iteri
let for_alli = Trm . Map . for_alli
let to_iter = Trm . Map . to_iter
let to_list = Trm . Map . to_list
(* * look up a term in a substitution *)
let apply s a = Trm . Map . find a s | > Option . value ~ default : a
@ -92,7 +94,8 @@ end = struct
else (
assert (
Option . for_all ~ f : ( Trm . equal key ) ( Trm . Map . find key r )
| | fail " domains intersect: %a " Trm . pp key () ) ;
| | fail " domains intersect: %a ↦ %a in %a " Trm . pp key Trm . pp data
pp r () ) ;
let s = Trm . Map . singleton key data in
let r' = Trm . Map . map_endo ~ f : ( norm s ) r in
Trm . Map . add ~ key ~ data r' )
@ -119,7 +122,7 @@ end = struct
if Trm . equal data' data then s else Trm . Map . add ~ key ~ data : data' s
else
let s = Trm . Map . remove key s in
match ( key : Trm . t ) with
match ( key ' : Trm . t ) with
| Z _ | Q _ -> s
| _ -> Trm . Map . add_exn ~ key : key' ~ data : data' s )
@ -178,7 +181,7 @@ end
(* Equality classes ======================================================= *)
module Cls : sig
type t [ @@ deriving equal]
type t [ @@ deriving compare, equal, sexp ]
val empty : t
val of_ : Trm . t -> t
@ -190,14 +193,16 @@ module Cls : sig
val filter : t -> f : ( Trm . t -> bool ) -> t
val partition : t -> f : ( Trm . t -> bool ) -> t * t
val fold : t -> ' s -> f : ( Trm . t -> ' s -> ' s ) -> ' s
val map : t -> f : ( Trm . t -> Trm . t ) -> t
val to_iter : t -> Trm . t iter
val to_set : t -> Trm . Set . t
val sort : t -> t
val of_set : Trm . Set . t -> t
val ppx : Trm . Var . strength -> t pp
val pp : t pp
val pp_raw : t pp
val pp_diff : ( t * t ) pp
end = struct
type t = Trm . t list [ @@ deriving equal]
type t = Trm . t list [ @@ deriving compare, equal, sexp ]
let empty = []
let of_ e = [ e ]
@ -209,11 +214,19 @@ end = struct
let filter = List . filter
let partition = List . partition
let fold = List . fold
let map = List . map_endo
let to_iter = List . to_iter
let to_set = Trm . Set . of_list
let sort = List . sort ~ cmp : Trm . compare
let ppx x = List . pp " @ = " ( Trm . ppx x )
let of_set s = Iter . to_list ( Trm . Set . to_iter s )
let ppx x fs es =
List . pp " @ = " ( Trm . ppx x ) fs ( List . sort_uniq ~ cmp : Trm . compare es )
let pp = ppx ( fun _ -> None )
let pp_raw fs es =
Trm . Set . pp_full ~ pre : " {@[ " ~ suf : " @]} " ~ sep : " ,@ " Trm . pp fs ( to_set es )
let pp_diff = List . pp_diff ~ cmp : Trm . compare " @ = " Trm . pp
end
@ -222,49 +235,41 @@ end
(* * see also [invariant] *)
type t =
{ xs : Var . Set . t
(* * existential variables that did not appear in input equation s *)
(* * existential variables that did not appear in input formula s *)
; sat : bool (* * [false] only if constraints are inconsistent *)
; rep : Subst . t
(* * functional set of oriented equations: map [a] to [a'],
indicating that [ a = a' ] holds , and that [ a' ] is the
' rep ( resentative ) ' of [ a ] * )
; cls : Cls . t Trm . Map . t
(* * map each representative to the set of terms in its class *)
; pnd : ( Trm . t * Trm . t ) list
(* * pending equations to add ( once invariants are reestablished ) *)
}
[ @@ deriving compare , equal , sexp ]
let classes r =
Subst . fold r . rep Trm . Map . empty ~ f : ( fun ~ key : elt ~ data : rep cls ->
if Trm . equal elt rep then cls
else
Trm . Map . update rep cls ~ f : ( fun cls0 ->
Cls . add elt ( Option . value cls0 ~ default : Cls . empty ) ) )
let cls_of r e =
let e' = Subst . apply r . rep e in
Trm . Map . find e' ( classes r ) | > Option . value ~ default : ( Cls . of_ e' )
(* Pretty-printing ======================================================== *)
let pp_eq fs ( e , f ) = Format . fprintf fs " @[%a = %a@] " Trm . pp e Trm . pp f
let pp_pnd = List . pp " ;@ " pp_eq
let pp_raw fs { sat ; rep ; pnd} =
let pp_raw fs { sat ; rep ; cls ; pnd } =
let pp_alist pp_k pp_v fs alist =
let pp_assoc fs ( k , v ) =
Format . fprintf fs " [@[%a@ @<2>↦ %a@]] " pp_k k pp_v ( k , v )
in
Format . fprintf fs " [@[<hv>%a@]] " ( List . pp " ;@ " pp_assoc ) alist
in
let pp_term_v fs ( k , v ) = if not ( Trm . equal k v ) then Trm . pp fs v in
let pp_trm_v fs ( k , v ) = if not ( Trm . equal k v ) then Trm . pp fs v in
let pp_cls_v fs ( _ , cls ) = Cls . pp_raw fs cls in
let pp_pnd fs pnd =
if not ( List . is_empty pnd ) then
Format . fprintf fs " ;@ pnd= @[%a@] " pp_pnd pnd
Format . fprintf fs " ;@ pnd= @[%a@] " ( List . pp " ;@ " pp_eq ) pnd
in
Format . fprintf fs " @[{@[<hv>sat= %b;@ rep= %a%a@]}@] " sat
( pp_alist Trm . pp pp_term_v )
( Iter . to_list ( Subst . to_iter rep ) )
pp_pnd pnd
Format . fprintf fs " @[{@[<hv>sat= %b;@ rep= %a;@ cls= %a%a@]}@] " sat
( pp_alist Trm . pp pp_trm_v )
( Subst . to_list rep )
( pp_alist Trm . pp pp_cls_v )
( Trm . Map . to_list cls ) pp_pnd pnd
let pp_diff fs ( r , s ) =
let pp_sat fs =
@ -275,33 +280,35 @@ let pp_diff fs (r, s) =
if not ( Subst . is_empty r . rep ) then
Format . fprintf fs " rep= %a;@ " Subst . pp_diff ( r . rep , s . rep )
in
let pp_cls fs =
if not ( Trm . Map . equal Cls . equal r . cls s . cls ) then
Format . fprintf fs " cls= %a;@ "
( Trm . Map . pp_diff ~ eq : Cls . equal Trm . pp Cls . pp_raw Cls . pp_diff )
( r . cls , s . cls )
in
let pp_pnd fs =
Format . fprintf fs " pnd= @[%a@] "
( List . pp_diff ~ cmp : [ % compare : Trm . t * Trm . t ] " ;@ " pp_eq )
( r . pnd , s . pnd )
List . pp_diff ~ cmp : [ % compare : Trm . t * Trm . t ] ~ pre : " pnd= @[ " ~ suf : " @] "
" ;@ " pp_eq fs ( r . pnd , s . pnd )
in
Format . fprintf fs " @[{@[<hv>%t%t%t@]}@] " pp_sat pp_rep pp_pnd
Format . fprintf fs " @[{@[<hv>%t%t%t %t @]}@]" pp_sat pp_rep pp_cls pp_pnd
let ppx_classes x fs clss =
List . pp " @ @<2>∧ "
( fun fs ( rep , cls ) ->
if not ( Cls . is_empty cls ) then
Format . fprintf fs " @[%a@ = %a@] " ( Trm . ppx x ) rep ( Cls . ppx x ) cls )
fs
( Iter . to_list ( Trm . Map . to_iter clss ) )
fs ( Trm . Map . to_list clss )
let pp_classes fs r = ppx_classes ( fun _ -> None ) fs ( classes r )
let pp_diff_clss = Trm . Map . pp_diff ~ eq : Cls . equal Trm . pp Cls . pp Cls . pp_diff
let pp_classes fs r = ppx_classes ( fun _ -> None ) fs r . cls
let pp fs r =
let clss = classes r in
if Trm . Map . is_empty clss then
if Trm . Map . is_empty r . cls then
Format . fprintf fs ( if r . sat then " tt " else " ff " )
else pp x_classes ( fun _ -> None ) fs clss
else pp _classes fs r
let ppx var_strength fs clss noneqs =
let without_anon_vars =
Cls . filter ~ f : ( fun e ->
let without_anon_vars es =
Cls . filter es ~ f : ( fun e ->
match Var . of_trm e with
| Some v -> Poly . ( var_strength v < > Some ` Anonymous )
| None -> true )
@ -309,8 +316,7 @@ let ppx var_strength fs clss noneqs =
let clss =
Trm . Map . fold clss Trm . Map . empty ~ f : ( fun ~ key : rep ~ data : cls m ->
let cls = without_anon_vars cls in
if not ( Cls . is_empty cls ) then
Trm . Map . add ~ key : rep ~ data : ( Cls . sort cls ) m
if not ( Cls . is_empty cls ) then Trm . Map . add ~ key : rep ~ data : cls m
else m )
in
let first = Trm . Map . is_empty clss in
@ -321,6 +327,8 @@ let ppx var_strength fs clss noneqs =
" @ @<2>∧ " ( Fml . ppx var_strength ) fs noneqs ~ suf : " @] " ;
first && List . is_empty noneqs
let pp_diff_cls = Trm . Map . pp_diff ~ eq : Cls . equal Trm . pp Cls . pp Cls . pp_diff
(* Basic queries ========================================================== *)
(* * test membership in carrier *)
@ -355,12 +363,28 @@ let pre_invariant r =
let rep' = Subst . norm r . rep rep in
Trm . equal rep rep'
| | fail " not idempotent: %a != %a in@ %a " Trm . pp rep Trm . pp rep'
Subst . pp r . rep () ) )
Subst . pp r . rep () ) ;
(* every term is in the class of its rep *)
assert (
Trm . equal trm rep
| | Trm . Set . mem trm
( Cls . to_set
( Trm . Map . find rep r . cls | > Option . value ~ default : Cls . empty ) )
| | fail " %a not in cls of %a = {%a}@ %a " Trm . pp trm Trm . pp rep
Cls . pp
( Trm . Map . find rep r . cls | > Option . value ~ default : Cls . empty )
pp_raw r () ) ) ;
Trm . Map . iteri r . cls ~ f : ( fun ~ key : rep ~ data : cls ->
(* each class does not include its rep *)
assert ( not ( Trm . Set . mem rep ( Cls . to_set cls ) ) ) ;
(* representative of every element of [rep]'s class is [rep] *)
Iter . iter ( Cls . to_iter cls ) ~ f : ( fun elt ->
assert ( Option . exists ~ f : ( Trm . equal rep ) ( Subst . find elt r . rep ) ) ) )
let invariant r =
let @ () = Invariant . invariant [ % here ] r [ % sexp_of : t ] in
pre_invariant r ;
assert ( List . is_empty r . pnd ) ;
pre_invariant r ;
assert (
( not r . sat )
| | Subst . for_alli r . rep ~ f : ( fun ~ key : a ~ data : a' ->
@ -371,21 +395,67 @@ let invariant r =
| | fail " not congruent %a@ %a@ in@ %a " Trm . pp a Trm . pp b pp r
() ) ) )
(* Representation helpers ================================================= *)
(* Extending the carrier ================================================== *)
let rec extend_ a s =
[ % trace ]
~ call : ( fun { pf } -> pf " @ %a@ %a " Trm . pp a Subst . pp s )
~ retn : ( fun { pf } s' -> pf " %a " Subst . pp_diff ( s , s' ) )
@@ fun () ->
match ( a : Trm . t ) with
| Z _ | Q _ -> s
| _ -> (
if Theory . is_interpreted a then Iter . fold ~ f : extend_ ( Trm . trms a ) s
else
(* add uninterpreted terms *)
match Subst . extend a s with
(* and their subterms if newly added *)
| Some s -> Iter . fold ~ f : extend_ ( Trm . trms a ) s
| None -> s )
let add_to_pnd a a' x =
if Trm . equal a a' then x else { x with pnd = ( a , a' ) :: x . pnd }
(* * add a term to the carrier *)
let extend a x =
[ % trace ]
~ call : ( fun { pf } -> pf " @ %a@ %a " Trm . pp a pp x )
~ retn : ( fun { pf } x' ->
pf " %a " pp_diff ( x , x' ) ;
pre_invariant x' )
@@ fun () ->
let rep = extend_ a x . rep in
if rep = = x . rep then x else { x with rep }
(* Propagation ============================================================ *)
let propagate1 ( trm , rep ) x =
(* * add a=a' to x using a' as the representative *)
let propagate1 ( a , a' ) x =
[ % trace ]
~ call : ( fun { pf } ->
pf " @ @[%a ↦ %a@]@ %a " Trm . pp trm Trm . pp rep pp_raw x )
~ call : ( fun { pf } -> pf " @ @[%a ↦ %a@]@ %a " Trm . pp a Trm . pp a' pp_raw x )
~ retn : ( fun { pf } -> pf " %a " pp_raw )
@@ fun () ->
let rep = Subst . compose1 ~ key : trm ~ data : rep x . rep in
{ x with rep }
(* pending equations need not be between terms in the carrier *)
let x = extend a ( extend a' x ) in
let s = Trm . Map . singleton a a' in
Trm . Map . fold x . rep x ~ f : ( fun ~ key : _ ~ data : b0' x ->
let b' = Subst . norm s b0' in
if b' = = b0' then x
else
let b0'_cls , cls =
Trm . Map . find_and_remove b0' x . cls
| > Option . value ~ default : ( Cls . empty , x . cls )
in
let b0'_cls , pnd =
if Theory . is_interpreted b0' then ( b0'_cls , ( b0' , b' ) :: x . pnd )
else ( Cls . add b0' b0'_cls , x . pnd )
in
let rep =
Cls . fold b0'_cls x . rep ~ f : ( fun c rep ->
Trm . Map . add ~ key : c ~ data : b' rep )
in
let cls =
Trm . Map . update b' cls ~ f : ( fun b'_cls ->
Cls . union b0'_cls ( Option . value b'_cls ~ default : Cls . empty ) )
in
{ x with rep ; cls ; pnd } )
let solve ~ wrt ~ xs d e pending =
[ % trace ]
@ -416,7 +486,8 @@ let rec propagate ~wrt x =
let empty =
let rep = Subst . empty in
{ xs = Var . Set . empty ; sat = true ; rep ; pnd = [] } | > check invariant
{ xs = Var . Set . empty ; sat = true ; rep ; cls = Trm . Map . empty ; pnd = [] }
| > check invariant
let unsat = { empty with sat = false }
@ -453,30 +524,15 @@ let canon_f r b =
~ retn : ( fun { pf } -> pf " %a " Fml . pp )
@@ fun () -> Fml . map_trms ~ f : ( canon r ) b
let rec extend_ a r =
match ( a : Trm . t ) with
| Z _ | Q _ -> r
| _ -> (
if Theory . is_interpreted a then Iter . fold ~ f : extend_ ( Trm . trms a ) r
else
(* add uninterpreted terms *)
match Subst . extend a r with
(* and their subterms if newly added *)
| Some r -> Iter . fold ~ f : extend_ ( Trm . trms a ) r
| None -> r )
(* * add a term to the carrier *)
let extend a r =
let rep = extend_ a r . rep in
if rep = = r . rep then r else { r with rep } | > check pre_invariant
let merge ~ wrt a b x =
[ % trace ]
~ call : ( fun { pf } -> pf " @ %a@ %a@ %a " Trm . pp a Trm . pp b pp x )
~ retn : ( fun { pf } x' ->
pf " %a " pp_diff ( x , x' ) ;
pre_invariant x' )
@@ fun () -> propagate ~ wrt ( add_to_pnd a b x )
@@ fun () ->
let x = { x with pnd = ( a , b ) :: x . pnd } in
propagate ~ wrt x
(* * find an unproved equation between congruent terms *)
let find_missing r =
@ -510,21 +566,25 @@ let close ~wrt r =
pf " %a " pp_diff ( r , r' ) ;
invariant r' ]
let and_eq_ ~ wrt a b r =
if not r . sat then r
let and_eq_ ~ wrt a b x =
[ % trace ]
~ call : ( fun { pf } -> pf " @ @[%a = %a@]@ %a " Trm . pp a Trm . pp b pp x )
~ retn : ( fun { pf } x' ->
pf " %a " pp_diff ( x , x' ) ;
invariant x' )
@@ fun () ->
if not x . sat then x
else
let r0 = r in
let a' = canon r a in
let b' = canon r b in
let r = extend a' r in
let r = extend b' r in
if Trm . equal a' b' then r
let x0 = x in
let a' = canon x a in
let b' = canon x b in
if Trm . equal a' b' then extend a' ( extend b' x )
else
let r = merge ~ wrt a' b' r in
let x = merge ~ wrt a' b' x in
match ( a , b ) with
| ( Var _ as v ) , _ when not ( in_car r0 v ) -> r
| _ , ( Var _ as v ) when not ( in_car r0 v ) -> r
| _ -> close ~ wrt r
| ( Var _ as v ) , _ when not ( in_car x0 v ) -> x
| _ , ( Var _ as v ) when not ( in_car x0 v ) -> x
| _ -> close ~ wrt x
let extract_xs r = ( r . xs , { r with xs = Var . Set . empty } )
@ -545,6 +605,10 @@ let implies r b =
let refutes r b = Fml . equal Fml . ff ( canon_f r b )
let normalize r e = Term . map_trms ~ f : ( canon r ) e
let cls_of r e =
let e' = Subst . apply r . rep e in
Trm . Map . find e' r . cls | > Option . value ~ default : ( Cls . of_ e' )
let class_of r e =
match Term . get_trm ( normalize r e ) with
| Some e' ->
@ -552,7 +616,7 @@ let class_of r e =
| None -> []
let diff_classes r s =
Trm . Map . filter_mapi ( classes r ) ~ f : ( fun ~ key : rep ~ data : cls ->
Trm . Map . filter_mapi r . cls ~ f : ( fun ~ key : rep ~ data : cls ->
let cls' =
Cls . filter cls ~ f : ( fun exp -> not ( implies s ( Fml . eq rep exp ) ) )
in
@ -584,7 +648,7 @@ let apply_subst wrt s r =
;
( if Subst . is_empty s then r
else
Trm . Map . fold ( classes r ) { r with rep = Subst . empty }
Trm . Map . fold r . cls { r with rep = Subst . empty ; cls = Trm . Map . empty }
~ f : ( fun ~ key : rep ~ data : cls r ->
let rep' = Subst . subst_ s rep in
Cls . fold cls r ~ f : ( fun trm r ->
@ -619,7 +683,7 @@ let inter wrt r s =
else if not r . sat then s
else
let merge_mems rs r s =
Trm . Map . fold ( classes s ) rs ~ f : ( fun ~ key : rep ~ data : cls rs ->
Trm . Map . fold s . cls rs ~ f : ( fun ~ key : rep ~ data : cls rs ->
Cls . fold cls
( [ rep ] , rs )
~ f : ( fun exp ( reps , rs ) ->
@ -673,17 +737,23 @@ let dnf f =
let bot = Iter . empty in
Fml . fold_dnf ~ meet1 ~ join1 ~ top ~ bot f
let rename r sub =
[ % Trace . call fun { pf } -> pf " @ @[%a@]@ %a " Var . Subst . pp sub pp r ]
;
let rep =
Subst . map_entries ~ f : ( Trm . map_vars ~ f : ( Var . Subst . apply sub ) ) r . rep
let rename x sub =
[ % trace ]
~ call : ( fun { pf } -> pf " @ @[%a@]@ %a " Var . Subst . pp sub pp x )
~ retn : ( fun { pf } x' ->
pf " %a " pp_diff ( x , x' ) ;
invariant x' )
@@ fun () ->
let apply_sub = Trm . map_vars ~ f : ( Var . Subst . apply sub ) in
let rep = Subst . map_entries ~ f : apply_sub x . rep in
let cls =
Trm . Map . fold x . cls x . cls ~ f : ( fun ~ key : a0' ~ data : a0'_cls cls ->
let a' = apply_sub a0' in
let a'_cls = Cls . map ~ f : apply_sub a0'_cls in
Trm . Map . add ~ key : a' ~ data : a'_cls
( if a' = = a0' then cls else Trm . Map . remove a0' cls ) )
in
( if rep = = r . rep then r else { r with rep } )
| >
[ % Trace . retn fun { pf } r' ->
pf " %a " pp_diff ( r , r' ) ;
invariant r' ]
if rep = = x . rep && cls = = x . cls then x else { x with rep ; cls }
let trms r =
Iter . flat_map ~ f : ( fun ( k , v ) -> Iter . doubleton k v ) ( Subst . to_iter r . rep )
@ -904,6 +974,8 @@ let solve_uninterp_eqs us (cls, subst) =
let cls = Cls . add rep_xs cls_us in
let subst =
Cls . fold cls_xs subst ~ f : ( fun trm_xs subst ->
let trm_xs = Subst . subst_ subst trm_xs in
let rep_xs = Subst . subst_ subst rep_xs in
Subst . compose1 ~ key : trm_xs ~ data : rep_xs subst )
in
( cls , subst )
@ -941,7 +1013,7 @@ let solve_class us us_xs ~key:rep ~data:cls (classes, subst) =
| >
[ % Trace . retn fun { pf } ( classes' , subst' ) ->
pf " subst: @[%a@]@ classes: @[%a@] " Subst . pp_diff ( subst , subst' )
pp_diff_cls s ( classes0 , classes' ) ]
pp_diff_cls ( classes0 , classes' ) ]
let solve_concat_extracts_eq r x =
[ % Trace . call fun { pf } -> pf " @ %a@ %a " Trm . pp x pp r ]
@ -1022,7 +1094,7 @@ let solve_classes r xs (classes, subst, us) =
| >
[ % Trace . retn fun { pf } ( classes' , subst' , _ ) ->
pf " subst: @[%a@]@ classes: @[%a@] " Subst . pp_diff ( subst , subst' )
pp_diff_cls s ( classes , classes' ) ]
pp_diff_cls ( classes , classes' ) ]
let pp_vss fs vss =
Format . fprintf fs " [@[%a@]] "
@ -1042,7 +1114,7 @@ let solve_for_vars vss r =
let us , vss =
match vss with us :: vss -> ( us , vss ) | [] -> ( Var . Set . empty , vss )
in
List . fold ~ f : ( solve_classes r ) vss ( cla sses r , Subst . empty , us ) | > snd3
List . fold ~ f : ( solve_classes r ) vss ( r. cls, Subst . empty , us ) | > snd3
| >
[ % Trace . retn fun { pf } subst ->
pf " %a " Subst . pp subst ;
@ -1079,57 +1151,73 @@ let trivial vs r =
Var . Set . add v ks
| _ -> ks )
let trim ks r =
let trim ks x =
[ % trace ]
~ call : ( fun { pf } -> pf " @ %a@ %a " Var . Set . pp_xs ks pp_raw r )
~ retn : ( fun { pf } r' ->
pf " %a " pp_raw r' ;
assert ( Var . Set . disjoint ks ( fv r' ) ) )
~ call : ( fun { pf } -> pf " @ %a@ %a " Var . Set . pp_xs ks pp_raw x )
~ retn : ( fun { pf } x' ->
pf " %a " pp_raw x' ;
invariant x' ;
assert ( Var . Set . disjoint ks ( fv x' ) ) )
@@ fun () ->
let kills = Trm . Set . of_iter ( Iter . map ~ f : Trm . var ( Var . Set . to_iter ks ) ) in
(* compute classes including reps *)
(* expand classes to include reps *)
let reps =
Subst . fold r . rep Trm . Set . empty ~ f : ( fun ~ key : _ ~ data : rep reps ->
Subst . fold x . rep Trm . Set . empty ~ f : ( fun ~ key : _ ~ data : rep reps ->
Trm . Set . add rep reps )
in
let clss =
Trm . Set . fold reps ( classes r ) ~ f : ( fun rep clss ->
Trm . Set . fold reps x . cls ~ f : ( fun rep clss ->
Trm . Map . update rep clss ~ f : ( fun cls0 ->
Cls . add rep ( Option . value cls0 ~ default : Cls . empty ) ) )
in
(* trim classes to those that intersect kills *)
let clss =
Trm . Map . filter_mapi clss ~ f : ( fun ~ key : _ ~ data : cls ->
let cls = Cls . to_set cls in
if Trm . Set . disjoint kills cls then None else Some cls )
in
(* enumerate affected classes and update solution subst *)
let rep =
Trm . Map . fold clss r . rep ~ f : ( fun ~ key : rep ~ data : cls s ->
(* remove mappings for non-rep class elements to kill *)
let drop = Trm . Set . inter cls kills in
let s = Trm . Set . fold ~ f : Subst . remove drop s in
if not ( Trm . Set . mem rep kills ) then s
(* enumerate expanded classes and update solution subst *)
let kills = Trm . Set . of_vars ks in
Trm . Map . fold clss x ~ f : ( fun ~ key : a' ~ data : ecls x ->
(* remove mappings for non-rep class elements to kill *)
let keep , drop = Trm . Set . diff_inter ( Cls . to_set ecls ) kills in
if Trm . Set . is_empty drop then x
else
let rep = Trm . Set . fold ~ f : Subst . remove drop x . rep in
let x = { x with rep } in
(* new class is keepers without rep *)
let keep' = Trm . Set . remove a' keep in
let ecls = Cls . of_set keep' in
if keep' != keep then
(* a' is to be kept: continue to use it as rep *)
let cls =
if Cls . is_empty ecls then Trm . Map . remove a' x . cls
else Trm . Map . add ~ key : a' ~ data : ecls x . cls
in
{ x with cls }
else
(* if rep is to be removed, choose new one from the keepers *)
let keep = Trm . Set . diff cls drop in
(* a' is to be removed: choose new rep from the keepers *)
let cls = Trm . Map . remove a' x . cls in
let x = { x with cls } in
match
Trm . Set . reduce keep ~ f : ( fun x y ->
if Theory . prefer x y < 0 then x else y )
with
| Some rep ' ->
| Some b ' ->
(* add mappings from each keeper to the new representative *)
Trm . Set . fold keep s ~ f : ( fun elt s ->
Subst . add ~ key : elt ~ data : rep' s )
| None -> s )
in
{ r with rep }
let rep =
Trm . Set . fold keep x . rep ~ f : ( fun elt rep ->
Subst . add ~ key : elt ~ data : b' rep )
in
(* add trimmed class to new rep *)
let cls =
if Cls . is_empty ecls then x . cls
else Trm . Map . add ~ key : b' ~ data : ecls x . cls
in
{ x with rep ; cls }
| None ->
(* entire class removed *)
x )
let apply_and_elim ~ wrt xs s r =
[ % trace ]
~ call : ( fun { pf } -> pf " @ %a%a@ %a " Var . Set . pp_xs xs Subst . pp s pp_raw r )
~ retn : ( fun { pf } ( zs , r' , ks ) ->
pf " %a@ %a@ %a " Var . Set . pp_xs zs pp_raw r' Var . Set . pp_xs ks ;
invariant r' ;
assert ( Var . Set . subset ks ~ of_ : xs ) ;
assert ( Var . Set . disjoint ks ( fv r' ) ) )
@@ fun () ->