@ -56,25 +56,26 @@ module Subst : sig
val to_alist : t -> ( Term . t * Term . t ) list
val to_alist : t -> ( Term . t * Term . t ) list
val partition_valid : Var . Set . t -> t -> t * Var . Set . t * t
val partition_valid : Var . Set . t -> t -> t * Var . Set . t * t
end = struct
end = struct
type t = Term . t Term . Map . t [ @@ deriving compare , equal , sexp ]
type t = Term . t Term . Map . t [ @@ deriving compare , equal , sexp _of ]
let pp = Map . pp Term . pp Term . pp
let t_of_sexp = Term . Map . t_of_sexp Term . t_of_sexp Term . t_of_sexp
let pp = Term . Map . pp Term . pp Term . pp
let pp_diff =
let pp_diff =
Map. pp_diff ~ data_equal : Term . equal Term . pp Term . pp Term . pp_diff
Term. Map. pp_diff ~ data_equal : Term . equal Term . pp Term . pp Term . pp_diff
let empty = Term . Map . empty
let empty = Term . Map . empty
let is_empty = Map. is_empty
let is_empty = Term. Map. is_empty
let length = Map. length
let length = Term. Map. length
let mem = Map. mem
let mem = Term. Map. mem
let find = Map. find
let find = Term. Map. find
let fold = Map. fold
let fold = Term. Map. fold
let iteri = Map. iteri
let iteri = Term. Map. iteri
let for_alli = Map. for_alli
let for_alli = Term. Map. for_alli
let to_alist = Map. to_alist ~ key_order : ` Increasing
let to_alist = Term. Map. to_alist
(* * look up a term in a substitution *)
(* * look up a term in a substitution *)
let apply s a = Map. find s a | > Option . value ~ default : a
let apply s a = Term. Map. find s a | > Option . value ~ default : a
let rec subst s a = apply s ( Term . map ~ f : ( subst s ) a )
let rec subst s a = apply s ( Term . map ~ f : ( subst s ) a )
@ -87,21 +88,21 @@ end = struct
(* * compose two substitutions *)
(* * compose two substitutions *)
let compose r s =
let compose r s =
let r' = Map. map ~ f : ( norm s ) r in
let r' = Term. Map. map ~ f : ( norm s ) r in
Map. merge_skewed r' s ~ combine : ( fun ~ key v1 v2 ->
Term. Map. merge_skewed r' s ~ combine : ( fun ~ key v1 v2 ->
if Term . equal v1 v2 then v1
if Term . equal v1 v2 then v1
else fail " domains intersect: %a " Term . pp key () )
else fail " domains intersect: %a " Term . pp key () )
(* * compose a substitution with a mapping *)
(* * compose a substitution with a mapping *)
let compose1 ~ key ~ data s =
let compose1 ~ key ~ data s =
if Term . equal key data then s
if Term . equal key data then s
else compose s ( Map. set Term . Map . empty ~ key ~ data )
else compose s ( Term. Map. set Term . Map . empty ~ key ~ data )
(* * add an identity entry if the term is not already present *)
(* * add an identity entry if the term is not already present *)
let extend e s =
let extend e s =
let exception Found in
let exception Found in
match
match
Map. update s e ~ f : ( function
Term. Map. update s e ~ f : ( function
| Some _ -> Exn . raise_without_backtrace Found
| Some _ -> Exn . raise_without_backtrace Found
| None -> e )
| None -> e )
with
with
@ -112,12 +113,14 @@ end = struct
[ f ] is injective and for any set of terms [ E ] , [ f \ [ E \ ] ] is disjoint
[ f ] is injective and for any set of terms [ E ] , [ f \ [ E \ ] ] is disjoint
from [ E ] * )
from [ E ] * )
let map_entries ~ f s =
let map_entries ~ f s =
Map. fold s ~ init : s ~ f : ( fun ~ key ~ data s ->
Term. Map. fold s ~ init : s ~ f : ( fun ~ key ~ data s ->
let key' = f key in
let key' = f key in
let data' = f data in
let data' = f data in
if Term . equal key' key then
if Term . equal key' key then
if Term . equal data' data then s else Map . set s ~ key ~ data : data'
if Term . equal data' data then s
else Map . remove s key | > Map . add_exn ~ key : key' ~ data : data' )
else Term . Map . set s ~ key ~ data : data'
else Term . Map . remove s key | > Term . Map . add_exn ~ key : key' ~ data : data'
)
(* * Holds only if [true ⊢ ∃xs. e=f]. Clients assume
(* * Holds only if [true ⊢ ∃xs. e=f]. Clients assume
[ not ( is_valid_eq xs e f ) ] implies [ not ( is_valid_eq ys e f ) ] for
[ not ( is_valid_eq xs e f ) ] implies [ not ( is_valid_eq ys e f ) ] for
@ -141,12 +144,12 @@ end = struct
valid , so loop until no change . * )
valid , so loop until no change . * )
let rec partition_valid_ t ks s =
let rec partition_valid_ t ks s =
let t' , ks' , s' =
let t' , ks' , s' =
Map. fold s ~ init : ( t , ks , s ) ~ f : ( fun ~ key ~ data ( t , ks , s ) ->
Term. Map. fold s ~ init : ( t , ks , s ) ~ f : ( fun ~ key ~ data ( t , ks , s ) ->
if is_valid_eq ks key data then ( t , ks , s )
if is_valid_eq ks key data then ( t , ks , s )
else
else
let t = Map. set ~ key ~ data t
let t = Term. Map. set ~ key ~ data t
and ks = Set . diff ks ( Set . union ( Term . fv key ) ( Term . fv data ) )
and ks = Set . diff ks ( Set . union ( Term . fv key ) ( Term . fv data ) )
and s = Map. remove s key in
and s = Term. Map. remove s key in
( t , ks , s ) )
( t , ks , s ) )
in
in
if s' != s then partition_valid_ t' ks' s' else ( t' , ks' , s' )
if s' != s then partition_valid_ t' ks' s' else ( t' , ks' , s' )
@ -327,7 +330,7 @@ type t =
let classes r =
let classes r =
let add key data cls =
let add key data cls =
if Term . equal key data then cls
if Term . equal key data then cls
else Map. add_multi cls ~ key : data ~ data : key
else Term. Map. add_multi cls ~ key : data ~ data : key
in
in
Subst . fold r . rep ~ init : Term . Map . empty ~ f : ( fun ~ key ~ data cls ->
Subst . fold r . rep ~ init : Term . Map . empty ~ f : ( fun ~ key ~ data cls ->
match classify key with
match classify key with
@ -337,7 +340,7 @@ let classes r =
let cls_of r e =
let cls_of r e =
let e' = Subst . apply r . rep e in
let e' = Subst . apply r . rep e in
Map. find ( classes r ) e' | > Option . value ~ default : [ e' ]
Term. Map. find ( classes r ) e' | > Option . value ~ default : [ e' ]
(* * Pretty-printing *)
(* * Pretty-printing *)
@ -373,12 +376,13 @@ let ppx_clss x fs cs =
( fun fs ( key , data ) ->
( fun fs ( key , data ) ->
Format . fprintf fs " @[%a@ = %a@] " ( Term . ppx x ) key ( ppx_cls x )
Format . fprintf fs " @[%a@ = %a@] " ( Term . ppx x ) key ( ppx_cls x )
( List . sort ~ compare : Term . compare data ) )
( List . sort ~ compare : Term . compare data ) )
fs ( Map. to_alist cs )
fs ( Term. Map. to_alist cs )
let pp_clss fs cs = ppx_clss ( fun _ -> None ) fs cs
let pp_clss fs cs = ppx_clss ( fun _ -> None ) fs cs
let pp_diff_clss =
let pp_diff_clss =
Map . pp_diff ~ data_equal : ( List . equal Term . equal ) Term . pp pp_cls pp_diff_cls
Term . Map . pp_diff ~ data_equal : ( List . equal Term . equal ) Term . pp pp_cls
pp_diff_cls
(* * Invariant *)
(* * Invariant *)
@ -525,7 +529,7 @@ let normalize = canon
let class_of r e =
let class_of r e =
let e' = normalize r e in
let e' = normalize r e in
e' :: Map. find_multi ( classes r ) e'
e' :: Term. Map. find_multi ( classes r ) e'
let fold_uses_of r t ~ init ~ f =
let fold_uses_of r t ~ init ~ f =
let rec fold_ e ~ init : s ~ f =
let rec fold_ e ~ init : s ~ f =
@ -558,7 +562,7 @@ let difference r a b =
let apply_subst us s r =
let apply_subst us s r =
[ % Trace . call fun { pf } -> pf " %a@ %a " Subst . pp s pp r ]
[ % Trace . call fun { pf } -> pf " %a@ %a " Subst . pp s pp r ]
;
;
Map. fold ( classes r ) ~ init : true _ ~ f : ( fun ~ key : rep ~ data : cls r ->
Term. Map. fold ( classes r ) ~ init : true _ ~ f : ( fun ~ key : rep ~ data : cls r ->
let rep' = Subst . subst s rep in
let rep' = Subst . subst s rep in
List . fold cls ~ init : r ~ f : ( fun r trm ->
List . fold cls ~ init : r ~ f : ( fun r trm ->
let trm' = Subst . subst s trm in
let trm' = Subst . subst s trm in
@ -585,7 +589,7 @@ let or_ us r s =
else if not r . sat then s
else if not r . sat then s
else
else
let merge_mems rs r s =
let merge_mems rs r s =
Map. fold ( classes s ) ~ init : rs ~ f : ( fun ~ key : rep ~ data : cls rs ->
Term. Map. fold ( classes s ) ~ init : rs ~ f : ( fun ~ key : rep ~ data : cls rs ->
List . fold cls
List . fold cls
~ init : ( [ rep ] , rs )
~ init : ( [ rep ] , rs )
~ f : ( fun ( reps , rs ) exp ->
~ f : ( fun ( reps , rs ) exp ->
@ -651,7 +655,7 @@ let ppx_classes x fs r = ppx_clss x fs (classes r)
let ppx_classes_diff x fs ( r , s ) =
let ppx_classes_diff x fs ( r , s ) =
let clss = classes s in
let clss = classes s in
let clss =
let clss =
Map. filter_mapi clss ~ f : ( fun ~ key : rep ~ data : cls ->
Term. Map. filter_mapi clss ~ f : ( fun ~ key : rep ~ data : cls ->
match
match
List . filter cls ~ f : ( fun exp -> not ( entails_eq r rep exp ) )
List . filter cls ~ f : ( fun exp -> not ( entails_eq r rep exp ) )
with
with
@ -663,7 +667,7 @@ let ppx_classes_diff x fs (r, s) =
Format . fprintf fs " @[%a@ = %a@] " ( Term . ppx x ) rep
Format . fprintf fs " @[%a@ = %a@] " ( Term . ppx x ) rep
( List . pp " @ = " ( Term . ppx x ) )
( List . pp " @ = " ( Term . ppx x ) )
( List . dedup_and_sort ~ compare : Term . compare cls ) )
( List . dedup_and_sort ~ compare : Term . compare cls ) )
fs ( Map. to_alist clss )
fs ( Term. Map. to_alist clss )
(* * Existential Witnessing and Elimination *)
(* * Existential Witnessing and Elimination *)
@ -876,8 +880,8 @@ let solve_class us us_xs ~key:rep ~data:cls (classes, subst) =
| > Option . value ~ default : cls
| > Option . value ~ default : cls
in
in
let classes =
let classes =
if List . is_empty cls then Map. remove classes rep
if List . is_empty cls then Term. Map. remove classes rep
else Map. set classes ~ key : rep ~ data : cls
else Term. Map. set classes ~ key : rep ~ data : cls
in
in
( classes , subst )
( classes , subst )
| >
| >
@ -954,7 +958,8 @@ let solve_classes r (classes, subst, us) xs =
;
;
let rec solve_classes_ ( classes0 , subst0 , us_xs ) =
let rec solve_classes_ ( classes0 , subst0 , us_xs ) =
let classes , subst =
let classes , subst =
Map . fold ~ f : ( solve_class us us_xs ) classes0 ~ init : ( classes0 , subst0 )
Term . Map . fold ~ f : ( solve_class us us_xs ) classes0
~ init : ( classes0 , subst0 )
in
in
if subst != subst0 then solve_classes_ ( classes , subst , us_xs )
if subst != subst0 then solve_classes_ ( classes , subst , us_xs )
else ( classes , subst , us_xs )
else ( classes , subst , us_xs )