[sledge] Generalize Multiset over type of multiplicities

Reviewed By: jvillard

Differential Revision: D24306070

fbshipit-source-id: f4df5aafa
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent bd49ad84a8
commit df35f9702a

@ -5,61 +5,58 @@
* LICENSE file in the root directory of this source tree.
*)
(** Multiset - Set with (signed) rational multiplicity for each element *)
(** Multiset - Set with multiplicity for each element *)
open NS0
include Multiset_intf
module Make (Elt : sig
type t [@@deriving compare, sexp_of]
end) =
module Make
(Mul : MULTIPLICITY) (Elt : sig
type t [@@deriving compare, sexp_of]
end) =
struct
module M = Map.Make (Elt)
type mul = Mul.t
type elt = Elt.t
type t = Q.t M.t
type t = Mul.t M.t
let compare = M.compare Q.compare
let equal = M.equal Q.equal
let compare = M.compare Mul.compare
let equal = M.equal Mul.equal
let hash_fold_t hash_fold_elt s m =
let hash_fold_q s q = Hash.fold_int s (Hashtbl.hash q) in
let hash_fold_mul s i = Hash.fold_int s (Mul.hash i) in
M.fold m
~init:(Hash.fold_int s (M.length m))
~f:(fun ~key ~data state -> hash_fold_q (hash_fold_elt state key) data)
~f:(fun ~key ~data state ->
hash_fold_mul (hash_fold_elt state key) data )
let sexp_of_t s =
let sexp_of_q q = Sexp.Atom (Q.to_string q) in
List.sexp_of_t
(Sexplib.Conv.sexp_of_pair Elt.sexp_of_t sexp_of_q)
(Sexplib.Conv.sexp_of_pair Elt.sexp_of_t Mul.sexp_of_t)
(M.to_alist s)
let t_of_sexp elt_of_sexp sexp =
let q_of_sexp = function
| Sexp.Atom s -> Q.of_string s
| _ -> assert false
in
List.fold_left
~f:(fun m (key, data) -> M.add_exn m ~key ~data)
~init:M.empty
(List.t_of_sexp
(Sexplib.Conv.pair_of_sexp elt_of_sexp q_of_sexp)
(Sexplib.Conv.pair_of_sexp elt_of_sexp Mul.t_of_sexp)
sexp)
let pp sep pp_elt fs s = List.pp sep pp_elt fs (M.to_alist s)
let empty = M.empty
let of_ = M.singleton
let if_nz q = if Q.equal Q.zero q then None else Some q
let if_nz q = if Mul.equal Mul.zero q then None else Some q
let add m x i =
M.change m x ~f:(function Some j -> if_nz Q.(i + j) | None -> if_nz i)
M.change m x ~f:(function
| Some j -> if_nz (Mul.add i j)
| None -> if_nz i )
let remove m x = M.remove m x
let find_and_remove = M.find_and_remove
let union m n =
M.merge m n ~f:(fun ~key:_ -> function
| `Both (i, j) -> if_nz Q.(i + j) | `Left i | `Right i -> Some i )
let union m n = M.union m n ~f:(fun _ i j -> if_nz (Mul.add i j))
let map m ~f =
let m' = empty in
@ -67,7 +64,7 @@ struct
M.fold m ~init:(m, m') ~f:(fun ~key:x ~data:i (m, m') ->
let x', i' = f x i in
if x' == x then
if Q.equal i' i then (m, m') else (M.set m ~key:x ~data:i', m')
if Mul.equal i' i then (m, m') else (M.set m ~key:x ~data:i', m')
else (M.remove m x, add m' x' i') )
in
M.fold m' ~init:m ~f:(fun ~key:x ~data:i m -> add m x i)
@ -76,7 +73,7 @@ struct
let is_empty = M.is_empty
let is_singleton = M.is_singleton
let length m = M.length m
let count m x = match M.find m x with Some q -> q | None -> Q.zero
let count m x = match M.find m x with Some q -> q | None -> Mul.zero
let choose = M.choose
let choose_exn = M.choose_exn
let pop = M.pop

@ -9,6 +9,7 @@
include module type of Multiset_intf
module Make (Elt : sig
type t [@@deriving compare, sexp_of]
end) : S with type elt = Elt.t
module Make
(Mul : MULTIPLICITY) (Elt : sig
type t [@@deriving compare, sexp_of]
end) : S with type mul = Mul.t with type elt = Elt.t

@ -9,7 +9,17 @@
open NS0
module type MULTIPLICITY = sig
type t [@@deriving compare, equal, hash, sexp]
val zero : t
val add : t -> t -> t
val sub : t -> t -> t
val neg : t -> t
end
module type S = sig
type mul
type elt
type t
@ -18,16 +28,16 @@ module type S = sig
val hash_fold_t : elt Hash.folder -> t Hash.folder
val sexp_of_t : t -> Sexp.t
val t_of_sexp : (Sexp.t -> elt) -> Sexp.t -> t
val pp : (unit, unit) fmt -> (elt * Q.t) pp -> t pp
val pp : (unit, unit) fmt -> (elt * mul) pp -> t pp
(* constructors *)
val empty : t
(** The empty multiset over the provided order. *)
val of_ : elt -> Q.t -> t
val of_ : elt -> mul -> t
val add : t -> elt -> Q.t -> t
val add : t -> elt -> mul -> t
(** Add to multiplicity of single element. [O(log n)] *)
val remove : t -> elt -> t
@ -36,11 +46,11 @@ module type S = sig
val union : t -> t -> t
(** Sum multiplicities pointwise. [O(n + m)] *)
val map : t -> f:(elt -> Q.t -> elt * Q.t) -> t
val map : t -> f:(elt -> mul -> elt * mul) -> t
(** Map over the elements in ascending order. Preserves physical equality
if [f] does. *)
val map_counts : t -> f:(elt -> Q.t -> Q.t) -> t
val map_counts : t -> f:(elt -> mul -> mul) -> t
(** Map over the multiplicities of the elements in ascending order. *)
(* queries *)
@ -51,44 +61,44 @@ module type S = sig
val length : t -> int
(** Number of elements with non-zero multiplicity. [O(1)]. *)
val count : t -> elt -> Q.t
val count : t -> elt -> mul
(** Multiplicity of an element. [O(log n)]. *)
val choose_exn : t -> elt * Q.t
val choose_exn : t -> elt * mul
(** Find an unspecified element. [O(1)]. *)
val choose : t -> (elt * Q.t) option
val choose : t -> (elt * mul) option
(** Find an unspecified element. [O(1)]. *)
val pop : t -> (elt * Q.t * t) option
val pop : t -> (elt * mul * t) option
(** Find and remove an unspecified element. [O(1)]. *)
val min_elt : t -> (elt * Q.t) option
val min_elt : t -> (elt * mul) option
(** Minimum element. [O(log n)]. *)
val pop_min_elt : t -> (elt * Q.t * t) option
val pop_min_elt : t -> (elt * mul * t) option
(** Find and remove minimum element. [O(log n)]. *)
val classify : t -> [`Zero | `One of elt * Q.t | `Many]
val classify : t -> [`Zero | `One of elt * mul | `Many]
(** Classify a set as either empty, singleton, or otherwise. *)
val find_and_remove : t -> elt -> (Q.t * t) option
val find_and_remove : t -> elt -> (mul * t) option
(** Find and remove an element. *)
val to_list : t -> (elt * Q.t) list
val to_list : t -> (elt * mul) list
(** Convert to a list of elements in ascending order. *)
(* traversals *)
val iter : t -> f:(elt -> Q.t -> unit) -> unit
val iter : t -> f:(elt -> mul -> unit) -> unit
(** Iterate over the elements in ascending order. *)
val exists : t -> f:(elt -> Q.t -> bool) -> bool
val exists : t -> f:(elt -> mul -> bool) -> bool
(** Search for an element satisfying a predicate. *)
val for_all : t -> f:(elt -> Q.t -> bool) -> bool
val for_all : t -> f:(elt -> mul -> bool) -> bool
(** Test whether all elements satisfy a predicate. *)
val fold : t -> f:(elt -> Q.t -> 's -> 's) -> init:'s -> 's
val fold : t -> f:(elt -> mul -> 's -> 's) -> init:'s -> 's
(** Fold over the elements in ascending order. *)
end

@ -44,11 +44,11 @@ end = struct
end
and Qset : sig
include NS.Multiset.S with type elt := T.t
include NS.Multiset.S with type mul := Q.t with type elt := T.t
val t_of_sexp : Sexp.t -> t
end = struct
include NS.Multiset.Make (T)
include NS.Multiset.Make (Q) (T)
let t_of_sexp = t_of_sexp T.t_of_sexp
end

@ -57,7 +57,7 @@ module rec Set : sig
end
and Qset : sig
include NS.Multiset.S with type elt := T.t
include NS.Multiset.S with type mul := Q.t with type elt := T.t
val t_of_sexp : Sexp.t -> t
end

Loading…
Cancel
Save