@ -38,8 +38,8 @@ module Subst : sig
val empty : t
val empty : t
val is_empty : t -> bool
val is_empty : t -> bool
val length : t -> int
val length : t -> int
val mem : t -> Term . t -> bool
val mem : Term . t -> t -> bool
val find : t -> Term . t -> Term . t option
val find : Term . t -> t -> Term . t option
val fold : t -> init : ' a -> f : ( key : Term . t -> data : Term . t -> ' a -> ' a ) -> ' a
val fold : t -> init : ' a -> f : ( key : Term . t -> data : Term . t -> ' a -> ' a ) -> ' a
val iteri : t -> f : ( key : Term . t -> data : Term . t -> unit ) -> unit
val iteri : t -> f : ( key : Term . t -> data : Term . t -> unit ) -> unit
val for_alli : t -> f : ( key : Term . t -> data : Term . t -> bool ) -> bool
val for_alli : t -> f : ( key : Term . t -> data : Term . t -> bool ) -> bool
@ -52,17 +52,13 @@ module Subst : sig
val remove : Var . Set . t -> t -> t
val remove : Var . Set . t -> t -> t
val map_entries : f : ( Term . t -> Term . t ) -> t -> t
val map_entries : f : ( Term . t -> Term . t ) -> t -> t
val to_iter : t -> ( Term . t * Term . t ) iter
val to_iter : t -> ( Term . t * Term . t ) iter
val to_alist : t -> ( Term . t * Term . t ) list
val partition_valid : Var . Set . t -> t -> t * Var . Set . t * t
val partition_valid : Var . Set . t -> t -> t * Var . Set . t * t
end = struct
end = struct
type t = Term . t Term . Map . t [ @@ deriving compare , equal , sexp_of ]
type t = Term . t Term . Map . t [ @@ deriving compare , equal , sexp_of ]
let t_of_sexp = Term . Map . t_of_sexp Term . t_of_sexp
let t_of_sexp = Term . Map . t_of_sexp Term . t_of_sexp
let pp = Term . Map . pp Term . pp Term . pp
let pp = Term . Map . pp Term . pp Term . pp
let pp_diff = Term . Map . pp_diff ~ eq : Term . equal Term . pp Term . pp Term . pp_diff
let pp_diff =
Term . Map . pp_diff ~ data_equal : Term . equal Term . pp Term . pp Term . pp_diff
let empty = Term . Map . empty
let empty = Term . Map . empty
let is_empty = Term . Map . is_empty
let is_empty = Term . Map . is_empty
let length = Term . Map . length
let length = Term . Map . length
@ -72,10 +68,9 @@ end = struct
let iteri = Term . Map . iteri
let iteri = Term . Map . iteri
let for_alli = Term . Map . for_alli
let for_alli = Term . Map . for_alli
let to_iter = Term . Map . to_iter
let to_iter = Term . Map . to_iter
let to_alist = Term . Map . to_alist ~ key_order : ` Increasing
(* * look up a term in a substitution *)
(* * look up a term in a substitution *)
let apply s a = Term . Map . find s a | > Option . value ~ default : a
let apply s a = Term . Map . find a s | > Option . value ~ default : a
let rec subst s a = apply s ( Term . map ~ f : ( subst s ) a )
let rec subst s a = apply s ( Term . map ~ f : ( subst s ) a )
@ -88,7 +83,7 @@ end = struct
[ % Trace . call fun { pf } -> pf " %a@ %a " pp r pp s ]
[ % Trace . call fun { pf } -> pf " %a@ %a " pp r pp s ]
;
;
let r' = Term . Map . map_endo ~ f : ( norm s ) r in
let r' = Term . Map . map_endo ~ f : ( norm s ) r in
Term . Map . merge_endo r' s ~ f : ( fun ~ key -> function
Term . Map . merge_endo r' s ~ f : ( fun key -> function
| ` Both ( data_r , data_s ) ->
| ` Both ( data_r , data_s ) ->
assert (
assert (
Term . equal data_s data_r
Term . equal data_s data_r
@ -112,7 +107,7 @@ end = struct
let extend e s =
let extend e s =
let exception Found in
let exception Found in
match
match
Term . Map . update s e ~ f : ( function
Term . Map . update e s ~ f : ( function
| Some _ -> raise_notrace Found
| Some _ -> raise_notrace Found
| None -> e )
| None -> e )
with
with
@ -121,7 +116,7 @@ end = struct
(* * remove entries for vars *)
(* * remove entries for vars *)
let remove xs s =
let remove xs s =
Var . Set . fold ~ f : ( fun s x -> Term . Map . remove s ( Term . var x ) ) ~ init : s xs
Var . Set . fold ~ f : ( fun s x -> Term . Map . remove ( Term . var x ) s ) ~ init : s xs
(* * map over a subst, applying [f] to both domain and range, requires that
(* * map over a subst, applying [f] to both domain and range, requires that
[ f ] is injective and for any set of terms [ E ] , [ f \ [ E \ ] ] is disjoint
[ f ] is injective and for any set of terms [ E ] , [ f \ [ E \ ] ] is disjoint
@ -132,9 +127,9 @@ end = struct
let data' = f data in
let data' = f data in
if Term . equal key' key then
if Term . equal key' key then
if Term . equal data' data then s
if Term . equal data' data then s
else Term . Map . set s ~ key ~ data : data'
else Term . Map . add ~ key ~ data : data' s
else
else
let s = Term . Map . remove s key in
let s = Term . Map . remove key s in
match ( key : Term . t ) with
match ( key : Term . t ) with
| Integer _ | Rational _ -> s
| Integer _ | Rational _ -> s
| _ -> Term . Map . add_exn ~ key : key' ~ data : data' s )
| _ -> Term . Map . add_exn ~ key : key' ~ data : data' s )
@ -167,10 +162,10 @@ end = struct
Term . Map . fold s ~ init : ( t , ks , s ) ~ f : ( fun ~ key ~ data ( t , ks , s ) ->
Term . Map . fold s ~ init : ( t , ks , s ) ~ f : ( fun ~ key ~ data ( t , ks , s ) ->
if is_valid_eq ks key data then ( t , ks , s )
if is_valid_eq ks key data then ( t , ks , s )
else
else
let t = Term . Map . set ~ key ~ data t
let t = Term . Map . add ~ key ~ data t
and ks =
and ks =
Var . Set . diff ks ( Var . Set . union ( Term . fv key ) ( Term . fv data ) )
Var . Set . diff ks ( Var . Set . union ( Term . fv key ) ( Term . fv data ) )
and s = Term . Map . remove s key in
and s = Term . Map . remove key s in
( t , ks , s ) )
( t , ks , s ) )
in
in
if s' != s then partition_valid_ t' ks' s' else ( t' , ks' , s' )
if s' != s then partition_valid_ t' ks' s' else ( t' , ks' , s' )
@ -347,7 +342,7 @@ type classes = Term.t list Term.Map.t
let classes r =
let classes r =
let add key data cls =
let add key data cls =
if Term . equal key data then cls
if Term . equal key data then cls
else Term . Map . add_multi cls ~ key : data ~ data : key
else Term . Map . add_multi ~ key : data ~ data : key cls
in
in
Subst . fold r . rep ~ init : Term . Map . empty ~ f : ( fun ~ key ~ data cls ->
Subst . fold r . rep ~ init : Term . Map . empty ~ f : ( fun ~ key ~ data cls ->
match classify key with
match classify key with
@ -356,7 +351,7 @@ let classes r =
let cls_of r e =
let cls_of r e =
let e' = Subst . apply r . rep e in
let e' = Subst . apply r . rep e in
Term . Map . find ( classes r ) e' | > Option . value ~ default : [ e' ]
Term . Map . find e' ( classes r ) | > Option . value ~ default : [ e' ]
(* * Pretty-printing *)
(* * Pretty-printing *)
@ -370,7 +365,7 @@ let pp fs {sat; rep} =
let pp_term_v fs ( k , v ) = if not ( Term . equal k v ) then Term . pp fs v in
let pp_term_v fs ( k , v ) = if not ( Term . equal k v ) then Term . pp fs v in
Format . fprintf fs " @[{@[<hv>sat= %b;@ rep= %a@]}@] " sat
Format . fprintf fs " @[{@[<hv>sat= %b;@ rep= %a@]}@] " sat
( pp_alist Term . pp pp_term_v )
( pp_alist Term . pp pp_term_v )
( Subst. to_ al is t rep )
( Iter. to_list ( Subst. to_ iter rep ) )
let pp_diff fs ( r , s ) =
let pp_diff fs ( r , s ) =
let pp_sat fs =
let pp_sat fs =
@ -392,18 +387,18 @@ let ppx_classes x fs clss =
( fun fs ( rep , cls ) ->
( fun fs ( rep , cls ) ->
Format . fprintf fs " @[%a@ = %a@] " ( Term . ppx x ) rep ( ppx_cls x )
Format . fprintf fs " @[%a@ = %a@] " ( Term . ppx x ) rep ( ppx_cls x )
( List . sort ~ cmp : Term . compare cls ) )
( List . sort ~ cmp : Term . compare cls ) )
fs ( Term . Map . to_alist clss )
fs
( Iter . to_list ( Term . Map . to_iter clss ) )
let pp_classes fs r = ppx_classes ( fun _ -> None ) fs ( classes r )
let pp_classes fs r = ppx_classes ( fun _ -> None ) fs ( classes r )
let pp_diff_clss =
let pp_diff_clss =
Term . Map . pp_diff ~ data_equal : ( List . equal Term . equal ) Term . pp pp_cls
Term . Map . pp_diff ~ eq : ( List . equal Term . equal ) Term . pp pp_cls pp_diff_cls
pp_diff_cls
(* * Basic queries *)
(* * Basic queries *)
(* * test membership in carrier *)
(* * test membership in carrier *)
let in_car r e = Subst . mem r. rep e
let in_car r e = Subst . mem e r. rep
(* * congruent specialized to assume subterms of [a'] are [Subst.norm]alized
(* * congruent specialized to assume subterms of [a'] are [Subst.norm]alized
wrt [ r ] ( or canonized ) * )
wrt [ r ] ( or canonized ) * )
@ -575,7 +570,7 @@ let normalize = canon
let class_of r e =
let class_of r e =
let e' = normalize r e in
let e' = normalize r e in
e' :: Term . Map . find_multi ( classes r ) e'
e' :: Term . Map . find_multi e' ( classes r )
let fold_uses_of r t ~ init ~ f =
let fold_uses_of r t ~ init ~ f =
let rec fold_ e ~ init : s ~ f =
let rec fold_ e ~ init : s ~ f =
@ -707,7 +702,7 @@ let subst_invariant us s0 s =
Subst . iteri s ~ f : ( fun ~ key ~ data ->
Subst . iteri s ~ f : ( fun ~ key ~ data ->
(* dom of new entries not ito us *)
(* dom of new entries not ito us *)
assert (
assert (
Option . for_all ~ f : ( Term . equal data ) ( Subst . find s0 key)
Option . for_all ~ f : ( Term . equal data ) ( Subst . find key s0 )
| | not ( Var . Set . is_subset ( Term . fv key ) ~ of_ : us ) ) ;
| | not ( Var . Set . is_subset ( Term . fv key ) ~ of_ : us ) ) ;
(* rep not ito us implies trm not ito us *)
(* rep not ito us implies trm not ito us *)
assert (
assert (
@ -912,8 +907,8 @@ let solve_class us us_xs ~key:rep ~data:cls (classes, subst) =
let cls = List . rev_append cls_not_ito_us_xs cls in
let cls = List . rev_append cls_not_ito_us_xs cls in
let cls = List . remove ~ eq : Term . equal ( Subst . norm subst rep ) cls in
let cls = List . remove ~ eq : Term . equal ( Subst . norm subst rep ) cls in
let classes =
let classes =
if List . is_empty cls then Term . Map . remove classes rep
if List . is_empty cls then Term . Map . remove rep classes
else Term . Map . set classes ~ key : rep ~ data : cls
else Term . Map . add ~ key : rep ~ data : cls classes
in
in
( classes , subst )
( classes , subst )
| >
| >
@ -980,7 +975,7 @@ let solve_for_xs r us xs (classes, subst, us_xs) =
Var . Set . fold xs ~ init : ( classes , subst , us_xs )
Var . Set . fold xs ~ init : ( classes , subst , us_xs )
~ f : ( fun ( classes , subst , us_xs ) x ->
~ f : ( fun ( classes , subst , us_xs ) x ->
let x = Term . var x in
let x = Term . var x in
if Subst . mem subst x then ( classes , subst , us_xs )
if Subst . mem x subst then ( classes , subst , us_xs )
else solve_concat_extracts r us x ( classes , subst , us_xs ) )
else solve_concat_extracts r us x ( classes , subst , us_xs ) )
(* * move equations from [classes] to [subst] which can be expressed, after
(* * move equations from [classes] to [subst] which can be expressed, after