From 32c89e6b688a746688f0afeabe6ad404f108b0de Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Sun, 21 Feb 2021 13:17:05 -0800 Subject: [PATCH] [sledge] Change ocaml/{set,map} to use Comparer interface Summary: Also add support for deriving compare, equal, and sexp. Reviewed By: ngorogiannis Differential Revision: D26250524 fbshipit-source-id: b47787a9c --- sledge/nonstdlib/ocaml/map.ml | 199 ++++++++++++++++++++++++--------- sledge/nonstdlib/ocaml/map.mli | 29 ++++- sledge/nonstdlib/ocaml/set.ml | 192 +++++++++++++++++++++---------- sledge/nonstdlib/ocaml/set.mli | 29 ++++- 4 files changed, 339 insertions(+), 110 deletions(-) diff --git a/sledge/nonstdlib/ocaml/map.ml b/sledge/nonstdlib/ocaml/map.ml index 9fbcd6c31..edc2ef0e6 100644 --- a/sledge/nonstdlib/ocaml/map.ml +++ b/sledge/nonstdlib/ocaml/map.ml @@ -23,6 +23,9 @@ module type S = sig type key type +'a t + + include Comparer.S1 with type 'a t := 'a t + val empty: 'a t val is_empty: 'a t -> bool val mem: key -> 'a t -> bool @@ -34,7 +37,14 @@ module type S = (key -> 'a option -> 'b option -> 'c option) -> 'a t -> 'b t -> 'c t val union: (key -> 'a -> 'a -> 'a option) -> 'a t -> 'a t -> 'a t val compare: ('a -> 'a -> int) -> 'a t -> 'a t -> int - val equal: ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + + module Provide_equal (_ : sig + type t = key [@@deriving equal] + end) : sig + type 'a t [@@deriving equal] + end + with type 'a t := 'a t + val iter: (key -> 'a -> unit) -> 'a t -> unit val fold: (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b val for_all: (key -> 'a -> bool) -> 'a t -> bool @@ -63,62 +73,145 @@ module type S = val to_seq_from : key -> 'a t -> (key * 'a) Seq.t val add_seq : (key * 'a) Seq.t -> 'a t -> 'a t val of_seq : (key * 'a) Seq.t -> 'a t - end - -module Make(Ord: OrderedType) = struct - - type key = Ord.t - - type 'a t = - Empty - | Node of {l:'a t; v:key; d:'a; r:'a t; h:int} - type 'a enumeration = End | More of key * 'a * 'a t * 'a enumeration + module Provide_sexp_of (_ : sig + type t = key [@@deriving sexp_of] + end) : sig + type 'a t [@@deriving sexp_of] + end + with type 'a t := 'a t + + module Provide_of_sexp (_ : sig + type t = key [@@deriving of_sexp] + end) : sig + type 'a t [@@deriving of_sexp] + end + with type 'a t := 'a t + end - let rec cons_enum m e = - match m with - Empty -> e - | Node {l; v; d; r} -> cons_enum l (More(v, d, r, e)) - - let compare cmp m1 m2 = - let rec compare_aux e1 e2 = - match (e1, e2) with - (End, End) -> 0 - | (End, _) -> -1 - | (_, End) -> 1 - | (More(v1, d1, r1, e1), More(v2, d2, r2, e2)) -> - let c = Ord.compare v1 v2 in - if c <> 0 then c else - let c = cmp d1 d2 in - if c <> 0 then c else - compare_aux (cons_enum r1 e1) (cons_enum r2 e2) - in compare_aux (cons_enum m1 End) (cons_enum m2 End) - - let equal cmp m1 m2 = - let rec equal_aux e1 e2 = - match (e1, e2) with - (End, End) -> true - | (End, _) -> false - | (_, End) -> false - | (More(v1, d1, r1, e1), More(v2, d2, r2, e2)) -> - Ord.compare v1 v2 = 0 && cmp d1 d2 && - equal_aux (cons_enum r1 e1) (cons_enum r2 e2) - in equal_aux (cons_enum m1 End) (cons_enum m2 End) - - let rec bindings_aux accu = function - Empty -> accu - | Node {l; v; d; r} -> bindings_aux ((v, d) :: bindings_aux accu r) l +module T = struct + type ('key, 'a, 'cmp) t = + Empty + | Node of {l:('key, 'a, 'cmp) t; v:'key; d:'a; r:('key, 'a, 'cmp) t; h:int} + + type ('key, 'a, 'cmp) enumeration = + End + | More of 'key * 'a * ('key, 'a, 'cmp) t * ('key, 'a, 'cmp) enumeration + + let rec cons_enum m e = + match m with + Empty -> e + | Node {l; v; d; r} -> cons_enum l (More(v, d, r, e)) + + let compare compare_key compare_a _ m1 m2 = + let rec compare_aux e1 e2 = + match (e1, e2) with + (End, End) -> 0 + | (End, _) -> -1 + | (_, End) -> 1 + | (More(v1, d1, r1, e1), More(v2, d2, r2, e2)) -> + let c = compare_key v1 v2 in + if c <> 0 then c else + let c = compare_a d1 d2 in + if c <> 0 then c else + compare_aux (cons_enum r1 e1) (cons_enum r2 e2) + in compare_aux (cons_enum m1 End) (cons_enum m2 End) + + type ('compare_key, 'compare_a) compare [@@deriving compare, equal, sexp] +end - let bindings s = - bindings_aux [] s +include T + +let equal equal_key equal_a _ m1 m2 = + let rec equal_aux e1 e2 = + match (e1, e2) with + (End, End) -> true + | (End, _) -> false + | (_, End) -> false + | (More(v1, d1, r1, e1), More(v2, d2, r2, e2)) -> + equal_key v1 v2 && equal_a d1 d2 && + equal_aux (cons_enum r1 e1) (cons_enum r2 e2) + in equal_aux (cons_enum m1 End) (cons_enum m2 End) + +let rec bindings_aux accu = function + Empty -> accu + | Node {l; v; d; r} -> bindings_aux ((v, d) :: bindings_aux accu r) l + +let bindings s = + bindings_aux [] s + +let sexp_of_t sexp_of_key sexp_of_data _ m = + m + |> bindings + |> Sexplib.Conv.sexp_of_list + (Sexplib.Conv.sexp_of_pair sexp_of_key sexp_of_data) + +let height = function + Empty -> 0 + | Node {h} -> h + +let create l x d r = + let hl = height l and hr = height r in + Node{l; v=x; d; r; h=(if hl >= hr then hl + 1 else hr + 1)} + +let of_sorted_list l = + let rec sub n l = + match n, l with + | 0, l -> Empty, l + | 1, (v0,d0) :: l -> Node {l=Empty; v=v0; d=d0; r=Empty; h=1}, l + | 2, (v0,d0) :: (v1,d1) :: l -> + Node{l=Node{l=Empty; v=v0; d=d0; r=Empty; h=1}; v=v1; d=d1; + r=Empty; h=2}, l + | 3, (v0,d0) :: (v1,d1) :: (v2,d2) :: l -> + Node{l=Node{l=Empty; v=v0; d=d0; r=Empty; h=1}; v=v1; d=d1; + r=Node{l=Empty; v=v2; d=d2; r=Empty; h=1}; h=2}, l + | n, l -> + let nl = n / 2 in + let left, l = sub nl l in + match l with + | [] -> assert false + | (v,d) :: l -> + let right, l = sub (n - nl - 1) l in + create left v d right, l + in + fst (sub (List.length l) l) + +let t_of_sexp key_of_sexp data_of_sexp _ m = + m + |> Sexplib.Conv.list_of_sexp + (Sexplib.Conv.pair_of_sexp key_of_sexp data_of_sexp) + |> of_sorted_list + +module Make (Ord : Comparer.S) = struct + module Ord = struct + include Ord + let compare = (comparer :> t -> t -> int) + end - let height = function - Empty -> 0 - | Node {h} -> h + type key = Ord.t - let create l x d r = - let hl = height l and hr = height r in - Node{l; v=x; d; r; h=(if hl >= hr then hl + 1 else hr + 1)} + include (Comparer.Apply1 (T) (Ord)) + + module Provide_equal (Key : sig + type t = Ord.t [@@deriving equal] + end) = struct + let equal equal_data = + equal Key.equal equal_data Ord.equal_compare + end + + module Provide_sexp_of (Key : sig + type t = Ord.t [@@deriving sexp_of] + end) = struct + let sexp_of_t sexp_of_data m = + sexp_of_t Key.sexp_of_t sexp_of_data Ord.sexp_of_compare m + end + + module Provide_of_sexp (Key : sig + type t = Ord.t [@@deriving of_sexp] + end) = struct + let t_of_sexp data_of_sexp s = + t_of_sexp Key.t_of_sexp data_of_sexp Ord.compare_of_sexp s + end let singleton x d = Node{l=Empty; v=x; d; r=Empty; h=1} @@ -492,6 +585,8 @@ module Make(Ord: OrderedType) = struct Empty -> 0 | Node {l; r} -> cardinal l + 1 + cardinal r + let bindings = bindings + let choose = min_binding let choose_opt = min_binding_opt diff --git a/sledge/nonstdlib/ocaml/map.mli b/sledge/nonstdlib/ocaml/map.mli index 6ec8249ab..c7ecb77e0 100644 --- a/sledge/nonstdlib/ocaml/map.mli +++ b/sledge/nonstdlib/ocaml/map.mli @@ -67,6 +67,8 @@ module type S = type (+'a) t (** The type of maps from type [key] to type ['a]. *) + include Comparer.S1 with type 'a t := 'a t + val empty: 'a t (** The empty map. *) @@ -141,11 +143,15 @@ module type S = (** Total ordering between maps. The first argument is a total ordering used to compare data associated with equal keys in the two maps. *) + module Provide_equal (_ : sig + type t = key [@@deriving equal] + end) : sig val equal: ('a -> 'a -> bool) -> 'a t -> 'a t -> bool (** [equal cmp m1 m2] tests whether the maps [m1] and [m2] are equal, that is, contain equal keys and associate them with equal data. [cmp] is the equality predicate used to compare the data associated with the keys. *) + end val iter: (key -> 'a -> unit) -> 'a t -> unit (** [iter f m] applies [f] to all bindings in map [m]. @@ -344,9 +350,30 @@ module type S = val of_seq : (key * 'a) Seq.t -> 'a t (** Build a map from the given bindings @since 4.07 *) + + module Provide_sexp_of (_ : sig + type t = key [@@deriving sexp_of] + end) : sig + type 'a t [@@deriving sexp_of] + end + with type 'a t := 'a t + + module Provide_of_sexp (_ : sig + type t = key [@@deriving of_sexp] + end) : sig + type 'a t [@@deriving of_sexp] + end + with type 'a t := 'a t end (** Output signature of the functor {!Map.Make}. *) -module Make (Ord : OrderedType) : S with type key = Ord.t +type ('key, +'a, 'compare_key) t [@@deriving compare, equal, sexp] + +type ('compare_key, 'compare_a) compare [@@deriving compare, equal, sexp] + +module Make (Ord : Comparer.S) : + S with type key = Ord.t + with type +'a t = (Ord.t, 'a, Ord.compare) t + with type 'compare_a compare = (Ord.compare, 'compare_a) compare (** Functor building an implementation of the map structure given a totally ordered type. *) diff --git a/sledge/nonstdlib/ocaml/set.ml b/sledge/nonstdlib/ocaml/set.ml index 427b67163..036ae264e 100644 --- a/sledge/nonstdlib/ocaml/set.ml +++ b/sledge/nonstdlib/ocaml/set.ml @@ -25,6 +25,8 @@ module type S = sig type elt type t + include Comparer.S with type t := t + val empty: t val is_empty: t -> bool val mem: elt -> t -> bool @@ -36,7 +38,14 @@ module type S = val disjoint: t -> t -> bool val diff: t -> t -> t val compare: t -> t -> int - val equal: t -> t -> bool + + module Provide_equal (_ : sig + type t = elt [@@deriving equal] + end) : sig + type t [@@deriving equal] + end + with type t := t + val subset: t -> t -> bool val iter: (elt -> unit) -> t -> unit val map: (elt -> elt) -> t -> t @@ -66,81 +75,150 @@ module type S = val to_seq : t -> elt Seq.t val add_seq : elt Seq.t -> t -> t val of_seq : elt Seq.t -> t + + module Provide_sexp_of (_ : sig + type t = elt [@@deriving sexp_of] + end) : sig + type t [@@deriving sexp_of] + end + with type t := t + + module Provide_of_sexp (_ : sig + type t = elt [@@deriving of_sexp] + end) : sig + type t [@@deriving of_sexp] + end + with type t := t end -module Make(Ord: OrderedType) = - struct - type elt = Ord.t - type t = Empty | Node of {l:t; v:elt; r:t; h:int} +module T = struct + type ('elt, 'cmp) t = + | Empty + | Node of {l: ('elt, 'cmp) t; v: 'elt; r: ('elt, 'cmp) t; h: int} - (* Sets are represented by balanced binary trees (the heights of the - children differ by at most 2 *) + (* Sets are represented by balanced binary trees (the heights of the + children differ by at most 2 *) - type enumeration = End | More of elt * t * enumeration + type ('elt, 'cmp) enumeration = + | End + | More of 'elt * ('elt, 'cmp) t * ('elt, 'cmp) enumeration - let rec cons_enum s e = - match s with - Empty -> e - | Node{l; v; r} -> cons_enum l (More(v, r, e)) + let rec cons_enum s e = + match s with + Empty -> e + | Node{l; v; r} -> cons_enum l (More(v, r, e)) + let compare compare_elt _ s1 s2 = let rec compare_aux e1 e2 = match (e1, e2) with (End, End) -> 0 | (End, _) -> -1 | (_, End) -> 1 | (More(v1, r1, e1), More(v2, r2, e2)) -> - let c = Ord.compare v1 v2 in + let c = compare_elt v1 v2 in if c <> 0 then c else compare_aux (cons_enum r1 e1) (cons_enum r2 e2) + in + compare_aux (cons_enum s1 End) (cons_enum s2 End) + + type 'compare_elt compare [@@deriving compare, equal, sexp] +end + +include T + +let equal equal_elt _ s1 s2 = + let rec equal_aux e1 e2 = + match (e1, e2) with + (End, End) -> true + | (End, _) -> false + | (_, End) -> false + | (More(v1, r1, e1), More(v2, r2, e2)) -> + equal_elt v1 v2 && + equal_aux (cons_enum r1 e1) (cons_enum r2 e2) + in + equal_aux (cons_enum s1 End) (cons_enum s2 End) + +let rec elements_aux accu = function + Empty -> accu + | Node{l; v; r} -> elements_aux (v :: elements_aux accu r) l + +let elements s = + elements_aux [] s + +let sexp_of_t sexp_of_elt _ s = + elements s + |> Sexplib.Conv.sexp_of_list sexp_of_elt + +let height = function + Empty -> 0 + | Node {h} -> h + +(* Creates a new node with left son l, value v and right son r. + We must have all elements of l < v < all elements of r. + l and r must be balanced and | height l - height r | <= 2. + Inline expansion of height for better speed. *) + +let create l v r = + let hl = match l with Empty -> 0 | Node {h} -> h in + let hr = match r with Empty -> 0 | Node {h} -> h in + Node{l; v; r; h=(if hl >= hr then hl + 1 else hr + 1)} + +let of_sorted_list l = + let rec sub n l = + match n, l with + | 0, l -> Empty, l + | 1, x0 :: l -> Node {l=Empty; v=x0; r=Empty; h=1}, l + | 2, x0 :: x1 :: l -> + Node{l=Node{l=Empty; v=x0; r=Empty; h=1}; v=x1; r=Empty; h=2}, l + | 3, x0 :: x1 :: x2 :: l -> + Node{l=Node{l=Empty; v=x0; r=Empty; h=1}; v=x1; + r=Node{l=Empty; v=x2; r=Empty; h=1}; h=2}, l + | n, l -> + let nl = n / 2 in + let left, l = sub nl l in + match l with + | [] -> assert false + | mid :: l -> + let right, l = sub (n - nl - 1) l in + create left mid right, l + in + fst (sub (List.length l) l) + +let t_of_sexp elt_of_sexp _ s = + Sexplib.Conv.list_of_sexp elt_of_sexp s + |> of_sorted_list + +module Make(Ord: Comparer.S) = + struct + module Ord = struct + include Ord + let compare = (comparer :> t -> t -> int) + end - let compare s1 s2 = - compare_aux (cons_enum s1 End) (cons_enum s2 End) - - let equal s1 s2 = - compare s1 s2 = 0 - - let rec elements_aux accu = function - Empty -> accu - | Node{l; v; r} -> elements_aux (v :: elements_aux accu r) l + type elt = Ord.t - let elements s = - elements_aux [] s + include (Comparer.Apply (T) (Ord)) - let height = function - Empty -> 0 - | Node {h} -> h + module Provide_equal (Elt : sig + type t = Ord.t [@@deriving equal] + end) = struct + let equal l r = equal Elt.equal Ord.equal_compare l r + end - (* Creates a new node with left son l, value v and right son r. - We must have all elements of l < v < all elements of r. - l and r must be balanced and | height l - height r | <= 2. - Inline expansion of height for better speed. *) + module Provide_sexp_of (Elt : sig + type t = Ord.t [@@deriving sexp_of] + end) = struct + let sexp_of_t s = + sexp_of_t Elt.sexp_of_t Ord.sexp_of_compare s + end - let create l v r = - let hl = match l with Empty -> 0 | Node {h} -> h in - let hr = match r with Empty -> 0 | Node {h} -> h in - Node{l; v; r; h=(if hl >= hr then hl + 1 else hr + 1)} - - let of_sorted_list l = - let rec sub n l = - match n, l with - | 0, l -> Empty, l - | 1, x0 :: l -> Node {l=Empty; v=x0; r=Empty; h=1}, l - | 2, x0 :: x1 :: l -> - Node{l=Node{l=Empty; v=x0; r=Empty; h=1}; v=x1; r=Empty; h=2}, l - | 3, x0 :: x1 :: x2 :: l -> - Node{l=Node{l=Empty; v=x0; r=Empty; h=1}; v=x1; - r=Node{l=Empty; v=x2; r=Empty; h=1}; h=2}, l - | n, l -> - let nl = n / 2 in - let left, l = sub nl l in - match l with - | [] -> assert false - | mid :: l -> - let right, l = sub (n - nl - 1) l in - create left mid right, l - in - fst (sub (List.length l) l) + module Provide_of_sexp (Elt : sig + type t = Ord.t [@@deriving of_sexp] + end) = struct + let t_of_sexp s = + t_of_sexp Elt.t_of_sexp Ord.compare_of_sexp s + end (* Same as create, but performs one step of rebalancing if necessary. Assumes l and r balanced and | height l - height r | <= 3. @@ -443,6 +521,8 @@ module Make(Ord: OrderedType) = Empty -> 0 | Node{l; r} -> cardinal l + 1 + cardinal r + let elements = elements + let choose = min_elt let choose_opt = min_elt_opt diff --git a/sledge/nonstdlib/ocaml/set.mli b/sledge/nonstdlib/ocaml/set.mli index 91e392386..669a344fe 100644 --- a/sledge/nonstdlib/ocaml/set.mli +++ b/sledge/nonstdlib/ocaml/set.mli @@ -68,6 +68,8 @@ module type S = type t (** The type of sets. *) + include Comparer.S with type t := t + val empty: t (** The empty set. *) @@ -110,9 +112,13 @@ module type S = (** Total ordering between sets. Can be used as the ordering function for doing sets of sets. *) + module Provide_equal (_ : sig + type t = elt [@@deriving equal] + end) : sig val equal: t -> t -> bool (** [equal s1 s2] tests whether the sets [s1] and [s2] are equal, that is, contain equal elements. *) + end val subset: t -> t -> bool (** [subset s1 s2] tests whether the set [s1] is a subset of @@ -298,9 +304,30 @@ module type S = val of_seq : elt Seq.t -> t (** Build a set from the given bindings @since 4.07 *) + + module Provide_sexp_of (_ : sig + type t = elt [@@deriving sexp_of] + end) : sig + type t [@@deriving sexp_of] + end + with type t := t + + module Provide_of_sexp (_ : sig + type t = elt [@@deriving of_sexp] + end) : sig + type t [@@deriving of_sexp] + end + with type t := t end (** Output signature of the functor {!Set.Make}. *) -module Make (Ord : OrderedType) : S with type elt = Ord.t +type ('elt, 'cmp) t [@@deriving compare, equal, sexp] + +type 'compare_elt compare [@@deriving compare, equal, sexp] + +module Make (Ord : Comparer.S) : + S with type elt = Ord.t + with type t = (Ord.t, Ord.compare) t + with type compare = Ord.compare compare (** Functor building an implementation of the set structure given a totally ordered type. *)