From 57a8748e9f4a519fa7aa9135600bea33426169aa Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Tue, 17 Mar 2020 09:04:31 -0700 Subject: [PATCH] [sledge] Do not use Base.Map Summary: The base containers have inconvenient interfaces due to lacking support for functors, which also leads to the representation of values of containers including closures for the comparison functions. This causes problems when `Marshal`ing these values. This diff is one step toward not using the base containers. Reviewed By: ngorogiannis Differential Revision: D20482763 fbshipit-source-id: f55f91bf2 --- sledge/bin/sledge.ml | 2 +- sledge/lib/control.ml | 29 ++-- sledge/lib/domain_relation.ml | 2 +- sledge/lib/domain_sh.ml | 2 +- sledge/lib/domain_unit.ml | 2 +- sledge/lib/domain_used_globals.ml | 6 +- sledge/lib/equality.ml | 73 ++++----- sledge/lib/exp.ml | 12 +- sledge/lib/exp.mli | 9 +- sledge/lib/import/import.ml | 246 ++++++++++++++++++++++++------ sledge/lib/import/import.mli | 107 +++++++++---- sledge/lib/llair.ml | 18 +-- sledge/lib/llair.mli | 4 +- sledge/lib/sh.ml | 14 +- sledge/lib/sh.mli | 2 +- sledge/lib/term.ml | 16 +- sledge/lib/term.mli | 19 +-- 17 files changed, 367 insertions(+), 196 deletions(-) diff --git a/sledge/bin/sledge.ml b/sledge/bin/sledge.ml index b652bceaa..e935cefa6 100644 --- a/sledge/bin/sledge.ml +++ b/sledge/bin/sledge.ml @@ -81,7 +81,7 @@ let used_globals pgm preanalyze : Domain_used_globals.r = ; globals= Declared Reg.Set.empty } pgm in - Per_function (Map.map summary_table ~f:Reg.Set.union_list) + Per_function (Reg.Map.map summary_table ~f:Reg.Set.union_list) else Declared (Vector.fold pgm.globals ~init:Reg.Set.empty ~f:(fun acc g -> diff --git a/sledge/lib/control.ml b/sledge/lib/control.ml index 32ffb4c43..01782ea51 100644 --- a/sledge/lib/control.ml +++ b/sledge/lib/control.ml @@ -168,7 +168,6 @@ module Make (Dom : Domain_intf.Dom) = struct end include T - include Comparator.Make (T) let pp fs {dst; src} = Format.fprintf fs "#%i %%%s <--%a" dst.sort_index dst.lbl @@ -178,27 +177,27 @@ module Make (Dom : Domain_intf.Dom) = struct end module Depths = struct - type t = int Map.M(Edge).t + module M = Map.Make (Edge) - let empty = Map.empty (module Edge) - let find = Map.find - let set = Map.set + type t = int M.t + + let empty = M.empty + let find = M.find + let set = M.set let join x y = - Map.merge x y ~f:(fun ~key:_ -> function + M.merge x y ~f:(fun ~key:_ -> function | `Left d | `Right d -> Some d | `Both (d1, d2) -> Some (Int.max d1 d2) ) end type priority = int * Edge.t [@@deriving compare] type priority_queue = priority Fheap.t - type waiting_states = (Dom.t * Depths.t) list Map.M(Llair.Block).t + type waiting_states = (Dom.t * Depths.t) list Llair.Block.Map.t type t = priority_queue * waiting_states * int type x = Depths.t -> t -> t - let empty_waiting_states : waiting_states = - Map.empty (module Llair.Block) - + let empty_waiting_states : waiting_states = Llair.Block.Map.empty let pp_priority fs (n, e) = Format.fprintf fs "%i: %a" n Edge.pp e let pp fs pq = @@ -221,7 +220,9 @@ module Make (Dom : Domain_intf.Dom) = struct let pq = Fheap.add pq (depth, edge) in [%Trace.info "@[<6>enqueue %i: %a@ | %a@]" depth Edge.pp edge pp pq] ; let depths = Depths.set depths ~key:edge ~data:depth in - let ws = Map.add_multi ws ~key:curr ~data:(state, depths) in + let ws = + Llair.Block.Map.add_multi ws ~key:curr ~data:(state, depths) + in (pq, ws, bound) let init state curr bound = @@ -231,7 +232,7 @@ module Make (Dom : Domain_intf.Dom) = struct let rec run ~f (pq0, ws, bnd) = match Fheap.pop pq0 with | Some ((_, ({Edge.dst; stk} as edge)), pq) -> ( - match Map.find_and_remove ws dst with + match Llair.Block.Map.find_and_remove ws dst with | Some (q :: qs, ws) -> let join (qa, da) (q, d) = (Dom.join q qa, Depths.join d da) in let skipped, (qs, depths) = @@ -240,7 +241,7 @@ module Make (Dom : Domain_intf.Dom) = struct | Some joined, depths -> (skipped, (joined, depths)) | None, _ -> (curr :: skipped, joined) ) in - let ws = Map.add_exn ws ~key:dst ~data:skipped in + let ws = Llair.Block.Map.add_exn ws ~key:dst ~data:skipped in run ~f (f stk qs dst depths (pq, ws, bnd)) | _ -> [%Trace.info "done: %a" Edge.pp edge] ; @@ -489,5 +490,5 @@ module Make (Dom : Domain_intf.Dom) = struct assert opts.function_summaries ; exec_pgm opts pgm ; Hashtbl.fold summary_table ~init:Reg.Map.empty ~f:(fun ~key ~data map -> - match data with [] -> map | _ -> Map.set map ~key ~data ) + match data with [] -> map | _ -> Reg.Map.set map ~key ~data ) end diff --git a/sledge/lib/domain_relation.ml b/sledge/lib/domain_relation.ml index 8e11a20c3..a582a663b 100644 --- a/sledge/lib/domain_relation.ml +++ b/sledge/lib/domain_relation.ml @@ -20,7 +20,7 @@ module type State_domain_sig = sig end module Make (State_domain : State_domain_sig) = struct - type t = State_domain.t * State_domain.t [@@deriving sexp_of, equal] + type t = State_domain.t * State_domain.t [@@deriving equal, sexp_of] let embed b = (b, b) diff --git a/sledge/lib/domain_sh.ml b/sledge/lib/domain_sh.ml index f6a6f57b0..0ea8f28a5 100644 --- a/sledge/lib/domain_sh.ml +++ b/sledge/lib/domain_sh.ml @@ -7,7 +7,7 @@ (** Abstract domain *) -type t = Sh.t [@@deriving equal, sexp_of] +type t = Sh.t [@@deriving equal, sexp] let pp fs q = Format.fprintf fs "@[{ %a@ }@]" Sh.pp q let report_fmt_thunk = Fn.flip pp diff --git a/sledge/lib/domain_unit.ml b/sledge/lib/domain_unit.ml index 9edfc9eaf..236449b92 100644 --- a/sledge/lib/domain_unit.ml +++ b/sledge/lib/domain_unit.ml @@ -7,7 +7,7 @@ (** "Unit" abstract domain *) -type t = unit [@@deriving equal, sexp_of] +type t = unit [@@deriving equal, sexp] let pp fs () = Format.pp_print_string fs "()" let report_fmt_thunk () fs = pp fs () diff --git a/sledge/lib/domain_used_globals.ml b/sledge/lib/domain_used_globals.ml index 783babee7..106f7b839 100644 --- a/sledge/lib/domain_used_globals.ml +++ b/sledge/lib/domain_used_globals.ml @@ -7,7 +7,7 @@ (** Used-globals abstract domain *) -type t = Reg.Set.t [@@deriving equal, sexp_of] +type t = Reg.Set.t [@@deriving equal, sexp] let pp = Set.pp Reg.pp let report_fmt_thunk = Fn.flip pp @@ -60,7 +60,7 @@ let exec_intrinsic ~skip_throw:_ st _ intrinsic actuals = |> fun res -> Some (Some res) else None -type from_call = t [@@deriving sexp_of] +type from_call = t [@@deriving sexp] (* Set abstract state to bottom (i.e. empty set) at function entry *) let call ~summaries:_ ~globals:_ ~actuals ~areturn:_ ~formals:_ ~freturn:_ @@ -92,7 +92,7 @@ let by_function : r -> Reg.t -> t = ( match s with | Declared set -> set | Per_function map -> ( - match Map.find map fn with + match Reg.Map.find map fn with | Some gs -> gs | None -> fail diff --git a/sledge/lib/equality.ml b/sledge/lib/equality.ml index 38de2f3cb..d34f9b2e5 100644 --- a/sledge/lib/equality.ml +++ b/sledge/lib/equality.ml @@ -56,25 +56,26 @@ module Subst : sig val to_alist : t -> (Term.t * Term.t) list val partition_valid : Var.Set.t -> t -> t * Var.Set.t * t end = struct - type t = Term.t Term.Map.t [@@deriving compare, equal, sexp] + type t = Term.t Term.Map.t [@@deriving compare, equal, sexp_of] - let pp = Map.pp Term.pp Term.pp + let t_of_sexp = Term.Map.t_of_sexp Term.t_of_sexp Term.t_of_sexp + let pp = Term.Map.pp Term.pp Term.pp let pp_diff = - Map.pp_diff ~data_equal:Term.equal Term.pp Term.pp Term.pp_diff + Term.Map.pp_diff ~data_equal:Term.equal Term.pp Term.pp Term.pp_diff let empty = Term.Map.empty - let is_empty = Map.is_empty - let length = Map.length - let mem = Map.mem - let find = Map.find - let fold = Map.fold - let iteri = Map.iteri - let for_alli = Map.for_alli - let to_alist = Map.to_alist ~key_order:`Increasing + let is_empty = Term.Map.is_empty + let length = Term.Map.length + let mem = Term.Map.mem + let find = Term.Map.find + let fold = Term.Map.fold + let iteri = Term.Map.iteri + let for_alli = Term.Map.for_alli + let to_alist = Term.Map.to_alist (** look up a term in a substitution *) - let apply s a = Map.find s a |> Option.value ~default:a + let apply s a = Term.Map.find s a |> Option.value ~default:a let rec subst s a = apply s (Term.map ~f:(subst s) a) @@ -87,21 +88,21 @@ end = struct (** compose two substitutions *) let compose r s = - let r' = Map.map ~f:(norm s) r in - Map.merge_skewed r' s ~combine:(fun ~key v1 v2 -> + let r' = Term.Map.map ~f:(norm s) r in + Term.Map.merge_skewed r' s ~combine:(fun ~key v1 v2 -> if Term.equal v1 v2 then v1 else fail "domains intersect: %a" Term.pp key () ) (** compose a substitution with a mapping *) let compose1 ~key ~data s = if Term.equal key data then s - else compose s (Map.set Term.Map.empty ~key ~data) + else compose s (Term.Map.set Term.Map.empty ~key ~data) (** add an identity entry if the term is not already present *) let extend e s = let exception Found in match - Map.update s e ~f:(function + Term.Map.update s e ~f:(function | Some _ -> Exn.raise_without_backtrace Found | None -> e ) with @@ -112,12 +113,14 @@ end = struct [f] is injective and for any set of terms [E], [f\[E\]] is disjoint from [E] *) let map_entries ~f s = - Map.fold s ~init:s ~f:(fun ~key ~data s -> + Term.Map.fold s ~init:s ~f:(fun ~key ~data s -> let key' = f key in let data' = f data in if Term.equal key' key then - if Term.equal data' data then s else Map.set s ~key ~data:data' - else Map.remove s key |> Map.add_exn ~key:key' ~data:data' ) + if Term.equal data' data then s + else Term.Map.set s ~key ~data:data' + else Term.Map.remove s key |> Term.Map.add_exn ~key:key' ~data:data' + ) (** Holds only if [true ⊢ ∃xs. e=f]. Clients assume [not (is_valid_eq xs e f)] implies [not (is_valid_eq ys e f)] for @@ -141,12 +144,12 @@ end = struct valid, so loop until no change. *) let rec partition_valid_ t ks s = let t', ks', s' = - Map.fold s ~init:(t, ks, s) ~f:(fun ~key ~data (t, ks, s) -> + Term.Map.fold s ~init:(t, ks, s) ~f:(fun ~key ~data (t, ks, s) -> if is_valid_eq ks key data then (t, ks, s) else - let t = Map.set ~key ~data t + let t = Term.Map.set ~key ~data t and ks = Set.diff ks (Set.union (Term.fv key) (Term.fv data)) - and s = Map.remove s key in + and s = Term.Map.remove s key in (t, ks, s) ) in if s' != s then partition_valid_ t' ks' s' else (t', ks', s') @@ -327,7 +330,7 @@ type t = let classes r = let add key data cls = if Term.equal key data then cls - else Map.add_multi cls ~key:data ~data:key + else Term.Map.add_multi cls ~key:data ~data:key in Subst.fold r.rep ~init:Term.Map.empty ~f:(fun ~key ~data cls -> match classify key with @@ -337,7 +340,7 @@ let classes r = let cls_of r e = let e' = Subst.apply r.rep e in - Map.find (classes r) e' |> Option.value ~default:[e'] + Term.Map.find (classes r) e' |> Option.value ~default:[e'] (** Pretty-printing *) @@ -373,12 +376,13 @@ let ppx_clss x fs cs = (fun fs (key, data) -> Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) key (ppx_cls x) (List.sort ~compare:Term.compare data) ) - fs (Map.to_alist cs) + fs (Term.Map.to_alist cs) let pp_clss fs cs = ppx_clss (fun _ -> None) fs cs let pp_diff_clss = - Map.pp_diff ~data_equal:(List.equal Term.equal) Term.pp pp_cls pp_diff_cls + Term.Map.pp_diff ~data_equal:(List.equal Term.equal) Term.pp pp_cls + pp_diff_cls (** Invariant *) @@ -525,7 +529,7 @@ let normalize = canon let class_of r e = let e' = normalize r e in - e' :: Map.find_multi (classes r) e' + e' :: Term.Map.find_multi (classes r) e' let fold_uses_of r t ~init ~f = let rec fold_ e ~init:s ~f = @@ -558,7 +562,7 @@ let difference r a b = let apply_subst us s r = [%Trace.call fun {pf} -> pf "%a@ %a" Subst.pp s pp r] ; - Map.fold (classes r) ~init:true_ ~f:(fun ~key:rep ~data:cls r -> + Term.Map.fold (classes r) ~init:true_ ~f:(fun ~key:rep ~data:cls r -> let rep' = Subst.subst s rep in List.fold cls ~init:r ~f:(fun r trm -> let trm' = Subst.subst s trm in @@ -585,7 +589,7 @@ let or_ us r s = else if not r.sat then s else let merge_mems rs r s = - Map.fold (classes s) ~init:rs ~f:(fun ~key:rep ~data:cls rs -> + Term.Map.fold (classes s) ~init:rs ~f:(fun ~key:rep ~data:cls rs -> List.fold cls ~init:([rep], rs) ~f:(fun (reps, rs) exp -> @@ -651,7 +655,7 @@ let ppx_classes x fs r = ppx_clss x fs (classes r) let ppx_classes_diff x fs (r, s) = let clss = classes s in let clss = - Map.filter_mapi clss ~f:(fun ~key:rep ~data:cls -> + Term.Map.filter_mapi clss ~f:(fun ~key:rep ~data:cls -> match List.filter cls ~f:(fun exp -> not (entails_eq r rep exp)) with @@ -663,7 +667,7 @@ let ppx_classes_diff x fs (r, s) = Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) rep (List.pp "@ = " (Term.ppx x)) (List.dedup_and_sort ~compare:Term.compare cls) ) - fs (Map.to_alist clss) + fs (Term.Map.to_alist clss) (** Existential Witnessing and Elimination *) @@ -876,8 +880,8 @@ let solve_class us us_xs ~key:rep ~data:cls (classes, subst) = |> Option.value ~default:cls in let classes = - if List.is_empty cls then Map.remove classes rep - else Map.set classes ~key:rep ~data:cls + if List.is_empty cls then Term.Map.remove classes rep + else Term.Map.set classes ~key:rep ~data:cls in (classes, subst) |> @@ -954,7 +958,8 @@ let solve_classes r (classes, subst, us) xs = ; let rec solve_classes_ (classes0, subst0, us_xs) = let classes, subst = - Map.fold ~f:(solve_class us us_xs) classes0 ~init:(classes0, subst0) + Term.Map.fold ~f:(solve_class us us_xs) classes0 + ~init:(classes0, subst0) in if subst != subst0 then solve_classes_ (classes, subst, us_xs) else (classes, subst, us_xs) diff --git a/sledge/lib/exp.ml b/sledge/lib/exp.ml index 767b8a449..b5ea865da 100644 --- a/sledge/lib/exp.ml +++ b/sledge/lib/exp.ml @@ -84,6 +84,7 @@ module T = struct end include T +module Map = Map.Make (T) let term e = e.term @@ -328,16 +329,7 @@ module Reg = struct let vars = Set.fold ~init:Var.Set.empty ~f:(fun s r -> add s (var r)) end - module Map = struct - include ( - Map : - module type of Map - with type ('key, 'value, 'cmp) t := ('key, 'value, 'cmp) Map.t ) - - type 'v t = 'v Map.M(T).t [@@deriving compare, equal, sexp] - - let empty = Map.empty (module T) - end + module Map = Map let demangle = ref (fun _ -> None) diff --git a/sledge/lib/exp.mli b/sledge/lib/exp.mli index 5adba08a3..d0012f347 100644 --- a/sledge/lib/exp.mli +++ b/sledge/lib/exp.mli @@ -116,14 +116,7 @@ module Reg : sig val vars : t -> Var.Set.t end - module Map : sig - type reg := t - - type 'a t = (reg, 'a, comparator_witness) Map.t - [@@deriving compare, equal, sexp] - - val empty : 'a t - end + module Map : Map.S with type key := t val demangle : (string -> string option) ref val pp : t pp diff --git a/sledge/lib/import/import.ml b/sledge/lib/import/import.ml index c27f11c01..ac0304eea 100644 --- a/sledge/lib/import/import.ml +++ b/sledge/lib/import/import.ml @@ -253,52 +253,201 @@ module List = struct pp sep pp_diff_elt fs (symmetric_diff ~compare xs ys) end +module type OrderedType = sig + type t + + val compare : t -> t -> int + val sexp_of_t : t -> Sexp.t +end + +exception Duplicate + module Map = struct - include Base.Map - - let pp pp_k pp_v fs m = - Format.fprintf fs "@[<1>[%a]@]" - (List.pp ",@ " (fun fs (k, v) -> - Format.fprintf fs "@[%a @<2>↦ %a@]" pp_k k pp_v v )) - (to_alist m) - - let pp_diff ~data_equal pp_key pp_val pp_diff_val fs (x, y) = - let pp_diff_elt fs = function - | k, `Left v -> - Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v - | k, `Right v -> - Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v - | k, `Unequal vv -> - Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_diff_val vv - in - let sd = Sequence.to_list (symmetric_diff ~data_equal x y) in - if not (List.is_empty sd) then - Format.fprintf fs "[@[%a@]];@ " (List.pp ";@ " pp_diff_elt) sd - - let equal_m__t (module Elt : Compare_m) equal_v = equal equal_v - - let find_and_remove m k = - let found = ref None in - let m = - change m k ~f:(fun v -> - found := v ; - None ) - in - let+ v = !found in - (v, m) - - let find_or_add (type data) map key ~(default : data) ~if_found ~if_added - = - let exception Found of data in - match - update map key ~f:(function - | Some old_data -> Exn.raise_without_backtrace (Found old_data) - | None -> default ) - with - | exception Found old_data -> if_found old_data - | map -> if_added map + module type S = sig + type key + type +'a t + + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + val sexp_of_t : ('a -> Sexp.t) -> 'a t -> Sexp.t + val t_of_sexp : (Sexp.t -> key) -> (Sexp.t -> 'a) -> Sexp.t -> 'a t + val pp : key pp -> 'a pp -> 'a t pp + + val pp_diff : + data_equal:('a -> 'a -> bool) + -> key pp + -> 'a pp + -> ('a * 'a) pp + -> ('a t * 'a t) pp + + (* initial constructors *) + val empty : 'a t + + (* constructors *) + val set : 'a t -> key:key -> data:'a -> 'a t + val add_exn : 'a t -> key:key -> data:'a -> 'a t + val add_multi : 'a list t -> key:key -> data:'a -> 'a list t + val remove : 'a t -> key -> 'a t + val update : 'a t -> key -> f:('a option -> 'a) -> 'a t + + val merge : + 'a t + -> 'b t + -> f: + ( key:key + -> [`Both of 'a * 'b | `Left of 'a | `Right of 'b] + -> 'c option) + -> 'c t + + val merge_skewed : + 'a t -> 'a t -> combine:(key:key -> 'a -> 'a -> 'a) -> 'a t + + val map : 'a t -> f:('a -> 'b) -> 'b t + val filter_keys : 'a t -> f:(key -> bool) -> 'a t + val filter_mapi : 'a t -> f:(key:key -> data:'a -> 'b option) -> 'b t + + (* queries *) + val is_empty : 'b t -> bool + val length : 'b t -> int + val mem : 'a t -> key -> bool + val find : 'a t -> key -> 'a option + val find_and_remove : 'a t -> key -> ('a * 'a t) option + val find_multi : 'a list t -> key -> 'a list + val data : 'a t -> 'a list + val to_alist : 'a t -> (key * 'a) list + + (* traversals *) + val iter : 'a t -> f:('a -> unit) -> unit + val iteri : 'a t -> f:(key:key -> data:'a -> unit) -> unit + val for_alli : 'a t -> f:(key:key -> data:'a -> bool) -> bool + val fold : 'a t -> init:'s -> f:(key:key -> data:'a -> 's -> 's) -> 's + end - let map_preserving_phys_equal t ~f = map_preserving_phys_equal map t ~f + module Make (Key : OrderedType) : S with type key = Key.t = struct + module M = Caml.Map.Make (Key) + + type key = Key.t + type 'a t = 'a M.t + + let compare = M.compare + let equal = M.equal + + let sexp_of_t sexp_of_val m = + List.sexp_of_t + (Sexplib.Conv.sexp_of_pair Key.sexp_of_t sexp_of_val) + (M.bindings m) + + let t_of_sexp key_of_sexp val_of_sexp sexp = + Caml.List.fold_left + (fun m (k, v) -> M.add k v m) + M.empty + (List.t_of_sexp + (Sexplib.Conv.pair_of_sexp key_of_sexp val_of_sexp) + sexp) + + let pp pp_k pp_v fs m = + Format.fprintf fs "@[<1>[%a]@]" + (List.pp ",@ " (fun fs (k, v) -> + Format.fprintf fs "@[%a @<2>↦ %a@]" pp_k k pp_v v )) + (M.bindings m) + + let pp_diff ~data_equal pp_key pp_val pp_diff_val fs (x, y) = + let pp_diff_val fs = function + | k, `Left v -> + Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v + | k, `Right v -> + Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v + | k, `Unequal vv -> + Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_diff_val vv + in + let sd = + M.merge + (fun _ v1o v2o -> + match (v1o, v2o) with + | Some v1, Some v2 when not (data_equal v1 v2) -> + Some (`Unequal (v1, v2)) + | Some v1, None -> Some (`Left v1) + | None, Some v2 -> Some (`Right v2) + | _ -> None ) + x y + in + if not (M.is_empty sd) then + Format.fprintf fs "[@[%a@]];@ " + (List.pp ";@ " pp_diff_val) + (M.bindings sd) + + exception Duplicate + + let empty = M.empty + let set m ~key ~data = M.add key data m + + let add_exn m ~key ~data = + M.update key + (function None -> Some data | Some _ -> raise Duplicate) + m + + let add_multi m ~key ~data = + M.update key + (function None -> Some [data] | Some vs -> Some (data :: vs)) + m + + let remove m k = M.remove k m + let update m k ~f = M.update k (fun vo -> Some (f vo)) m + + let merge m n ~f = + M.merge + (fun k v1o v2o -> + match (v1o, v2o) with + | Some v1, Some v2 -> f ~key:k (`Both (v1, v2)) + | Some v1, None -> f ~key:k (`Left v1) + | None, Some v2 -> f ~key:k (`Right v2) + | None, None -> None ) + m n + + let merge_skewed m n ~combine = + M.merge + (fun k v1o v2o -> + match (v1o, v2o) with + | Some v1, Some v2 -> Some (combine ~key:k v1 v2) + | Some _, None -> v1o + | None, Some _ -> v2o + | None, None -> None ) + m n + + let map m ~f = M.map f m + let filter_keys m ~f = M.filter (fun k _ -> f k) m + + let filter_mapi m ~f = + M.fold + (fun k v m -> + match f ~key:k ~data:v with Some v' -> M.add k v' m | None -> m ) + m M.empty + + let is_empty = M.is_empty + let length = M.cardinal + let mem m k = M.mem k m + let find m k = M.find_opt k m + + let find_and_remove m k = + let found = ref None in + let m = + M.update k + (fun v -> + found := v ; + None ) + m + in + let+ v = !found in + (v, m) + + let find_multi m k = try M.find k m with Not_found -> [] + let data m = M.fold (fun _ v s -> v :: s) m [] + let to_alist = M.bindings + let iter m ~f = M.iter (fun _ v -> f v) m + let iteri m ~f = M.iter (fun k v -> f ~key:k ~data:v) m + let for_alli m ~f = M.for_all (fun key data -> f ~key ~data) m + let fold m ~init ~f = M.fold (fun key data s -> f ~key ~data s) m init + end end module Result = struct @@ -358,6 +507,15 @@ module Array = struct let pp sep pp_elt fs a = List.pp sep pp_elt fs (to_list a) end +module String = struct + include String + + let t_of_sexp = Sexplib.Conv.string_of_sexp + let sexp_of_t = Sexplib.Conv.sexp_of_string + + module Map = Map.Make (String) +end + module Q = struct let pp = Q.pp_print let hash = Hashtbl.hash diff --git a/sledge/lib/import/import.mli b/sledge/lib/import/import.mli index 05edbba6f..e9a2fcf53 100644 --- a/sledge/lib/import/import.mli +++ b/sledge/lib/import/import.mli @@ -192,39 +192,77 @@ module List : sig compare:('a -> 'a -> int) -> 'a t -> 'a t -> ('a, 'a) Either.t t end -module Map : sig - include module type of Base.Map +module type OrderedType = sig + type t - val pp : 'k pp -> 'v pp -> ('k, 'v, 'c) t pp + val compare : t -> t -> int + val sexp_of_t : t -> Sexp.t +end - val pp_diff : - data_equal:('v -> 'v -> bool) - -> 'k pp - -> 'v pp - -> ('v * 'v) pp - -> (('k, 'v, 'c) t * ('k, 'v, 'c) t) pp +exception Duplicate - val equal_m__t : - (module Compare_m) - -> ('v -> 'v -> bool) - -> ('k, 'v, 'c) t - -> ('k, 'v, 'c) t - -> bool - - val find_and_remove : ('k, 'v, 'c) t -> 'k -> ('v * ('k, 'v, 'c) t) option - - val find_or_add : - ('k, 'v, 'c) t - -> 'k - -> default:'v - -> if_found:('v -> 'a) - -> if_added:(('k, 'v, 'c) t -> 'a) - -> 'a - - val map_preserving_phys_equal : - ('k, 'v, 'c) t -> f:('v -> 'v) -> ('k, 'v, 'c) t - (** Like map, but preserves [phys_equal] if [f] preserves [phys_equal] of - every element. *) +module Map : sig + module type S = sig + type key + type +'a t + + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + val sexp_of_t : ('a -> Sexp.t) -> 'a t -> Sexp.t + val t_of_sexp : (Sexp.t -> key) -> (Sexp.t -> 'a) -> Sexp.t -> 'a t + val pp : key pp -> 'a pp -> 'a t pp + + val pp_diff : + data_equal:('a -> 'a -> bool) + -> key pp + -> 'a pp + -> ('a * 'a) pp + -> ('a t * 'a t) pp + + (* initial constructors *) + val empty : 'a t + + (* constructors *) + val set : 'a t -> key:key -> data:'a -> 'a t + val add_exn : 'a t -> key:key -> data:'a -> 'a t + val add_multi : 'a list t -> key:key -> data:'a -> 'a list t + val remove : 'a t -> key -> 'a t + val update : 'a t -> key -> f:('a option -> 'a) -> 'a t + + val merge : + 'a t + -> 'b t + -> f: + ( key:key + -> [`Both of 'a * 'b | `Left of 'a | `Right of 'b] + -> 'c option) + -> 'c t + + val merge_skewed : + 'a t -> 'a t -> combine:(key:key -> 'a -> 'a -> 'a) -> 'a t + + val map : 'a t -> f:('a -> 'b) -> 'b t + val filter_keys : 'a t -> f:(key -> bool) -> 'a t + val filter_mapi : 'a t -> f:(key:key -> data:'a -> 'b option) -> 'b t + + (* queries *) + val is_empty : 'b t -> bool + val length : 'b t -> int + val mem : 'a t -> key -> bool + val find : 'a t -> key -> 'a option + val find_and_remove : 'a t -> key -> ('a * 'a t) option + val find_multi : 'a list t -> key -> 'a list + val data : 'a t -> 'a list + val to_alist : 'a t -> (key * 'a) list + + (* traversals *) + val iter : 'a t -> f:('a -> unit) -> unit + val iteri : 'a t -> f:(key:key -> data:'a -> unit) -> unit + val for_alli : 'a t -> f:(key:key -> data:'a -> bool) -> bool + val fold : 'a t -> init:'s -> f:(key:key -> data:'a -> 's -> 's) -> 's + end + + module Make (Key : OrderedType) : S with type key = Key.t end module Result : sig @@ -277,6 +315,15 @@ module Array : sig val pp : (unit, unit) fmt -> 'a pp -> 'a array pp end +module String : sig + include module type of String + + val t_of_sexp : Sexp.t -> t + val sexp_of_t : t -> Sexp.t + + module Map : Map.S with type key = string +end + module Q : sig include module type of struct include Q end diff --git a/sledge/lib/llair.ml b/sledge/lib/llair.ml index b8c643cb1..4051545fd 100644 --- a/sledge/lib/llair.ml +++ b/sledge/lib/llair.ml @@ -114,7 +114,7 @@ let sexp_of_func {name; formals; freturn; fthrow; locals; entry} = let compare_block x y = Int.compare x.sort_index y.sort_index let equal_block x y = Int.equal x.sort_index y.sort_index -type functions = func Map.M(String).t [@@deriving sexp_of] +type functions = func String.Map.t [@@deriving sexp_of] type t = {globals: Global.t vector; functions: functions} [@@deriving sexp_of] @@ -358,7 +358,7 @@ end module Block = struct module T = struct type t = block [@@deriving compare, equal, sexp_of] end include T - include Comparator.Make (T) + module Map = Map.Make (T) let pp = pp_block @@ -471,7 +471,7 @@ module Func = struct iter_term func ~f:(fun term -> Term.invariant ~parent:func term) | _ -> assert false - let find functions name = Map.find functions name + let find functions name = String.Map.find functions name let mk ~(name : Global.t) ~formals ~freturn ~fthrow ~entry ~cfg = let locals = @@ -518,9 +518,9 @@ end let set_derived_metadata functions = let compute_roots functions = let roots = FuncQ.create () in - Map.iter functions ~f:(fun func -> + String.Map.iter functions ~f:(fun func -> FuncQ.enqueue_back_exn roots func.name.reg func ) ; - Map.iter functions ~f:(fun func -> + String.Map.iter functions ~f:(fun func -> Func.fold_term func ~init:() ~f:(fun () -> function | Call {callee; _} -> ( match Reg.of_exp callee with @@ -571,10 +571,8 @@ let set_derived_metadata functions = index := !index - 1 ) in let functions = - List.fold functions - ~init:(Map.empty (module String)) - ~f:(fun m func -> - Map.add_exn m ~key:(Reg.name func.name.reg) ~data:func ) + List.fold functions ~init:String.Map.empty ~f:(fun m func -> + String.Map.add_exn m ~key:(Reg.name func.name.reg) ~data:func ) in let roots = compute_roots functions in let tips_to_roots = topsort functions roots in @@ -599,5 +597,5 @@ let pp fs {globals; functions} = (Vector.pp "@\n@\n" Global.pp_defn) globals (List.pp "@\n@\n" Func.pp) - ( Map.data functions + ( String.Map.data functions |> List.sort ~compare:(fun x y -> compare_block x.entry y.entry) ) diff --git a/sledge/lib/llair.mli b/sledge/lib/llair.mli index faa137191..20363dd53 100644 --- a/sledge/lib/llair.mli +++ b/sledge/lib/llair.mli @@ -166,10 +166,10 @@ end module Block : sig type t = block [@@deriving compare, equal, sexp_of] - include Comparator.S with type t := t - val pp : t pp val mk : lbl:label -> cmnd:cmnd -> term:term -> block + + module Map : Map.S with type key := t end module Func : sig diff --git a/sledge/lib/sh.ml b/sledge/lib/sh.ml index 251928e06..395bc0600 100644 --- a/sledge/lib/sh.ml +++ b/sledge/lib/sh.ml @@ -98,22 +98,22 @@ let fold_vars ?ignore_cong fold_vars q ~init ~f = let rec var_strength_ xs m q = let add m v = - match Map.find m v with - | None -> Map.set m ~key:v ~data:`Anonymous - | Some `Anonymous -> Map.set m ~key:v ~data:`Existential + match Var.Map.find m v with + | None -> Var.Map.set m ~key:v ~data:`Anonymous + | Some `Anonymous -> Var.Map.set m ~key:v ~data:`Existential | Some _ -> m in let xs = Set.union xs q.xs in let m_stem = fold_vars_stem ~ignore_cong:() q ~init:m ~f:(fun m var -> - if not (Set.mem xs var) then Map.set m ~key:var ~data:`Universal + if not (Set.mem xs var) then Var.Map.set m ~key:var ~data:`Universal else add m var ) in let m = List.fold ~init:m_stem q.djns ~f:(fun m djn -> let ms = List.map ~f:(fun dj -> snd (var_strength_ xs m dj)) djn in List.reduce_balanced ms ~f:(fun m1 m2 -> - Map.merge_skewed m1 m2 ~combine:(fun ~key:_ s1 s2 -> + Var.Map.merge_skewed m1 m2 ~combine:(fun ~key:_ s1 s2 -> match (s1, s2) with | `Anonymous, `Anonymous -> `Anonymous | `Universal, _ | _, `Universal -> `Universal @@ -125,7 +125,7 @@ let rec var_strength_ xs m q = let var_strength_full ?(xs = Var.Set.empty) q = let m = Set.fold xs ~init:Var.Map.empty ~f:(fun m x -> - Map.set m ~key:x ~data:`Existential ) + Var.Map.set m ~key:x ~data:`Existential ) in var_strength_ xs m q @@ -212,7 +212,7 @@ let pp_us ?(pre = ("" : _ fmt)) ?vs () fs us = let rec pp_ ?var_strength vs parent_xs parent_cong fs {us; xs; cong; pure; heap; djns} = Format.pp_open_hvbox fs 0 ; - let x v = Option.bind ~f:(fun (_, m) -> Map.find m v) var_strength in + let x v = Option.bind ~f:(fun (_, m) -> Var.Map.find m v) var_strength in pp_us ~vs () fs us ; let xs_d_vs, xs_i_vs = Set.diff_inter diff --git a/sledge/lib/sh.mli b/sledge/lib/sh.mli index 386cb72ee..b0578c035 100644 --- a/sledge/lib/sh.mli +++ b/sledge/lib/sh.mli @@ -23,7 +23,7 @@ type starjunction = private and disjunction = starjunction list -type t = starjunction [@@deriving equal, compare, sexp] +type t = starjunction [@@deriving compare, equal, sexp] val pp_seg_norm : Equality.t -> seg pp val pp_us : ?pre:('a, 'a) fmt -> ?vs:Var.Set.t -> unit -> Var.Set.t pp diff --git a/sledge/lib/term.ml b/sledge/lib/term.ml index 3e5deea84..499131931 100644 --- a/sledge/lib/term.ml +++ b/sledge/lib/term.ml @@ -116,17 +116,7 @@ end type _t = T0.t include T - -module Map = struct - include ( - Map : - module type of Map - with type ('key, 'value, 'cmp) t := ('key, 'value, 'cmp) Map.t ) - - type 'v t = 'v Map.M(T).t [@@deriving compare, equal, sexp] - - let empty = empty (module T) -end +module Map = Map.Make (T) let empty_qset = Qset.empty (module T) @@ -370,7 +360,9 @@ module Var = struct (** Variable renaming substitutions *) module Subst = struct - type t = T.t Map.M(T).t [@@deriving compare, equal, sexp] + type t = T.t Map.t [@@deriving compare, equal, sexp_of] + + let t_of_sexp = Map.t_of_sexp T.t_of_sexp T.t_of_sexp let invariant s = Invariant.invariant [%here] s [%sexp_of: t] diff --git a/sledge/lib/term.mli b/sledge/lib/term.mli index 49977f386..a60722291 100644 --- a/sledge/lib/term.mli +++ b/sledge/lib/term.mli @@ -110,14 +110,7 @@ module Var : sig val union_list : t list -> t end - module Map : sig - type var := t - - type 'a t = (var, 'a, comparator_witness) Map.t - [@@deriving compare, equal, sexp] - - val empty : 'a t - end + module Map : Map.S with type key := t val pp : t pp @@ -147,16 +140,8 @@ module Var : sig end end -module Map : sig - type term := t - - type 'a t = (term, 'a, comparator_witness) Map.t - [@@deriving compare, equal, sexp] - - val empty : 'a t -end +module Map : Map.S with type key := t -val comparator : (t, comparator_witness) Comparator.t val ppx : Var.strength -> t pp val pp : t pp val pp_diff : (t * t) pp