diff --git a/sledge/bin/sledge.ml b/sledge/bin/sledge.ml index b652bceaa..e935cefa6 100644 --- a/sledge/bin/sledge.ml +++ b/sledge/bin/sledge.ml @@ -81,7 +81,7 @@ let used_globals pgm preanalyze : Domain_used_globals.r = ; globals= Declared Reg.Set.empty } pgm in - Per_function (Map.map summary_table ~f:Reg.Set.union_list) + Per_function (Reg.Map.map summary_table ~f:Reg.Set.union_list) else Declared (Vector.fold pgm.globals ~init:Reg.Set.empty ~f:(fun acc g -> diff --git a/sledge/lib/control.ml b/sledge/lib/control.ml index 32ffb4c43..01782ea51 100644 --- a/sledge/lib/control.ml +++ b/sledge/lib/control.ml @@ -168,7 +168,6 @@ module Make (Dom : Domain_intf.Dom) = struct end include T - include Comparator.Make (T) let pp fs {dst; src} = Format.fprintf fs "#%i %%%s <--%a" dst.sort_index dst.lbl @@ -178,27 +177,27 @@ module Make (Dom : Domain_intf.Dom) = struct end module Depths = struct - type t = int Map.M(Edge).t + module M = Map.Make (Edge) - let empty = Map.empty (module Edge) - let find = Map.find - let set = Map.set + type t = int M.t + + let empty = M.empty + let find = M.find + let set = M.set let join x y = - Map.merge x y ~f:(fun ~key:_ -> function + M.merge x y ~f:(fun ~key:_ -> function | `Left d | `Right d -> Some d | `Both (d1, d2) -> Some (Int.max d1 d2) ) end type priority = int * Edge.t [@@deriving compare] type priority_queue = priority Fheap.t - type waiting_states = (Dom.t * Depths.t) list Map.M(Llair.Block).t + type waiting_states = (Dom.t * Depths.t) list Llair.Block.Map.t type t = priority_queue * waiting_states * int type x = Depths.t -> t -> t - let empty_waiting_states : waiting_states = - Map.empty (module Llair.Block) - + let empty_waiting_states : waiting_states = Llair.Block.Map.empty let pp_priority fs (n, e) = Format.fprintf fs "%i: %a" n Edge.pp e let pp fs pq = @@ -221,7 +220,9 @@ module Make (Dom : Domain_intf.Dom) = struct let pq = Fheap.add pq (depth, edge) in [%Trace.info "@[<6>enqueue %i: %a@ | %a@]" depth Edge.pp edge pp pq] ; let depths = Depths.set depths ~key:edge ~data:depth in - let ws = Map.add_multi ws ~key:curr ~data:(state, depths) in + let ws = + Llair.Block.Map.add_multi ws ~key:curr ~data:(state, depths) + in (pq, ws, bound) let init state curr bound = @@ -231,7 +232,7 @@ module Make (Dom : Domain_intf.Dom) = struct let rec run ~f (pq0, ws, bnd) = match Fheap.pop pq0 with | Some ((_, ({Edge.dst; stk} as edge)), pq) -> ( - match Map.find_and_remove ws dst with + match Llair.Block.Map.find_and_remove ws dst with | Some (q :: qs, ws) -> let join (qa, da) (q, d) = (Dom.join q qa, Depths.join d da) in let skipped, (qs, depths) = @@ -240,7 +241,7 @@ module Make (Dom : Domain_intf.Dom) = struct | Some joined, depths -> (skipped, (joined, depths)) | None, _ -> (curr :: skipped, joined) ) in - let ws = Map.add_exn ws ~key:dst ~data:skipped in + let ws = Llair.Block.Map.add_exn ws ~key:dst ~data:skipped in run ~f (f stk qs dst depths (pq, ws, bnd)) | _ -> [%Trace.info "done: %a" Edge.pp edge] ; @@ -489,5 +490,5 @@ module Make (Dom : Domain_intf.Dom) = struct assert opts.function_summaries ; exec_pgm opts pgm ; Hashtbl.fold summary_table ~init:Reg.Map.empty ~f:(fun ~key ~data map -> - match data with [] -> map | _ -> Map.set map ~key ~data ) + match data with [] -> map | _ -> Reg.Map.set map ~key ~data ) end diff --git a/sledge/lib/domain_relation.ml b/sledge/lib/domain_relation.ml index 8e11a20c3..a582a663b 100644 --- a/sledge/lib/domain_relation.ml +++ b/sledge/lib/domain_relation.ml @@ -20,7 +20,7 @@ module type State_domain_sig = sig end module Make (State_domain : State_domain_sig) = struct - type t = State_domain.t * State_domain.t [@@deriving sexp_of, equal] + type t = State_domain.t * State_domain.t [@@deriving equal, sexp_of] let embed b = (b, b) diff --git a/sledge/lib/domain_sh.ml b/sledge/lib/domain_sh.ml index f6a6f57b0..0ea8f28a5 100644 --- a/sledge/lib/domain_sh.ml +++ b/sledge/lib/domain_sh.ml @@ -7,7 +7,7 @@ (** Abstract domain *) -type t = Sh.t [@@deriving equal, sexp_of] +type t = Sh.t [@@deriving equal, sexp] let pp fs q = Format.fprintf fs "@[{ %a@ }@]" Sh.pp q let report_fmt_thunk = Fn.flip pp diff --git a/sledge/lib/domain_unit.ml b/sledge/lib/domain_unit.ml index 9edfc9eaf..236449b92 100644 --- a/sledge/lib/domain_unit.ml +++ b/sledge/lib/domain_unit.ml @@ -7,7 +7,7 @@ (** "Unit" abstract domain *) -type t = unit [@@deriving equal, sexp_of] +type t = unit [@@deriving equal, sexp] let pp fs () = Format.pp_print_string fs "()" let report_fmt_thunk () fs = pp fs () diff --git a/sledge/lib/domain_used_globals.ml b/sledge/lib/domain_used_globals.ml index 783babee7..106f7b839 100644 --- a/sledge/lib/domain_used_globals.ml +++ b/sledge/lib/domain_used_globals.ml @@ -7,7 +7,7 @@ (** Used-globals abstract domain *) -type t = Reg.Set.t [@@deriving equal, sexp_of] +type t = Reg.Set.t [@@deriving equal, sexp] let pp = Set.pp Reg.pp let report_fmt_thunk = Fn.flip pp @@ -60,7 +60,7 @@ let exec_intrinsic ~skip_throw:_ st _ intrinsic actuals = |> fun res -> Some (Some res) else None -type from_call = t [@@deriving sexp_of] +type from_call = t [@@deriving sexp] (* Set abstract state to bottom (i.e. empty set) at function entry *) let call ~summaries:_ ~globals:_ ~actuals ~areturn:_ ~formals:_ ~freturn:_ @@ -92,7 +92,7 @@ let by_function : r -> Reg.t -> t = ( match s with | Declared set -> set | Per_function map -> ( - match Map.find map fn with + match Reg.Map.find map fn with | Some gs -> gs | None -> fail diff --git a/sledge/lib/equality.ml b/sledge/lib/equality.ml index 38de2f3cb..d34f9b2e5 100644 --- a/sledge/lib/equality.ml +++ b/sledge/lib/equality.ml @@ -56,25 +56,26 @@ module Subst : sig val to_alist : t -> (Term.t * Term.t) list val partition_valid : Var.Set.t -> t -> t * Var.Set.t * t end = struct - type t = Term.t Term.Map.t [@@deriving compare, equal, sexp] + type t = Term.t Term.Map.t [@@deriving compare, equal, sexp_of] - let pp = Map.pp Term.pp Term.pp + let t_of_sexp = Term.Map.t_of_sexp Term.t_of_sexp Term.t_of_sexp + let pp = Term.Map.pp Term.pp Term.pp let pp_diff = - Map.pp_diff ~data_equal:Term.equal Term.pp Term.pp Term.pp_diff + Term.Map.pp_diff ~data_equal:Term.equal Term.pp Term.pp Term.pp_diff let empty = Term.Map.empty - let is_empty = Map.is_empty - let length = Map.length - let mem = Map.mem - let find = Map.find - let fold = Map.fold - let iteri = Map.iteri - let for_alli = Map.for_alli - let to_alist = Map.to_alist ~key_order:`Increasing + let is_empty = Term.Map.is_empty + let length = Term.Map.length + let mem = Term.Map.mem + let find = Term.Map.find + let fold = Term.Map.fold + let iteri = Term.Map.iteri + let for_alli = Term.Map.for_alli + let to_alist = Term.Map.to_alist (** look up a term in a substitution *) - let apply s a = Map.find s a |> Option.value ~default:a + let apply s a = Term.Map.find s a |> Option.value ~default:a let rec subst s a = apply s (Term.map ~f:(subst s) a) @@ -87,21 +88,21 @@ end = struct (** compose two substitutions *) let compose r s = - let r' = Map.map ~f:(norm s) r in - Map.merge_skewed r' s ~combine:(fun ~key v1 v2 -> + let r' = Term.Map.map ~f:(norm s) r in + Term.Map.merge_skewed r' s ~combine:(fun ~key v1 v2 -> if Term.equal v1 v2 then v1 else fail "domains intersect: %a" Term.pp key () ) (** compose a substitution with a mapping *) let compose1 ~key ~data s = if Term.equal key data then s - else compose s (Map.set Term.Map.empty ~key ~data) + else compose s (Term.Map.set Term.Map.empty ~key ~data) (** add an identity entry if the term is not already present *) let extend e s = let exception Found in match - Map.update s e ~f:(function + Term.Map.update s e ~f:(function | Some _ -> Exn.raise_without_backtrace Found | None -> e ) with @@ -112,12 +113,14 @@ end = struct [f] is injective and for any set of terms [E], [f\[E\]] is disjoint from [E] *) let map_entries ~f s = - Map.fold s ~init:s ~f:(fun ~key ~data s -> + Term.Map.fold s ~init:s ~f:(fun ~key ~data s -> let key' = f key in let data' = f data in if Term.equal key' key then - if Term.equal data' data then s else Map.set s ~key ~data:data' - else Map.remove s key |> Map.add_exn ~key:key' ~data:data' ) + if Term.equal data' data then s + else Term.Map.set s ~key ~data:data' + else Term.Map.remove s key |> Term.Map.add_exn ~key:key' ~data:data' + ) (** Holds only if [true ⊢ ∃xs. e=f]. Clients assume [not (is_valid_eq xs e f)] implies [not (is_valid_eq ys e f)] for @@ -141,12 +144,12 @@ end = struct valid, so loop until no change. *) let rec partition_valid_ t ks s = let t', ks', s' = - Map.fold s ~init:(t, ks, s) ~f:(fun ~key ~data (t, ks, s) -> + Term.Map.fold s ~init:(t, ks, s) ~f:(fun ~key ~data (t, ks, s) -> if is_valid_eq ks key data then (t, ks, s) else - let t = Map.set ~key ~data t + let t = Term.Map.set ~key ~data t and ks = Set.diff ks (Set.union (Term.fv key) (Term.fv data)) - and s = Map.remove s key in + and s = Term.Map.remove s key in (t, ks, s) ) in if s' != s then partition_valid_ t' ks' s' else (t', ks', s') @@ -327,7 +330,7 @@ type t = let classes r = let add key data cls = if Term.equal key data then cls - else Map.add_multi cls ~key:data ~data:key + else Term.Map.add_multi cls ~key:data ~data:key in Subst.fold r.rep ~init:Term.Map.empty ~f:(fun ~key ~data cls -> match classify key with @@ -337,7 +340,7 @@ let classes r = let cls_of r e = let e' = Subst.apply r.rep e in - Map.find (classes r) e' |> Option.value ~default:[e'] + Term.Map.find (classes r) e' |> Option.value ~default:[e'] (** Pretty-printing *) @@ -373,12 +376,13 @@ let ppx_clss x fs cs = (fun fs (key, data) -> Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) key (ppx_cls x) (List.sort ~compare:Term.compare data) ) - fs (Map.to_alist cs) + fs (Term.Map.to_alist cs) let pp_clss fs cs = ppx_clss (fun _ -> None) fs cs let pp_diff_clss = - Map.pp_diff ~data_equal:(List.equal Term.equal) Term.pp pp_cls pp_diff_cls + Term.Map.pp_diff ~data_equal:(List.equal Term.equal) Term.pp pp_cls + pp_diff_cls (** Invariant *) @@ -525,7 +529,7 @@ let normalize = canon let class_of r e = let e' = normalize r e in - e' :: Map.find_multi (classes r) e' + e' :: Term.Map.find_multi (classes r) e' let fold_uses_of r t ~init ~f = let rec fold_ e ~init:s ~f = @@ -558,7 +562,7 @@ let difference r a b = let apply_subst us s r = [%Trace.call fun {pf} -> pf "%a@ %a" Subst.pp s pp r] ; - Map.fold (classes r) ~init:true_ ~f:(fun ~key:rep ~data:cls r -> + Term.Map.fold (classes r) ~init:true_ ~f:(fun ~key:rep ~data:cls r -> let rep' = Subst.subst s rep in List.fold cls ~init:r ~f:(fun r trm -> let trm' = Subst.subst s trm in @@ -585,7 +589,7 @@ let or_ us r s = else if not r.sat then s else let merge_mems rs r s = - Map.fold (classes s) ~init:rs ~f:(fun ~key:rep ~data:cls rs -> + Term.Map.fold (classes s) ~init:rs ~f:(fun ~key:rep ~data:cls rs -> List.fold cls ~init:([rep], rs) ~f:(fun (reps, rs) exp -> @@ -651,7 +655,7 @@ let ppx_classes x fs r = ppx_clss x fs (classes r) let ppx_classes_diff x fs (r, s) = let clss = classes s in let clss = - Map.filter_mapi clss ~f:(fun ~key:rep ~data:cls -> + Term.Map.filter_mapi clss ~f:(fun ~key:rep ~data:cls -> match List.filter cls ~f:(fun exp -> not (entails_eq r rep exp)) with @@ -663,7 +667,7 @@ let ppx_classes_diff x fs (r, s) = Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) rep (List.pp "@ = " (Term.ppx x)) (List.dedup_and_sort ~compare:Term.compare cls) ) - fs (Map.to_alist clss) + fs (Term.Map.to_alist clss) (** Existential Witnessing and Elimination *) @@ -876,8 +880,8 @@ let solve_class us us_xs ~key:rep ~data:cls (classes, subst) = |> Option.value ~default:cls in let classes = - if List.is_empty cls then Map.remove classes rep - else Map.set classes ~key:rep ~data:cls + if List.is_empty cls then Term.Map.remove classes rep + else Term.Map.set classes ~key:rep ~data:cls in (classes, subst) |> @@ -954,7 +958,8 @@ let solve_classes r (classes, subst, us) xs = ; let rec solve_classes_ (classes0, subst0, us_xs) = let classes, subst = - Map.fold ~f:(solve_class us us_xs) classes0 ~init:(classes0, subst0) + Term.Map.fold ~f:(solve_class us us_xs) classes0 + ~init:(classes0, subst0) in if subst != subst0 then solve_classes_ (classes, subst, us_xs) else (classes, subst, us_xs) diff --git a/sledge/lib/exp.ml b/sledge/lib/exp.ml index 767b8a449..b5ea865da 100644 --- a/sledge/lib/exp.ml +++ b/sledge/lib/exp.ml @@ -84,6 +84,7 @@ module T = struct end include T +module Map = Map.Make (T) let term e = e.term @@ -328,16 +329,7 @@ module Reg = struct let vars = Set.fold ~init:Var.Set.empty ~f:(fun s r -> add s (var r)) end - module Map = struct - include ( - Map : - module type of Map - with type ('key, 'value, 'cmp) t := ('key, 'value, 'cmp) Map.t ) - - type 'v t = 'v Map.M(T).t [@@deriving compare, equal, sexp] - - let empty = Map.empty (module T) - end + module Map = Map let demangle = ref (fun _ -> None) diff --git a/sledge/lib/exp.mli b/sledge/lib/exp.mli index 5adba08a3..d0012f347 100644 --- a/sledge/lib/exp.mli +++ b/sledge/lib/exp.mli @@ -116,14 +116,7 @@ module Reg : sig val vars : t -> Var.Set.t end - module Map : sig - type reg := t - - type 'a t = (reg, 'a, comparator_witness) Map.t - [@@deriving compare, equal, sexp] - - val empty : 'a t - end + module Map : Map.S with type key := t val demangle : (string -> string option) ref val pp : t pp diff --git a/sledge/lib/import/import.ml b/sledge/lib/import/import.ml index c27f11c01..ac0304eea 100644 --- a/sledge/lib/import/import.ml +++ b/sledge/lib/import/import.ml @@ -253,52 +253,201 @@ module List = struct pp sep pp_diff_elt fs (symmetric_diff ~compare xs ys) end +module type OrderedType = sig + type t + + val compare : t -> t -> int + val sexp_of_t : t -> Sexp.t +end + +exception Duplicate + module Map = struct - include Base.Map - - let pp pp_k pp_v fs m = - Format.fprintf fs "@[<1>[%a]@]" - (List.pp ",@ " (fun fs (k, v) -> - Format.fprintf fs "@[%a @<2>↦ %a@]" pp_k k pp_v v )) - (to_alist m) - - let pp_diff ~data_equal pp_key pp_val pp_diff_val fs (x, y) = - let pp_diff_elt fs = function - | k, `Left v -> - Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v - | k, `Right v -> - Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v - | k, `Unequal vv -> - Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_diff_val vv - in - let sd = Sequence.to_list (symmetric_diff ~data_equal x y) in - if not (List.is_empty sd) then - Format.fprintf fs "[@[%a@]];@ " (List.pp ";@ " pp_diff_elt) sd - - let equal_m__t (module Elt : Compare_m) equal_v = equal equal_v - - let find_and_remove m k = - let found = ref None in - let m = - change m k ~f:(fun v -> - found := v ; - None ) - in - let+ v = !found in - (v, m) - - let find_or_add (type data) map key ~(default : data) ~if_found ~if_added - = - let exception Found of data in - match - update map key ~f:(function - | Some old_data -> Exn.raise_without_backtrace (Found old_data) - | None -> default ) - with - | exception Found old_data -> if_found old_data - | map -> if_added map + module type S = sig + type key + type +'a t + + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + val sexp_of_t : ('a -> Sexp.t) -> 'a t -> Sexp.t + val t_of_sexp : (Sexp.t -> key) -> (Sexp.t -> 'a) -> Sexp.t -> 'a t + val pp : key pp -> 'a pp -> 'a t pp + + val pp_diff : + data_equal:('a -> 'a -> bool) + -> key pp + -> 'a pp + -> ('a * 'a) pp + -> ('a t * 'a t) pp + + (* initial constructors *) + val empty : 'a t + + (* constructors *) + val set : 'a t -> key:key -> data:'a -> 'a t + val add_exn : 'a t -> key:key -> data:'a -> 'a t + val add_multi : 'a list t -> key:key -> data:'a -> 'a list t + val remove : 'a t -> key -> 'a t + val update : 'a t -> key -> f:('a option -> 'a) -> 'a t + + val merge : + 'a t + -> 'b t + -> f: + ( key:key + -> [`Both of 'a * 'b | `Left of 'a | `Right of 'b] + -> 'c option) + -> 'c t + + val merge_skewed : + 'a t -> 'a t -> combine:(key:key -> 'a -> 'a -> 'a) -> 'a t + + val map : 'a t -> f:('a -> 'b) -> 'b t + val filter_keys : 'a t -> f:(key -> bool) -> 'a t + val filter_mapi : 'a t -> f:(key:key -> data:'a -> 'b option) -> 'b t + + (* queries *) + val is_empty : 'b t -> bool + val length : 'b t -> int + val mem : 'a t -> key -> bool + val find : 'a t -> key -> 'a option + val find_and_remove : 'a t -> key -> ('a * 'a t) option + val find_multi : 'a list t -> key -> 'a list + val data : 'a t -> 'a list + val to_alist : 'a t -> (key * 'a) list + + (* traversals *) + val iter : 'a t -> f:('a -> unit) -> unit + val iteri : 'a t -> f:(key:key -> data:'a -> unit) -> unit + val for_alli : 'a t -> f:(key:key -> data:'a -> bool) -> bool + val fold : 'a t -> init:'s -> f:(key:key -> data:'a -> 's -> 's) -> 's + end - let map_preserving_phys_equal t ~f = map_preserving_phys_equal map t ~f + module Make (Key : OrderedType) : S with type key = Key.t = struct + module M = Caml.Map.Make (Key) + + type key = Key.t + type 'a t = 'a M.t + + let compare = M.compare + let equal = M.equal + + let sexp_of_t sexp_of_val m = + List.sexp_of_t + (Sexplib.Conv.sexp_of_pair Key.sexp_of_t sexp_of_val) + (M.bindings m) + + let t_of_sexp key_of_sexp val_of_sexp sexp = + Caml.List.fold_left + (fun m (k, v) -> M.add k v m) + M.empty + (List.t_of_sexp + (Sexplib.Conv.pair_of_sexp key_of_sexp val_of_sexp) + sexp) + + let pp pp_k pp_v fs m = + Format.fprintf fs "@[<1>[%a]@]" + (List.pp ",@ " (fun fs (k, v) -> + Format.fprintf fs "@[%a @<2>↦ %a@]" pp_k k pp_v v )) + (M.bindings m) + + let pp_diff ~data_equal pp_key pp_val pp_diff_val fs (x, y) = + let pp_diff_val fs = function + | k, `Left v -> + Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v + | k, `Right v -> + Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v + | k, `Unequal vv -> + Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_diff_val vv + in + let sd = + M.merge + (fun _ v1o v2o -> + match (v1o, v2o) with + | Some v1, Some v2 when not (data_equal v1 v2) -> + Some (`Unequal (v1, v2)) + | Some v1, None -> Some (`Left v1) + | None, Some v2 -> Some (`Right v2) + | _ -> None ) + x y + in + if not (M.is_empty sd) then + Format.fprintf fs "[@[%a@]];@ " + (List.pp ";@ " pp_diff_val) + (M.bindings sd) + + exception Duplicate + + let empty = M.empty + let set m ~key ~data = M.add key data m + + let add_exn m ~key ~data = + M.update key + (function None -> Some data | Some _ -> raise Duplicate) + m + + let add_multi m ~key ~data = + M.update key + (function None -> Some [data] | Some vs -> Some (data :: vs)) + m + + let remove m k = M.remove k m + let update m k ~f = M.update k (fun vo -> Some (f vo)) m + + let merge m n ~f = + M.merge + (fun k v1o v2o -> + match (v1o, v2o) with + | Some v1, Some v2 -> f ~key:k (`Both (v1, v2)) + | Some v1, None -> f ~key:k (`Left v1) + | None, Some v2 -> f ~key:k (`Right v2) + | None, None -> None ) + m n + + let merge_skewed m n ~combine = + M.merge + (fun k v1o v2o -> + match (v1o, v2o) with + | Some v1, Some v2 -> Some (combine ~key:k v1 v2) + | Some _, None -> v1o + | None, Some _ -> v2o + | None, None -> None ) + m n + + let map m ~f = M.map f m + let filter_keys m ~f = M.filter (fun k _ -> f k) m + + let filter_mapi m ~f = + M.fold + (fun k v m -> + match f ~key:k ~data:v with Some v' -> M.add k v' m | None -> m ) + m M.empty + + let is_empty = M.is_empty + let length = M.cardinal + let mem m k = M.mem k m + let find m k = M.find_opt k m + + let find_and_remove m k = + let found = ref None in + let m = + M.update k + (fun v -> + found := v ; + None ) + m + in + let+ v = !found in + (v, m) + + let find_multi m k = try M.find k m with Not_found -> [] + let data m = M.fold (fun _ v s -> v :: s) m [] + let to_alist = M.bindings + let iter m ~f = M.iter (fun _ v -> f v) m + let iteri m ~f = M.iter (fun k v -> f ~key:k ~data:v) m + let for_alli m ~f = M.for_all (fun key data -> f ~key ~data) m + let fold m ~init ~f = M.fold (fun key data s -> f ~key ~data s) m init + end end module Result = struct @@ -358,6 +507,15 @@ module Array = struct let pp sep pp_elt fs a = List.pp sep pp_elt fs (to_list a) end +module String = struct + include String + + let t_of_sexp = Sexplib.Conv.string_of_sexp + let sexp_of_t = Sexplib.Conv.sexp_of_string + + module Map = Map.Make (String) +end + module Q = struct let pp = Q.pp_print let hash = Hashtbl.hash diff --git a/sledge/lib/import/import.mli b/sledge/lib/import/import.mli index 05edbba6f..e9a2fcf53 100644 --- a/sledge/lib/import/import.mli +++ b/sledge/lib/import/import.mli @@ -192,39 +192,77 @@ module List : sig compare:('a -> 'a -> int) -> 'a t -> 'a t -> ('a, 'a) Either.t t end -module Map : sig - include module type of Base.Map +module type OrderedType = sig + type t - val pp : 'k pp -> 'v pp -> ('k, 'v, 'c) t pp + val compare : t -> t -> int + val sexp_of_t : t -> Sexp.t +end - val pp_diff : - data_equal:('v -> 'v -> bool) - -> 'k pp - -> 'v pp - -> ('v * 'v) pp - -> (('k, 'v, 'c) t * ('k, 'v, 'c) t) pp +exception Duplicate - val equal_m__t : - (module Compare_m) - -> ('v -> 'v -> bool) - -> ('k, 'v, 'c) t - -> ('k, 'v, 'c) t - -> bool - - val find_and_remove : ('k, 'v, 'c) t -> 'k -> ('v * ('k, 'v, 'c) t) option - - val find_or_add : - ('k, 'v, 'c) t - -> 'k - -> default:'v - -> if_found:('v -> 'a) - -> if_added:(('k, 'v, 'c) t -> 'a) - -> 'a - - val map_preserving_phys_equal : - ('k, 'v, 'c) t -> f:('v -> 'v) -> ('k, 'v, 'c) t - (** Like map, but preserves [phys_equal] if [f] preserves [phys_equal] of - every element. *) +module Map : sig + module type S = sig + type key + type +'a t + + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + val sexp_of_t : ('a -> Sexp.t) -> 'a t -> Sexp.t + val t_of_sexp : (Sexp.t -> key) -> (Sexp.t -> 'a) -> Sexp.t -> 'a t + val pp : key pp -> 'a pp -> 'a t pp + + val pp_diff : + data_equal:('a -> 'a -> bool) + -> key pp + -> 'a pp + -> ('a * 'a) pp + -> ('a t * 'a t) pp + + (* initial constructors *) + val empty : 'a t + + (* constructors *) + val set : 'a t -> key:key -> data:'a -> 'a t + val add_exn : 'a t -> key:key -> data:'a -> 'a t + val add_multi : 'a list t -> key:key -> data:'a -> 'a list t + val remove : 'a t -> key -> 'a t + val update : 'a t -> key -> f:('a option -> 'a) -> 'a t + + val merge : + 'a t + -> 'b t + -> f: + ( key:key + -> [`Both of 'a * 'b | `Left of 'a | `Right of 'b] + -> 'c option) + -> 'c t + + val merge_skewed : + 'a t -> 'a t -> combine:(key:key -> 'a -> 'a -> 'a) -> 'a t + + val map : 'a t -> f:('a -> 'b) -> 'b t + val filter_keys : 'a t -> f:(key -> bool) -> 'a t + val filter_mapi : 'a t -> f:(key:key -> data:'a -> 'b option) -> 'b t + + (* queries *) + val is_empty : 'b t -> bool + val length : 'b t -> int + val mem : 'a t -> key -> bool + val find : 'a t -> key -> 'a option + val find_and_remove : 'a t -> key -> ('a * 'a t) option + val find_multi : 'a list t -> key -> 'a list + val data : 'a t -> 'a list + val to_alist : 'a t -> (key * 'a) list + + (* traversals *) + val iter : 'a t -> f:('a -> unit) -> unit + val iteri : 'a t -> f:(key:key -> data:'a -> unit) -> unit + val for_alli : 'a t -> f:(key:key -> data:'a -> bool) -> bool + val fold : 'a t -> init:'s -> f:(key:key -> data:'a -> 's -> 's) -> 's + end + + module Make (Key : OrderedType) : S with type key = Key.t end module Result : sig @@ -277,6 +315,15 @@ module Array : sig val pp : (unit, unit) fmt -> 'a pp -> 'a array pp end +module String : sig + include module type of String + + val t_of_sexp : Sexp.t -> t + val sexp_of_t : t -> Sexp.t + + module Map : Map.S with type key = string +end + module Q : sig include module type of struct include Q end diff --git a/sledge/lib/llair.ml b/sledge/lib/llair.ml index b8c643cb1..4051545fd 100644 --- a/sledge/lib/llair.ml +++ b/sledge/lib/llair.ml @@ -114,7 +114,7 @@ let sexp_of_func {name; formals; freturn; fthrow; locals; entry} = let compare_block x y = Int.compare x.sort_index y.sort_index let equal_block x y = Int.equal x.sort_index y.sort_index -type functions = func Map.M(String).t [@@deriving sexp_of] +type functions = func String.Map.t [@@deriving sexp_of] type t = {globals: Global.t vector; functions: functions} [@@deriving sexp_of] @@ -358,7 +358,7 @@ end module Block = struct module T = struct type t = block [@@deriving compare, equal, sexp_of] end include T - include Comparator.Make (T) + module Map = Map.Make (T) let pp = pp_block @@ -471,7 +471,7 @@ module Func = struct iter_term func ~f:(fun term -> Term.invariant ~parent:func term) | _ -> assert false - let find functions name = Map.find functions name + let find functions name = String.Map.find functions name let mk ~(name : Global.t) ~formals ~freturn ~fthrow ~entry ~cfg = let locals = @@ -518,9 +518,9 @@ end let set_derived_metadata functions = let compute_roots functions = let roots = FuncQ.create () in - Map.iter functions ~f:(fun func -> + String.Map.iter functions ~f:(fun func -> FuncQ.enqueue_back_exn roots func.name.reg func ) ; - Map.iter functions ~f:(fun func -> + String.Map.iter functions ~f:(fun func -> Func.fold_term func ~init:() ~f:(fun () -> function | Call {callee; _} -> ( match Reg.of_exp callee with @@ -571,10 +571,8 @@ let set_derived_metadata functions = index := !index - 1 ) in let functions = - List.fold functions - ~init:(Map.empty (module String)) - ~f:(fun m func -> - Map.add_exn m ~key:(Reg.name func.name.reg) ~data:func ) + List.fold functions ~init:String.Map.empty ~f:(fun m func -> + String.Map.add_exn m ~key:(Reg.name func.name.reg) ~data:func ) in let roots = compute_roots functions in let tips_to_roots = topsort functions roots in @@ -599,5 +597,5 @@ let pp fs {globals; functions} = (Vector.pp "@\n@\n" Global.pp_defn) globals (List.pp "@\n@\n" Func.pp) - ( Map.data functions + ( String.Map.data functions |> List.sort ~compare:(fun x y -> compare_block x.entry y.entry) ) diff --git a/sledge/lib/llair.mli b/sledge/lib/llair.mli index faa137191..20363dd53 100644 --- a/sledge/lib/llair.mli +++ b/sledge/lib/llair.mli @@ -166,10 +166,10 @@ end module Block : sig type t = block [@@deriving compare, equal, sexp_of] - include Comparator.S with type t := t - val pp : t pp val mk : lbl:label -> cmnd:cmnd -> term:term -> block + + module Map : Map.S with type key := t end module Func : sig diff --git a/sledge/lib/sh.ml b/sledge/lib/sh.ml index 251928e06..395bc0600 100644 --- a/sledge/lib/sh.ml +++ b/sledge/lib/sh.ml @@ -98,22 +98,22 @@ let fold_vars ?ignore_cong fold_vars q ~init ~f = let rec var_strength_ xs m q = let add m v = - match Map.find m v with - | None -> Map.set m ~key:v ~data:`Anonymous - | Some `Anonymous -> Map.set m ~key:v ~data:`Existential + match Var.Map.find m v with + | None -> Var.Map.set m ~key:v ~data:`Anonymous + | Some `Anonymous -> Var.Map.set m ~key:v ~data:`Existential | Some _ -> m in let xs = Set.union xs q.xs in let m_stem = fold_vars_stem ~ignore_cong:() q ~init:m ~f:(fun m var -> - if not (Set.mem xs var) then Map.set m ~key:var ~data:`Universal + if not (Set.mem xs var) then Var.Map.set m ~key:var ~data:`Universal else add m var ) in let m = List.fold ~init:m_stem q.djns ~f:(fun m djn -> let ms = List.map ~f:(fun dj -> snd (var_strength_ xs m dj)) djn in List.reduce_balanced ms ~f:(fun m1 m2 -> - Map.merge_skewed m1 m2 ~combine:(fun ~key:_ s1 s2 -> + Var.Map.merge_skewed m1 m2 ~combine:(fun ~key:_ s1 s2 -> match (s1, s2) with | `Anonymous, `Anonymous -> `Anonymous | `Universal, _ | _, `Universal -> `Universal @@ -125,7 +125,7 @@ let rec var_strength_ xs m q = let var_strength_full ?(xs = Var.Set.empty) q = let m = Set.fold xs ~init:Var.Map.empty ~f:(fun m x -> - Map.set m ~key:x ~data:`Existential ) + Var.Map.set m ~key:x ~data:`Existential ) in var_strength_ xs m q @@ -212,7 +212,7 @@ let pp_us ?(pre = ("" : _ fmt)) ?vs () fs us = let rec pp_ ?var_strength vs parent_xs parent_cong fs {us; xs; cong; pure; heap; djns} = Format.pp_open_hvbox fs 0 ; - let x v = Option.bind ~f:(fun (_, m) -> Map.find m v) var_strength in + let x v = Option.bind ~f:(fun (_, m) -> Var.Map.find m v) var_strength in pp_us ~vs () fs us ; let xs_d_vs, xs_i_vs = Set.diff_inter diff --git a/sledge/lib/sh.mli b/sledge/lib/sh.mli index 386cb72ee..b0578c035 100644 --- a/sledge/lib/sh.mli +++ b/sledge/lib/sh.mli @@ -23,7 +23,7 @@ type starjunction = private and disjunction = starjunction list -type t = starjunction [@@deriving equal, compare, sexp] +type t = starjunction [@@deriving compare, equal, sexp] val pp_seg_norm : Equality.t -> seg pp val pp_us : ?pre:('a, 'a) fmt -> ?vs:Var.Set.t -> unit -> Var.Set.t pp diff --git a/sledge/lib/term.ml b/sledge/lib/term.ml index 3e5deea84..499131931 100644 --- a/sledge/lib/term.ml +++ b/sledge/lib/term.ml @@ -116,17 +116,7 @@ end type _t = T0.t include T - -module Map = struct - include ( - Map : - module type of Map - with type ('key, 'value, 'cmp) t := ('key, 'value, 'cmp) Map.t ) - - type 'v t = 'v Map.M(T).t [@@deriving compare, equal, sexp] - - let empty = empty (module T) -end +module Map = Map.Make (T) let empty_qset = Qset.empty (module T) @@ -370,7 +360,9 @@ module Var = struct (** Variable renaming substitutions *) module Subst = struct - type t = T.t Map.M(T).t [@@deriving compare, equal, sexp] + type t = T.t Map.t [@@deriving compare, equal, sexp_of] + + let t_of_sexp = Map.t_of_sexp T.t_of_sexp T.t_of_sexp let invariant s = Invariant.invariant [%here] s [%sexp_of: t] diff --git a/sledge/lib/term.mli b/sledge/lib/term.mli index 49977f386..a60722291 100644 --- a/sledge/lib/term.mli +++ b/sledge/lib/term.mli @@ -110,14 +110,7 @@ module Var : sig val union_list : t list -> t end - module Map : sig - type var := t - - type 'a t = (var, 'a, comparator_witness) Map.t - [@@deriving compare, equal, sexp] - - val empty : 'a t - end + module Map : Map.S with type key := t val pp : t pp @@ -147,16 +140,8 @@ module Var : sig end end -module Map : sig - type term := t - - type 'a t = (term, 'a, comparator_witness) Map.t - [@@deriving compare, equal, sexp] - - val empty : 'a t -end +module Map : Map.S with type key := t -val comparator : (t, comparator_witness) Comparator.t val ppx : Var.strength -> t pp val pp : t pp val pp_diff : (t * t) pp