[sledge] Change ocaml/{set,map} to use Comparer interface

Summary: Also add support for deriving compare, equal, and sexp.

Reviewed By: ngorogiannis

Differential Revision: D26250524

fbshipit-source-id: b47787a9c
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent 7cf6e17403
commit 32c89e6b68

@ -23,6 +23,9 @@ module type S =
sig sig
type key type key
type +'a t type +'a t
include Comparer.S1 with type 'a t := 'a t
val empty: 'a t val empty: 'a t
val is_empty: 'a t -> bool val is_empty: 'a t -> bool
val mem: key -> 'a t -> bool val mem: key -> 'a t -> bool
@ -34,7 +37,14 @@ module type S =
(key -> 'a option -> 'b option -> 'c option) -> 'a t -> 'b t -> 'c t (key -> 'a option -> 'b option -> 'c option) -> 'a t -> 'b t -> 'c t
val union: (key -> 'a -> 'a -> 'a option) -> 'a t -> 'a t -> 'a t val union: (key -> 'a -> 'a -> 'a option) -> 'a t -> 'a t -> 'a t
val compare: ('a -> 'a -> int) -> 'a t -> 'a t -> int val compare: ('a -> 'a -> int) -> 'a t -> 'a t -> int
val equal: ('a -> 'a -> bool) -> 'a t -> 'a t -> bool
module Provide_equal (_ : sig
type t = key [@@deriving equal]
end) : sig
type 'a t [@@deriving equal]
end
with type 'a t := 'a t
val iter: (key -> 'a -> unit) -> 'a t -> unit val iter: (key -> 'a -> unit) -> 'a t -> unit
val fold: (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b val fold: (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b
val for_all: (key -> 'a -> bool) -> 'a t -> bool val for_all: (key -> 'a -> bool) -> 'a t -> bool
@ -63,63 +73,146 @@ module type S =
val to_seq_from : key -> 'a t -> (key * 'a) Seq.t val to_seq_from : key -> 'a t -> (key * 'a) Seq.t
val add_seq : (key * 'a) Seq.t -> 'a t -> 'a t val add_seq : (key * 'a) Seq.t -> 'a t -> 'a t
val of_seq : (key * 'a) Seq.t -> 'a t val of_seq : (key * 'a) Seq.t -> 'a t
end
module Make(Ord: OrderedType) = struct module Provide_sexp_of (_ : sig
type t = key [@@deriving sexp_of]
end) : sig
type 'a t [@@deriving sexp_of]
end
with type 'a t := 'a t
type key = Ord.t module Provide_of_sexp (_ : sig
type t = key [@@deriving of_sexp]
end) : sig
type 'a t [@@deriving of_sexp]
end
with type 'a t := 'a t
end
type 'a t = module T = struct
type ('key, 'a, 'cmp) t =
Empty Empty
| Node of {l:'a t; v:key; d:'a; r:'a t; h:int} | Node of {l:('key, 'a, 'cmp) t; v:'key; d:'a; r:('key, 'a, 'cmp) t; h:int}
type 'a enumeration = End | More of key * 'a * 'a t * 'a enumeration type ('key, 'a, 'cmp) enumeration =
End
| More of 'key * 'a * ('key, 'a, 'cmp) t * ('key, 'a, 'cmp) enumeration
let rec cons_enum m e = let rec cons_enum m e =
match m with match m with
Empty -> e Empty -> e
| Node {l; v; d; r} -> cons_enum l (More(v, d, r, e)) | Node {l; v; d; r} -> cons_enum l (More(v, d, r, e))
let compare cmp m1 m2 = let compare compare_key compare_a _ m1 m2 =
let rec compare_aux e1 e2 = let rec compare_aux e1 e2 =
match (e1, e2) with match (e1, e2) with
(End, End) -> 0 (End, End) -> 0
| (End, _) -> -1 | (End, _) -> -1
| (_, End) -> 1 | (_, End) -> 1
| (More(v1, d1, r1, e1), More(v2, d2, r2, e2)) -> | (More(v1, d1, r1, e1), More(v2, d2, r2, e2)) ->
let c = Ord.compare v1 v2 in let c = compare_key v1 v2 in
if c <> 0 then c else if c <> 0 then c else
let c = cmp d1 d2 in let c = compare_a d1 d2 in
if c <> 0 then c else if c <> 0 then c else
compare_aux (cons_enum r1 e1) (cons_enum r2 e2) compare_aux (cons_enum r1 e1) (cons_enum r2 e2)
in compare_aux (cons_enum m1 End) (cons_enum m2 End) in compare_aux (cons_enum m1 End) (cons_enum m2 End)
let equal cmp m1 m2 = type ('compare_key, 'compare_a) compare [@@deriving compare, equal, sexp]
end
include T
let equal equal_key equal_a _ m1 m2 =
let rec equal_aux e1 e2 = let rec equal_aux e1 e2 =
match (e1, e2) with match (e1, e2) with
(End, End) -> true (End, End) -> true
| (End, _) -> false | (End, _) -> false
| (_, End) -> false | (_, End) -> false
| (More(v1, d1, r1, e1), More(v2, d2, r2, e2)) -> | (More(v1, d1, r1, e1), More(v2, d2, r2, e2)) ->
Ord.compare v1 v2 = 0 && cmp d1 d2 && equal_key v1 v2 && equal_a d1 d2 &&
equal_aux (cons_enum r1 e1) (cons_enum r2 e2) equal_aux (cons_enum r1 e1) (cons_enum r2 e2)
in equal_aux (cons_enum m1 End) (cons_enum m2 End) in equal_aux (cons_enum m1 End) (cons_enum m2 End)
let rec bindings_aux accu = function let rec bindings_aux accu = function
Empty -> accu Empty -> accu
| Node {l; v; d; r} -> bindings_aux ((v, d) :: bindings_aux accu r) l | Node {l; v; d; r} -> bindings_aux ((v, d) :: bindings_aux accu r) l
let bindings s = let bindings s =
bindings_aux [] s bindings_aux [] s
let height = function let sexp_of_t sexp_of_key sexp_of_data _ m =
m
|> bindings
|> Sexplib.Conv.sexp_of_list
(Sexplib.Conv.sexp_of_pair sexp_of_key sexp_of_data)
let height = function
Empty -> 0 Empty -> 0
| Node {h} -> h | Node {h} -> h
let create l x d r = let create l x d r =
let hl = height l and hr = height r in let hl = height l and hr = height r in
Node{l; v=x; d; r; h=(if hl >= hr then hl + 1 else hr + 1)} Node{l; v=x; d; r; h=(if hl >= hr then hl + 1 else hr + 1)}
let of_sorted_list l =
let rec sub n l =
match n, l with
| 0, l -> Empty, l
| 1, (v0,d0) :: l -> Node {l=Empty; v=v0; d=d0; r=Empty; h=1}, l
| 2, (v0,d0) :: (v1,d1) :: l ->
Node{l=Node{l=Empty; v=v0; d=d0; r=Empty; h=1}; v=v1; d=d1;
r=Empty; h=2}, l
| 3, (v0,d0) :: (v1,d1) :: (v2,d2) :: l ->
Node{l=Node{l=Empty; v=v0; d=d0; r=Empty; h=1}; v=v1; d=d1;
r=Node{l=Empty; v=v2; d=d2; r=Empty; h=1}; h=2}, l
| n, l ->
let nl = n / 2 in
let left, l = sub nl l in
match l with
| [] -> assert false
| (v,d) :: l ->
let right, l = sub (n - nl - 1) l in
create left v d right, l
in
fst (sub (List.length l) l)
let t_of_sexp key_of_sexp data_of_sexp _ m =
m
|> Sexplib.Conv.list_of_sexp
(Sexplib.Conv.pair_of_sexp key_of_sexp data_of_sexp)
|> of_sorted_list
module Make (Ord : Comparer.S) = struct
module Ord = struct
include Ord
let compare = (comparer :> t -> t -> int)
end
type key = Ord.t
include (Comparer.Apply1 (T) (Ord))
module Provide_equal (Key : sig
type t = Ord.t [@@deriving equal]
end) = struct
let equal equal_data =
equal Key.equal equal_data Ord.equal_compare
end
module Provide_sexp_of (Key : sig
type t = Ord.t [@@deriving sexp_of]
end) = struct
let sexp_of_t sexp_of_data m =
sexp_of_t Key.sexp_of_t sexp_of_data Ord.sexp_of_compare m
end
module Provide_of_sexp (Key : sig
type t = Ord.t [@@deriving of_sexp]
end) = struct
let t_of_sexp data_of_sexp s =
t_of_sexp Key.t_of_sexp data_of_sexp Ord.compare_of_sexp s
end
let singleton x d = Node{l=Empty; v=x; d; r=Empty; h=1} let singleton x d = Node{l=Empty; v=x; d; r=Empty; h=1}
let bal l x d r = let bal l x d r =
@ -492,6 +585,8 @@ module Make(Ord: OrderedType) = struct
Empty -> 0 Empty -> 0
| Node {l; r} -> cardinal l + 1 + cardinal r | Node {l; r} -> cardinal l + 1 + cardinal r
let bindings = bindings
let choose = min_binding let choose = min_binding
let choose_opt = min_binding_opt let choose_opt = min_binding_opt

@ -67,6 +67,8 @@ module type S =
type (+'a) t type (+'a) t
(** The type of maps from type [key] to type ['a]. *) (** The type of maps from type [key] to type ['a]. *)
include Comparer.S1 with type 'a t := 'a t
val empty: 'a t val empty: 'a t
(** The empty map. *) (** The empty map. *)
@ -141,11 +143,15 @@ module type S =
(** Total ordering between maps. The first argument is a total ordering (** Total ordering between maps. The first argument is a total ordering
used to compare data associated with equal keys in the two maps. *) used to compare data associated with equal keys in the two maps. *)
module Provide_equal (_ : sig
type t = key [@@deriving equal]
end) : sig
val equal: ('a -> 'a -> bool) -> 'a t -> 'a t -> bool val equal: ('a -> 'a -> bool) -> 'a t -> 'a t -> bool
(** [equal cmp m1 m2] tests whether the maps [m1] and [m2] are (** [equal cmp m1 m2] tests whether the maps [m1] and [m2] are
equal, that is, contain equal keys and associate them with equal, that is, contain equal keys and associate them with
equal data. [cmp] is the equality predicate used to compare equal data. [cmp] is the equality predicate used to compare
the data associated with the keys. *) the data associated with the keys. *)
end
val iter: (key -> 'a -> unit) -> 'a t -> unit val iter: (key -> 'a -> unit) -> 'a t -> unit
(** [iter f m] applies [f] to all bindings in map [m]. (** [iter f m] applies [f] to all bindings in map [m].
@ -344,9 +350,30 @@ module type S =
val of_seq : (key * 'a) Seq.t -> 'a t val of_seq : (key * 'a) Seq.t -> 'a t
(** Build a map from the given bindings (** Build a map from the given bindings
@since 4.07 *) @since 4.07 *)
module Provide_sexp_of (_ : sig
type t = key [@@deriving sexp_of]
end) : sig
type 'a t [@@deriving sexp_of]
end
with type 'a t := 'a t
module Provide_of_sexp (_ : sig
type t = key [@@deriving of_sexp]
end) : sig
type 'a t [@@deriving of_sexp]
end
with type 'a t := 'a t
end end
(** Output signature of the functor {!Map.Make}. *) (** Output signature of the functor {!Map.Make}. *)
module Make (Ord : OrderedType) : S with type key = Ord.t type ('key, +'a, 'compare_key) t [@@deriving compare, equal, sexp]
type ('compare_key, 'compare_a) compare [@@deriving compare, equal, sexp]
module Make (Ord : Comparer.S) :
S with type key = Ord.t
with type +'a t = (Ord.t, 'a, Ord.compare) t
with type 'compare_a compare = (Ord.compare, 'compare_a) compare
(** Functor building an implementation of the map structure (** Functor building an implementation of the map structure
given a totally ordered type. *) given a totally ordered type. *)

@ -25,6 +25,8 @@ module type S =
sig sig
type elt type elt
type t type t
include Comparer.S with type t := t
val empty: t val empty: t
val is_empty: t -> bool val is_empty: t -> bool
val mem: elt -> t -> bool val mem: elt -> t -> bool
@ -36,7 +38,14 @@ module type S =
val disjoint: t -> t -> bool val disjoint: t -> t -> bool
val diff: t -> t -> t val diff: t -> t -> t
val compare: t -> t -> int val compare: t -> t -> int
val equal: t -> t -> bool
module Provide_equal (_ : sig
type t = elt [@@deriving equal]
end) : sig
type t [@@deriving equal]
end
with type t := t
val subset: t -> t -> bool val subset: t -> t -> bool
val iter: (elt -> unit) -> t -> unit val iter: (elt -> unit) -> t -> unit
val map: (elt -> elt) -> t -> t val map: (elt -> elt) -> t -> t
@ -66,62 +75,96 @@ module type S =
val to_seq : t -> elt Seq.t val to_seq : t -> elt Seq.t
val add_seq : elt Seq.t -> t -> t val add_seq : elt Seq.t -> t -> t
val of_seq : elt Seq.t -> t val of_seq : elt Seq.t -> t
module Provide_sexp_of (_ : sig
type t = elt [@@deriving sexp_of]
end) : sig
type t [@@deriving sexp_of]
end end
with type t := t
module Make(Ord: OrderedType) = module Provide_of_sexp (_ : sig
struct type t = elt [@@deriving of_sexp]
type elt = Ord.t end) : sig
type t = Empty | Node of {l:t; v:elt; r:t; h:int} type t [@@deriving of_sexp]
end
with type t := t
end
module T = struct
type ('elt, 'cmp) t =
| Empty
| Node of {l: ('elt, 'cmp) t; v: 'elt; r: ('elt, 'cmp) t; h: int}
(* Sets are represented by balanced binary trees (the heights of the (* Sets are represented by balanced binary trees (the heights of the
children differ by at most 2 *) children differ by at most 2 *)
type enumeration = End | More of elt * t * enumeration type ('elt, 'cmp) enumeration =
| End
| More of 'elt * ('elt, 'cmp) t * ('elt, 'cmp) enumeration
let rec cons_enum s e = let rec cons_enum s e =
match s with match s with
Empty -> e Empty -> e
| Node{l; v; r} -> cons_enum l (More(v, r, e)) | Node{l; v; r} -> cons_enum l (More(v, r, e))
let compare compare_elt _ s1 s2 =
let rec compare_aux e1 e2 = let rec compare_aux e1 e2 =
match (e1, e2) with match (e1, e2) with
(End, End) -> 0 (End, End) -> 0
| (End, _) -> -1 | (End, _) -> -1
| (_, End) -> 1 | (_, End) -> 1
| (More(v1, r1, e1), More(v2, r2, e2)) -> | (More(v1, r1, e1), More(v2, r2, e2)) ->
let c = Ord.compare v1 v2 in let c = compare_elt v1 v2 in
if c <> 0 if c <> 0
then c then c
else compare_aux (cons_enum r1 e1) (cons_enum r2 e2) else compare_aux (cons_enum r1 e1) (cons_enum r2 e2)
in
let compare s1 s2 =
compare_aux (cons_enum s1 End) (cons_enum s2 End) compare_aux (cons_enum s1 End) (cons_enum s2 End)
let equal s1 s2 = type 'compare_elt compare [@@deriving compare, equal, sexp]
compare s1 s2 = 0 end
let rec elements_aux accu = function include T
let equal equal_elt _ s1 s2 =
let rec equal_aux e1 e2 =
match (e1, e2) with
(End, End) -> true
| (End, _) -> false
| (_, End) -> false
| (More(v1, r1, e1), More(v2, r2, e2)) ->
equal_elt v1 v2 &&
equal_aux (cons_enum r1 e1) (cons_enum r2 e2)
in
equal_aux (cons_enum s1 End) (cons_enum s2 End)
let rec elements_aux accu = function
Empty -> accu Empty -> accu
| Node{l; v; r} -> elements_aux (v :: elements_aux accu r) l | Node{l; v; r} -> elements_aux (v :: elements_aux accu r) l
let elements s = let elements s =
elements_aux [] s elements_aux [] s
let height = function let sexp_of_t sexp_of_elt _ s =
elements s
|> Sexplib.Conv.sexp_of_list sexp_of_elt
let height = function
Empty -> 0 Empty -> 0
| Node {h} -> h | Node {h} -> h
(* Creates a new node with left son l, value v and right son r. (* Creates a new node with left son l, value v and right son r.
We must have all elements of l < v < all elements of r. We must have all elements of l < v < all elements of r.
l and r must be balanced and | height l - height r | <= 2. l and r must be balanced and | height l - height r | <= 2.
Inline expansion of height for better speed. *) Inline expansion of height for better speed. *)
let create l v r = let create l v r =
let hl = match l with Empty -> 0 | Node {h} -> h in let hl = match l with Empty -> 0 | Node {h} -> h in
let hr = match r with Empty -> 0 | Node {h} -> h in let hr = match r with Empty -> 0 | Node {h} -> h in
Node{l; v; r; h=(if hl >= hr then hl + 1 else hr + 1)} Node{l; v; r; h=(if hl >= hr then hl + 1 else hr + 1)}
let of_sorted_list l = let of_sorted_list l =
let rec sub n l = let rec sub n l =
match n, l with match n, l with
| 0, l -> Empty, l | 0, l -> Empty, l
@ -142,6 +185,41 @@ module Make(Ord: OrderedType) =
in in
fst (sub (List.length l) l) fst (sub (List.length l) l)
let t_of_sexp elt_of_sexp _ s =
Sexplib.Conv.list_of_sexp elt_of_sexp s
|> of_sorted_list
module Make(Ord: Comparer.S) =
struct
module Ord = struct
include Ord
let compare = (comparer :> t -> t -> int)
end
type elt = Ord.t
include (Comparer.Apply (T) (Ord))
module Provide_equal (Elt : sig
type t = Ord.t [@@deriving equal]
end) = struct
let equal l r = equal Elt.equal Ord.equal_compare l r
end
module Provide_sexp_of (Elt : sig
type t = Ord.t [@@deriving sexp_of]
end) = struct
let sexp_of_t s =
sexp_of_t Elt.sexp_of_t Ord.sexp_of_compare s
end
module Provide_of_sexp (Elt : sig
type t = Ord.t [@@deriving of_sexp]
end) = struct
let t_of_sexp s =
t_of_sexp Elt.t_of_sexp Ord.compare_of_sexp s
end
(* Same as create, but performs one step of rebalancing if necessary. (* Same as create, but performs one step of rebalancing if necessary.
Assumes l and r balanced and | height l - height r | <= 3. Assumes l and r balanced and | height l - height r | <= 3.
Inline expansion of create for better speed in the most frequent case Inline expansion of create for better speed in the most frequent case
@ -443,6 +521,8 @@ module Make(Ord: OrderedType) =
Empty -> 0 Empty -> 0
| Node{l; r} -> cardinal l + 1 + cardinal r | Node{l; r} -> cardinal l + 1 + cardinal r
let elements = elements
let choose = min_elt let choose = min_elt
let choose_opt = min_elt_opt let choose_opt = min_elt_opt

@ -68,6 +68,8 @@ module type S =
type t type t
(** The type of sets. *) (** The type of sets. *)
include Comparer.S with type t := t
val empty: t val empty: t
(** The empty set. *) (** The empty set. *)
@ -110,9 +112,13 @@ module type S =
(** Total ordering between sets. Can be used as the ordering function (** Total ordering between sets. Can be used as the ordering function
for doing sets of sets. *) for doing sets of sets. *)
module Provide_equal (_ : sig
type t = elt [@@deriving equal]
end) : sig
val equal: t -> t -> bool val equal: t -> t -> bool
(** [equal s1 s2] tests whether the sets [s1] and [s2] are (** [equal s1 s2] tests whether the sets [s1] and [s2] are
equal, that is, contain equal elements. *) equal, that is, contain equal elements. *)
end
val subset: t -> t -> bool val subset: t -> t -> bool
(** [subset s1 s2] tests whether the set [s1] is a subset of (** [subset s1 s2] tests whether the set [s1] is a subset of
@ -298,9 +304,30 @@ module type S =
val of_seq : elt Seq.t -> t val of_seq : elt Seq.t -> t
(** Build a set from the given bindings (** Build a set from the given bindings
@since 4.07 *) @since 4.07 *)
module Provide_sexp_of (_ : sig
type t = elt [@@deriving sexp_of]
end) : sig
type t [@@deriving sexp_of]
end
with type t := t
module Provide_of_sexp (_ : sig
type t = elt [@@deriving of_sexp]
end) : sig
type t [@@deriving of_sexp]
end
with type t := t
end end
(** Output signature of the functor {!Set.Make}. *) (** Output signature of the functor {!Set.Make}. *)
module Make (Ord : OrderedType) : S with type elt = Ord.t type ('elt, 'cmp) t [@@deriving compare, equal, sexp]
type 'compare_elt compare [@@deriving compare, equal, sexp]
module Make (Ord : Comparer.S) :
S with type elt = Ord.t
with type t = (Ord.t, Ord.compare) t
with type compare = Ord.compare compare
(** Functor building an implementation of the set structure (** Functor building an implementation of the set structure
given a totally ordered type. *) given a totally ordered type. *)

Loading…
Cancel
Save