[sledge] Do not use Base.Map

Summary:
The base containers have inconvenient interfaces due to lacking
support for functors, which also leads to the representation of values
of containers including closures for the comparison functions. This
causes problems when `Marshal`ing these values.

This diff is one step toward not using the base containers.

Reviewed By: ngorogiannis

Differential Revision: D20482763

fbshipit-source-id: f55f91bf2
master
Josh Berdine 5 years ago committed by Facebook GitHub Bot
parent 86ab2d7d10
commit 57a8748e9f

@ -81,7 +81,7 @@ let used_globals pgm preanalyze : Domain_used_globals.r =
; globals= Declared Reg.Set.empty } ; globals= Declared Reg.Set.empty }
pgm pgm
in in
Per_function (Map.map summary_table ~f:Reg.Set.union_list) Per_function (Reg.Map.map summary_table ~f:Reg.Set.union_list)
else else
Declared Declared
(Vector.fold pgm.globals ~init:Reg.Set.empty ~f:(fun acc g -> (Vector.fold pgm.globals ~init:Reg.Set.empty ~f:(fun acc g ->

@ -168,7 +168,6 @@ module Make (Dom : Domain_intf.Dom) = struct
end end
include T include T
include Comparator.Make (T)
let pp fs {dst; src} = let pp fs {dst; src} =
Format.fprintf fs "#%i %%%s <--%a" dst.sort_index dst.lbl Format.fprintf fs "#%i %%%s <--%a" dst.sort_index dst.lbl
@ -178,27 +177,27 @@ module Make (Dom : Domain_intf.Dom) = struct
end end
module Depths = struct module Depths = struct
type t = int Map.M(Edge).t module M = Map.Make (Edge)
let empty = Map.empty (module Edge) type t = int M.t
let find = Map.find
let set = Map.set let empty = M.empty
let find = M.find
let set = M.set
let join x y = let join x y =
Map.merge x y ~f:(fun ~key:_ -> function M.merge x y ~f:(fun ~key:_ -> function
| `Left d | `Right d -> Some d | `Left d | `Right d -> Some d
| `Both (d1, d2) -> Some (Int.max d1 d2) ) | `Both (d1, d2) -> Some (Int.max d1 d2) )
end end
type priority = int * Edge.t [@@deriving compare] type priority = int * Edge.t [@@deriving compare]
type priority_queue = priority Fheap.t type priority_queue = priority Fheap.t
type waiting_states = (Dom.t * Depths.t) list Map.M(Llair.Block).t type waiting_states = (Dom.t * Depths.t) list Llair.Block.Map.t
type t = priority_queue * waiting_states * int type t = priority_queue * waiting_states * int
type x = Depths.t -> t -> t type x = Depths.t -> t -> t
let empty_waiting_states : waiting_states = let empty_waiting_states : waiting_states = Llair.Block.Map.empty
Map.empty (module Llair.Block)
let pp_priority fs (n, e) = Format.fprintf fs "%i: %a" n Edge.pp e let pp_priority fs (n, e) = Format.fprintf fs "%i: %a" n Edge.pp e
let pp fs pq = let pp fs pq =
@ -221,7 +220,9 @@ module Make (Dom : Domain_intf.Dom) = struct
let pq = Fheap.add pq (depth, edge) in let pq = Fheap.add pq (depth, edge) in
[%Trace.info "@[<6>enqueue %i: %a@ | %a@]" depth Edge.pp edge pp pq] ; [%Trace.info "@[<6>enqueue %i: %a@ | %a@]" depth Edge.pp edge pp pq] ;
let depths = Depths.set depths ~key:edge ~data:depth in let depths = Depths.set depths ~key:edge ~data:depth in
let ws = Map.add_multi ws ~key:curr ~data:(state, depths) in let ws =
Llair.Block.Map.add_multi ws ~key:curr ~data:(state, depths)
in
(pq, ws, bound) (pq, ws, bound)
let init state curr bound = let init state curr bound =
@ -231,7 +232,7 @@ module Make (Dom : Domain_intf.Dom) = struct
let rec run ~f (pq0, ws, bnd) = let rec run ~f (pq0, ws, bnd) =
match Fheap.pop pq0 with match Fheap.pop pq0 with
| Some ((_, ({Edge.dst; stk} as edge)), pq) -> ( | Some ((_, ({Edge.dst; stk} as edge)), pq) -> (
match Map.find_and_remove ws dst with match Llair.Block.Map.find_and_remove ws dst with
| Some (q :: qs, ws) -> | Some (q :: qs, ws) ->
let join (qa, da) (q, d) = (Dom.join q qa, Depths.join d da) in let join (qa, da) (q, d) = (Dom.join q qa, Depths.join d da) in
let skipped, (qs, depths) = let skipped, (qs, depths) =
@ -240,7 +241,7 @@ module Make (Dom : Domain_intf.Dom) = struct
| Some joined, depths -> (skipped, (joined, depths)) | Some joined, depths -> (skipped, (joined, depths))
| None, _ -> (curr :: skipped, joined) ) | None, _ -> (curr :: skipped, joined) )
in in
let ws = Map.add_exn ws ~key:dst ~data:skipped in let ws = Llair.Block.Map.add_exn ws ~key:dst ~data:skipped in
run ~f (f stk qs dst depths (pq, ws, bnd)) run ~f (f stk qs dst depths (pq, ws, bnd))
| _ -> | _ ->
[%Trace.info "done: %a" Edge.pp edge] ; [%Trace.info "done: %a" Edge.pp edge] ;
@ -489,5 +490,5 @@ module Make (Dom : Domain_intf.Dom) = struct
assert opts.function_summaries ; assert opts.function_summaries ;
exec_pgm opts pgm ; exec_pgm opts pgm ;
Hashtbl.fold summary_table ~init:Reg.Map.empty ~f:(fun ~key ~data map -> Hashtbl.fold summary_table ~init:Reg.Map.empty ~f:(fun ~key ~data map ->
match data with [] -> map | _ -> Map.set map ~key ~data ) match data with [] -> map | _ -> Reg.Map.set map ~key ~data )
end end

@ -20,7 +20,7 @@ module type State_domain_sig = sig
end end
module Make (State_domain : State_domain_sig) = struct module Make (State_domain : State_domain_sig) = struct
type t = State_domain.t * State_domain.t [@@deriving sexp_of, equal] type t = State_domain.t * State_domain.t [@@deriving equal, sexp_of]
let embed b = (b, b) let embed b = (b, b)

@ -7,7 +7,7 @@
(** Abstract domain *) (** Abstract domain *)
type t = Sh.t [@@deriving equal, sexp_of] type t = Sh.t [@@deriving equal, sexp]
let pp fs q = Format.fprintf fs "@[{ %a@ }@]" Sh.pp q let pp fs q = Format.fprintf fs "@[{ %a@ }@]" Sh.pp q
let report_fmt_thunk = Fn.flip pp let report_fmt_thunk = Fn.flip pp

@ -7,7 +7,7 @@
(** "Unit" abstract domain *) (** "Unit" abstract domain *)
type t = unit [@@deriving equal, sexp_of] type t = unit [@@deriving equal, sexp]
let pp fs () = Format.pp_print_string fs "()" let pp fs () = Format.pp_print_string fs "()"
let report_fmt_thunk () fs = pp fs () let report_fmt_thunk () fs = pp fs ()

@ -7,7 +7,7 @@
(** Used-globals abstract domain *) (** Used-globals abstract domain *)
type t = Reg.Set.t [@@deriving equal, sexp_of] type t = Reg.Set.t [@@deriving equal, sexp]
let pp = Set.pp Reg.pp let pp = Set.pp Reg.pp
let report_fmt_thunk = Fn.flip pp let report_fmt_thunk = Fn.flip pp
@ -60,7 +60,7 @@ let exec_intrinsic ~skip_throw:_ st _ intrinsic actuals =
|> fun res -> Some (Some res) |> fun res -> Some (Some res)
else None else None
type from_call = t [@@deriving sexp_of] type from_call = t [@@deriving sexp]
(* Set abstract state to bottom (i.e. empty set) at function entry *) (* Set abstract state to bottom (i.e. empty set) at function entry *)
let call ~summaries:_ ~globals:_ ~actuals ~areturn:_ ~formals:_ ~freturn:_ let call ~summaries:_ ~globals:_ ~actuals ~areturn:_ ~formals:_ ~freturn:_
@ -92,7 +92,7 @@ let by_function : r -> Reg.t -> t =
( match s with ( match s with
| Declared set -> set | Declared set -> set
| Per_function map -> ( | Per_function map -> (
match Map.find map fn with match Reg.Map.find map fn with
| Some gs -> gs | Some gs -> gs
| None -> | None ->
fail fail

@ -56,25 +56,26 @@ module Subst : sig
val to_alist : t -> (Term.t * Term.t) list val to_alist : t -> (Term.t * Term.t) list
val partition_valid : Var.Set.t -> t -> t * Var.Set.t * t val partition_valid : Var.Set.t -> t -> t * Var.Set.t * t
end = struct end = struct
type t = Term.t Term.Map.t [@@deriving compare, equal, sexp] type t = Term.t Term.Map.t [@@deriving compare, equal, sexp_of]
let pp = Map.pp Term.pp Term.pp let t_of_sexp = Term.Map.t_of_sexp Term.t_of_sexp Term.t_of_sexp
let pp = Term.Map.pp Term.pp Term.pp
let pp_diff = let pp_diff =
Map.pp_diff ~data_equal:Term.equal Term.pp Term.pp Term.pp_diff Term.Map.pp_diff ~data_equal:Term.equal Term.pp Term.pp Term.pp_diff
let empty = Term.Map.empty let empty = Term.Map.empty
let is_empty = Map.is_empty let is_empty = Term.Map.is_empty
let length = Map.length let length = Term.Map.length
let mem = Map.mem let mem = Term.Map.mem
let find = Map.find let find = Term.Map.find
let fold = Map.fold let fold = Term.Map.fold
let iteri = Map.iteri let iteri = Term.Map.iteri
let for_alli = Map.for_alli let for_alli = Term.Map.for_alli
let to_alist = Map.to_alist ~key_order:`Increasing let to_alist = Term.Map.to_alist
(** look up a term in a substitution *) (** look up a term in a substitution *)
let apply s a = Map.find s a |> Option.value ~default:a let apply s a = Term.Map.find s a |> Option.value ~default:a
let rec subst s a = apply s (Term.map ~f:(subst s) a) let rec subst s a = apply s (Term.map ~f:(subst s) a)
@ -87,21 +88,21 @@ end = struct
(** compose two substitutions *) (** compose two substitutions *)
let compose r s = let compose r s =
let r' = Map.map ~f:(norm s) r in let r' = Term.Map.map ~f:(norm s) r in
Map.merge_skewed r' s ~combine:(fun ~key v1 v2 -> Term.Map.merge_skewed r' s ~combine:(fun ~key v1 v2 ->
if Term.equal v1 v2 then v1 if Term.equal v1 v2 then v1
else fail "domains intersect: %a" Term.pp key () ) else fail "domains intersect: %a" Term.pp key () )
(** compose a substitution with a mapping *) (** compose a substitution with a mapping *)
let compose1 ~key ~data s = let compose1 ~key ~data s =
if Term.equal key data then s if Term.equal key data then s
else compose s (Map.set Term.Map.empty ~key ~data) else compose s (Term.Map.set Term.Map.empty ~key ~data)
(** add an identity entry if the term is not already present *) (** add an identity entry if the term is not already present *)
let extend e s = let extend e s =
let exception Found in let exception Found in
match match
Map.update s e ~f:(function Term.Map.update s e ~f:(function
| Some _ -> Exn.raise_without_backtrace Found | Some _ -> Exn.raise_without_backtrace Found
| None -> e ) | None -> e )
with with
@ -112,12 +113,14 @@ end = struct
[f] is injective and for any set of terms [E], [f\[E\]] is disjoint [f] is injective and for any set of terms [E], [f\[E\]] is disjoint
from [E] *) from [E] *)
let map_entries ~f s = let map_entries ~f s =
Map.fold s ~init:s ~f:(fun ~key ~data s -> Term.Map.fold s ~init:s ~f:(fun ~key ~data s ->
let key' = f key in let key' = f key in
let data' = f data in let data' = f data in
if Term.equal key' key then if Term.equal key' key then
if Term.equal data' data then s else Map.set s ~key ~data:data' if Term.equal data' data then s
else Map.remove s key |> Map.add_exn ~key:key' ~data:data' ) else Term.Map.set s ~key ~data:data'
else Term.Map.remove s key |> Term.Map.add_exn ~key:key' ~data:data'
)
(** Holds only if [true ⊢ ∃xs. e=f]. Clients assume (** Holds only if [true ⊢ ∃xs. e=f]. Clients assume
[not (is_valid_eq xs e f)] implies [not (is_valid_eq ys e f)] for [not (is_valid_eq xs e f)] implies [not (is_valid_eq ys e f)] for
@ -141,12 +144,12 @@ end = struct
valid, so loop until no change. *) valid, so loop until no change. *)
let rec partition_valid_ t ks s = let rec partition_valid_ t ks s =
let t', ks', s' = let t', ks', s' =
Map.fold s ~init:(t, ks, s) ~f:(fun ~key ~data (t, ks, s) -> Term.Map.fold s ~init:(t, ks, s) ~f:(fun ~key ~data (t, ks, s) ->
if is_valid_eq ks key data then (t, ks, s) if is_valid_eq ks key data then (t, ks, s)
else else
let t = Map.set ~key ~data t let t = Term.Map.set ~key ~data t
and ks = Set.diff ks (Set.union (Term.fv key) (Term.fv data)) and ks = Set.diff ks (Set.union (Term.fv key) (Term.fv data))
and s = Map.remove s key in and s = Term.Map.remove s key in
(t, ks, s) ) (t, ks, s) )
in in
if s' != s then partition_valid_ t' ks' s' else (t', ks', s') if s' != s then partition_valid_ t' ks' s' else (t', ks', s')
@ -327,7 +330,7 @@ type t =
let classes r = let classes r =
let add key data cls = let add key data cls =
if Term.equal key data then cls if Term.equal key data then cls
else Map.add_multi cls ~key:data ~data:key else Term.Map.add_multi cls ~key:data ~data:key
in in
Subst.fold r.rep ~init:Term.Map.empty ~f:(fun ~key ~data cls -> Subst.fold r.rep ~init:Term.Map.empty ~f:(fun ~key ~data cls ->
match classify key with match classify key with
@ -337,7 +340,7 @@ let classes r =
let cls_of r e = let cls_of r e =
let e' = Subst.apply r.rep e in let e' = Subst.apply r.rep e in
Map.find (classes r) e' |> Option.value ~default:[e'] Term.Map.find (classes r) e' |> Option.value ~default:[e']
(** Pretty-printing *) (** Pretty-printing *)
@ -373,12 +376,13 @@ let ppx_clss x fs cs =
(fun fs (key, data) -> (fun fs (key, data) ->
Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) key (ppx_cls x) Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) key (ppx_cls x)
(List.sort ~compare:Term.compare data) ) (List.sort ~compare:Term.compare data) )
fs (Map.to_alist cs) fs (Term.Map.to_alist cs)
let pp_clss fs cs = ppx_clss (fun _ -> None) fs cs let pp_clss fs cs = ppx_clss (fun _ -> None) fs cs
let pp_diff_clss = let pp_diff_clss =
Map.pp_diff ~data_equal:(List.equal Term.equal) Term.pp pp_cls pp_diff_cls Term.Map.pp_diff ~data_equal:(List.equal Term.equal) Term.pp pp_cls
pp_diff_cls
(** Invariant *) (** Invariant *)
@ -525,7 +529,7 @@ let normalize = canon
let class_of r e = let class_of r e =
let e' = normalize r e in let e' = normalize r e in
e' :: Map.find_multi (classes r) e' e' :: Term.Map.find_multi (classes r) e'
let fold_uses_of r t ~init ~f = let fold_uses_of r t ~init ~f =
let rec fold_ e ~init:s ~f = let rec fold_ e ~init:s ~f =
@ -558,7 +562,7 @@ let difference r a b =
let apply_subst us s r = let apply_subst us s r =
[%Trace.call fun {pf} -> pf "%a@ %a" Subst.pp s pp r] [%Trace.call fun {pf} -> pf "%a@ %a" Subst.pp s pp r]
; ;
Map.fold (classes r) ~init:true_ ~f:(fun ~key:rep ~data:cls r -> Term.Map.fold (classes r) ~init:true_ ~f:(fun ~key:rep ~data:cls r ->
let rep' = Subst.subst s rep in let rep' = Subst.subst s rep in
List.fold cls ~init:r ~f:(fun r trm -> List.fold cls ~init:r ~f:(fun r trm ->
let trm' = Subst.subst s trm in let trm' = Subst.subst s trm in
@ -585,7 +589,7 @@ let or_ us r s =
else if not r.sat then s else if not r.sat then s
else else
let merge_mems rs r s = let merge_mems rs r s =
Map.fold (classes s) ~init:rs ~f:(fun ~key:rep ~data:cls rs -> Term.Map.fold (classes s) ~init:rs ~f:(fun ~key:rep ~data:cls rs ->
List.fold cls List.fold cls
~init:([rep], rs) ~init:([rep], rs)
~f:(fun (reps, rs) exp -> ~f:(fun (reps, rs) exp ->
@ -651,7 +655,7 @@ let ppx_classes x fs r = ppx_clss x fs (classes r)
let ppx_classes_diff x fs (r, s) = let ppx_classes_diff x fs (r, s) =
let clss = classes s in let clss = classes s in
let clss = let clss =
Map.filter_mapi clss ~f:(fun ~key:rep ~data:cls -> Term.Map.filter_mapi clss ~f:(fun ~key:rep ~data:cls ->
match match
List.filter cls ~f:(fun exp -> not (entails_eq r rep exp)) List.filter cls ~f:(fun exp -> not (entails_eq r rep exp))
with with
@ -663,7 +667,7 @@ let ppx_classes_diff x fs (r, s) =
Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) rep Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) rep
(List.pp "@ = " (Term.ppx x)) (List.pp "@ = " (Term.ppx x))
(List.dedup_and_sort ~compare:Term.compare cls) ) (List.dedup_and_sort ~compare:Term.compare cls) )
fs (Map.to_alist clss) fs (Term.Map.to_alist clss)
(** Existential Witnessing and Elimination *) (** Existential Witnessing and Elimination *)
@ -876,8 +880,8 @@ let solve_class us us_xs ~key:rep ~data:cls (classes, subst) =
|> Option.value ~default:cls |> Option.value ~default:cls
in in
let classes = let classes =
if List.is_empty cls then Map.remove classes rep if List.is_empty cls then Term.Map.remove classes rep
else Map.set classes ~key:rep ~data:cls else Term.Map.set classes ~key:rep ~data:cls
in in
(classes, subst) (classes, subst)
|> |>
@ -954,7 +958,8 @@ let solve_classes r (classes, subst, us) xs =
; ;
let rec solve_classes_ (classes0, subst0, us_xs) = let rec solve_classes_ (classes0, subst0, us_xs) =
let classes, subst = let classes, subst =
Map.fold ~f:(solve_class us us_xs) classes0 ~init:(classes0, subst0) Term.Map.fold ~f:(solve_class us us_xs) classes0
~init:(classes0, subst0)
in in
if subst != subst0 then solve_classes_ (classes, subst, us_xs) if subst != subst0 then solve_classes_ (classes, subst, us_xs)
else (classes, subst, us_xs) else (classes, subst, us_xs)

@ -84,6 +84,7 @@ module T = struct
end end
include T include T
module Map = Map.Make (T)
let term e = e.term let term e = e.term
@ -328,16 +329,7 @@ module Reg = struct
let vars = Set.fold ~init:Var.Set.empty ~f:(fun s r -> add s (var r)) let vars = Set.fold ~init:Var.Set.empty ~f:(fun s r -> add s (var r))
end end
module Map = struct module Map = Map
include (
Map :
module type of Map
with type ('key, 'value, 'cmp) t := ('key, 'value, 'cmp) Map.t )
type 'v t = 'v Map.M(T).t [@@deriving compare, equal, sexp]
let empty = Map.empty (module T)
end
let demangle = ref (fun _ -> None) let demangle = ref (fun _ -> None)

@ -116,14 +116,7 @@ module Reg : sig
val vars : t -> Var.Set.t val vars : t -> Var.Set.t
end end
module Map : sig module Map : Map.S with type key := t
type reg := t
type 'a t = (reg, 'a, comparator_witness) Map.t
[@@deriving compare, equal, sexp]
val empty : 'a t
end
val demangle : (string -> string option) ref val demangle : (string -> string option) ref
val pp : t pp val pp : t pp

@ -253,52 +253,201 @@ module List = struct
pp sep pp_diff_elt fs (symmetric_diff ~compare xs ys) pp sep pp_diff_elt fs (symmetric_diff ~compare xs ys)
end end
module type OrderedType = sig
type t
val compare : t -> t -> int
val sexp_of_t : t -> Sexp.t
end
exception Duplicate
module Map = struct module Map = struct
include Base.Map module type S = sig
type key
let pp pp_k pp_v fs m = type +'a t
Format.fprintf fs "@[<1>[%a]@]"
(List.pp ",@ " (fun fs (k, v) -> val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int
Format.fprintf fs "@[%a @<2>↦ %a@]" pp_k k pp_v v )) val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool
(to_alist m) val sexp_of_t : ('a -> Sexp.t) -> 'a t -> Sexp.t
val t_of_sexp : (Sexp.t -> key) -> (Sexp.t -> 'a) -> Sexp.t -> 'a t
let pp_diff ~data_equal pp_key pp_val pp_diff_val fs (x, y) = val pp : key pp -> 'a pp -> 'a t pp
let pp_diff_elt fs = function
| k, `Left v -> val pp_diff :
Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v data_equal:('a -> 'a -> bool)
| k, `Right v -> -> key pp
Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v -> 'a pp
| k, `Unequal vv -> -> ('a * 'a) pp
Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_diff_val vv -> ('a t * 'a t) pp
in
let sd = Sequence.to_list (symmetric_diff ~data_equal x y) in (* initial constructors *)
if not (List.is_empty sd) then val empty : 'a t
Format.fprintf fs "[@[<hv>%a@]];@ " (List.pp ";@ " pp_diff_elt) sd
(* constructors *)
let equal_m__t (module Elt : Compare_m) equal_v = equal equal_v val set : 'a t -> key:key -> data:'a -> 'a t
val add_exn : 'a t -> key:key -> data:'a -> 'a t
let find_and_remove m k = val add_multi : 'a list t -> key:key -> data:'a -> 'a list t
let found = ref None in val remove : 'a t -> key -> 'a t
let m = val update : 'a t -> key -> f:('a option -> 'a) -> 'a t
change m k ~f:(fun v ->
found := v ; val merge :
None ) 'a t
in -> 'b t
let+ v = !found in -> f:
(v, m) ( key:key
-> [`Both of 'a * 'b | `Left of 'a | `Right of 'b]
let find_or_add (type data) map key ~(default : data) ~if_found ~if_added -> 'c option)
= -> 'c t
let exception Found of data in
match val merge_skewed :
update map key ~f:(function 'a t -> 'a t -> combine:(key:key -> 'a -> 'a -> 'a) -> 'a t
| Some old_data -> Exn.raise_without_backtrace (Found old_data)
| None -> default ) val map : 'a t -> f:('a -> 'b) -> 'b t
with val filter_keys : 'a t -> f:(key -> bool) -> 'a t
| exception Found old_data -> if_found old_data val filter_mapi : 'a t -> f:(key:key -> data:'a -> 'b option) -> 'b t
| map -> if_added map
(* queries *)
val is_empty : 'b t -> bool
val length : 'b t -> int
val mem : 'a t -> key -> bool
val find : 'a t -> key -> 'a option
val find_and_remove : 'a t -> key -> ('a * 'a t) option
val find_multi : 'a list t -> key -> 'a list
val data : 'a t -> 'a list
val to_alist : 'a t -> (key * 'a) list
(* traversals *)
val iter : 'a t -> f:('a -> unit) -> unit
val iteri : 'a t -> f:(key:key -> data:'a -> unit) -> unit
val for_alli : 'a t -> f:(key:key -> data:'a -> bool) -> bool
val fold : 'a t -> init:'s -> f:(key:key -> data:'a -> 's -> 's) -> 's
end
let map_preserving_phys_equal t ~f = map_preserving_phys_equal map t ~f module Make (Key : OrderedType) : S with type key = Key.t = struct
module M = Caml.Map.Make (Key)
type key = Key.t
type 'a t = 'a M.t
let compare = M.compare
let equal = M.equal
let sexp_of_t sexp_of_val m =
List.sexp_of_t
(Sexplib.Conv.sexp_of_pair Key.sexp_of_t sexp_of_val)
(M.bindings m)
let t_of_sexp key_of_sexp val_of_sexp sexp =
Caml.List.fold_left
(fun m (k, v) -> M.add k v m)
M.empty
(List.t_of_sexp
(Sexplib.Conv.pair_of_sexp key_of_sexp val_of_sexp)
sexp)
let pp pp_k pp_v fs m =
Format.fprintf fs "@[<1>[%a]@]"
(List.pp ",@ " (fun fs (k, v) ->
Format.fprintf fs "@[%a @<2>↦ %a@]" pp_k k pp_v v ))
(M.bindings m)
let pp_diff ~data_equal pp_key pp_val pp_diff_val fs (x, y) =
let pp_diff_val fs = function
| k, `Left v ->
Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v
| k, `Right v ->
Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v
| k, `Unequal vv ->
Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_diff_val vv
in
let sd =
M.merge
(fun _ v1o v2o ->
match (v1o, v2o) with
| Some v1, Some v2 when not (data_equal v1 v2) ->
Some (`Unequal (v1, v2))
| Some v1, None -> Some (`Left v1)
| None, Some v2 -> Some (`Right v2)
| _ -> None )
x y
in
if not (M.is_empty sd) then
Format.fprintf fs "[@[<hv>%a@]];@ "
(List.pp ";@ " pp_diff_val)
(M.bindings sd)
exception Duplicate
let empty = M.empty
let set m ~key ~data = M.add key data m
let add_exn m ~key ~data =
M.update key
(function None -> Some data | Some _ -> raise Duplicate)
m
let add_multi m ~key ~data =
M.update key
(function None -> Some [data] | Some vs -> Some (data :: vs))
m
let remove m k = M.remove k m
let update m k ~f = M.update k (fun vo -> Some (f vo)) m
let merge m n ~f =
M.merge
(fun k v1o v2o ->
match (v1o, v2o) with
| Some v1, Some v2 -> f ~key:k (`Both (v1, v2))
| Some v1, None -> f ~key:k (`Left v1)
| None, Some v2 -> f ~key:k (`Right v2)
| None, None -> None )
m n
let merge_skewed m n ~combine =
M.merge
(fun k v1o v2o ->
match (v1o, v2o) with
| Some v1, Some v2 -> Some (combine ~key:k v1 v2)
| Some _, None -> v1o
| None, Some _ -> v2o
| None, None -> None )
m n
let map m ~f = M.map f m
let filter_keys m ~f = M.filter (fun k _ -> f k) m
let filter_mapi m ~f =
M.fold
(fun k v m ->
match f ~key:k ~data:v with Some v' -> M.add k v' m | None -> m )
m M.empty
let is_empty = M.is_empty
let length = M.cardinal
let mem m k = M.mem k m
let find m k = M.find_opt k m
let find_and_remove m k =
let found = ref None in
let m =
M.update k
(fun v ->
found := v ;
None )
m
in
let+ v = !found in
(v, m)
let find_multi m k = try M.find k m with Not_found -> []
let data m = M.fold (fun _ v s -> v :: s) m []
let to_alist = M.bindings
let iter m ~f = M.iter (fun _ v -> f v) m
let iteri m ~f = M.iter (fun k v -> f ~key:k ~data:v) m
let for_alli m ~f = M.for_all (fun key data -> f ~key ~data) m
let fold m ~init ~f = M.fold (fun key data s -> f ~key ~data s) m init
end
end end
module Result = struct module Result = struct
@ -358,6 +507,15 @@ module Array = struct
let pp sep pp_elt fs a = List.pp sep pp_elt fs (to_list a) let pp sep pp_elt fs a = List.pp sep pp_elt fs (to_list a)
end end
module String = struct
include String
let t_of_sexp = Sexplib.Conv.string_of_sexp
let sexp_of_t = Sexplib.Conv.sexp_of_string
module Map = Map.Make (String)
end
module Q = struct module Q = struct
let pp = Q.pp_print let pp = Q.pp_print
let hash = Hashtbl.hash let hash = Hashtbl.hash

@ -192,39 +192,77 @@ module List : sig
compare:('a -> 'a -> int) -> 'a t -> 'a t -> ('a, 'a) Either.t t compare:('a -> 'a -> int) -> 'a t -> 'a t -> ('a, 'a) Either.t t
end end
module Map : sig module type OrderedType = sig
include module type of Base.Map type t
val pp : 'k pp -> 'v pp -> ('k, 'v, 'c) t pp val compare : t -> t -> int
val sexp_of_t : t -> Sexp.t
end
val pp_diff : exception Duplicate
data_equal:('v -> 'v -> bool)
-> 'k pp
-> 'v pp
-> ('v * 'v) pp
-> (('k, 'v, 'c) t * ('k, 'v, 'c) t) pp
val equal_m__t : module Map : sig
(module Compare_m) module type S = sig
-> ('v -> 'v -> bool) type key
-> ('k, 'v, 'c) t type +'a t
-> ('k, 'v, 'c) t
-> bool val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int
val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool
val find_and_remove : ('k, 'v, 'c) t -> 'k -> ('v * ('k, 'v, 'c) t) option val sexp_of_t : ('a -> Sexp.t) -> 'a t -> Sexp.t
val t_of_sexp : (Sexp.t -> key) -> (Sexp.t -> 'a) -> Sexp.t -> 'a t
val find_or_add : val pp : key pp -> 'a pp -> 'a t pp
('k, 'v, 'c) t
-> 'k val pp_diff :
-> default:'v data_equal:('a -> 'a -> bool)
-> if_found:('v -> 'a) -> key pp
-> if_added:(('k, 'v, 'c) t -> 'a) -> 'a pp
-> 'a -> ('a * 'a) pp
-> ('a t * 'a t) pp
val map_preserving_phys_equal :
('k, 'v, 'c) t -> f:('v -> 'v) -> ('k, 'v, 'c) t (* initial constructors *)
(** Like map, but preserves [phys_equal] if [f] preserves [phys_equal] of val empty : 'a t
every element. *)
(* constructors *)
val set : 'a t -> key:key -> data:'a -> 'a t
val add_exn : 'a t -> key:key -> data:'a -> 'a t
val add_multi : 'a list t -> key:key -> data:'a -> 'a list t
val remove : 'a t -> key -> 'a t
val update : 'a t -> key -> f:('a option -> 'a) -> 'a t
val merge :
'a t
-> 'b t
-> f:
( key:key
-> [`Both of 'a * 'b | `Left of 'a | `Right of 'b]
-> 'c option)
-> 'c t
val merge_skewed :
'a t -> 'a t -> combine:(key:key -> 'a -> 'a -> 'a) -> 'a t
val map : 'a t -> f:('a -> 'b) -> 'b t
val filter_keys : 'a t -> f:(key -> bool) -> 'a t
val filter_mapi : 'a t -> f:(key:key -> data:'a -> 'b option) -> 'b t
(* queries *)
val is_empty : 'b t -> bool
val length : 'b t -> int
val mem : 'a t -> key -> bool
val find : 'a t -> key -> 'a option
val find_and_remove : 'a t -> key -> ('a * 'a t) option
val find_multi : 'a list t -> key -> 'a list
val data : 'a t -> 'a list
val to_alist : 'a t -> (key * 'a) list
(* traversals *)
val iter : 'a t -> f:('a -> unit) -> unit
val iteri : 'a t -> f:(key:key -> data:'a -> unit) -> unit
val for_alli : 'a t -> f:(key:key -> data:'a -> bool) -> bool
val fold : 'a t -> init:'s -> f:(key:key -> data:'a -> 's -> 's) -> 's
end
module Make (Key : OrderedType) : S with type key = Key.t
end end
module Result : sig module Result : sig
@ -277,6 +315,15 @@ module Array : sig
val pp : (unit, unit) fmt -> 'a pp -> 'a array pp val pp : (unit, unit) fmt -> 'a pp -> 'a array pp
end end
module String : sig
include module type of String
val t_of_sexp : Sexp.t -> t
val sexp_of_t : t -> Sexp.t
module Map : Map.S with type key = string
end
module Q : sig module Q : sig
include module type of struct include Q end include module type of struct include Q end

@ -114,7 +114,7 @@ let sexp_of_func {name; formals; freturn; fthrow; locals; entry} =
let compare_block x y = Int.compare x.sort_index y.sort_index let compare_block x y = Int.compare x.sort_index y.sort_index
let equal_block x y = Int.equal x.sort_index y.sort_index let equal_block x y = Int.equal x.sort_index y.sort_index
type functions = func Map.M(String).t [@@deriving sexp_of] type functions = func String.Map.t [@@deriving sexp_of]
type t = {globals: Global.t vector; functions: functions} type t = {globals: Global.t vector; functions: functions}
[@@deriving sexp_of] [@@deriving sexp_of]
@ -358,7 +358,7 @@ end
module Block = struct module Block = struct
module T = struct type t = block [@@deriving compare, equal, sexp_of] end module T = struct type t = block [@@deriving compare, equal, sexp_of] end
include T include T
include Comparator.Make (T) module Map = Map.Make (T)
let pp = pp_block let pp = pp_block
@ -471,7 +471,7 @@ module Func = struct
iter_term func ~f:(fun term -> Term.invariant ~parent:func term) iter_term func ~f:(fun term -> Term.invariant ~parent:func term)
| _ -> assert false | _ -> assert false
let find functions name = Map.find functions name let find functions name = String.Map.find functions name
let mk ~(name : Global.t) ~formals ~freturn ~fthrow ~entry ~cfg = let mk ~(name : Global.t) ~formals ~freturn ~fthrow ~entry ~cfg =
let locals = let locals =
@ -518,9 +518,9 @@ end
let set_derived_metadata functions = let set_derived_metadata functions =
let compute_roots functions = let compute_roots functions =
let roots = FuncQ.create () in let roots = FuncQ.create () in
Map.iter functions ~f:(fun func -> String.Map.iter functions ~f:(fun func ->
FuncQ.enqueue_back_exn roots func.name.reg func ) ; FuncQ.enqueue_back_exn roots func.name.reg func ) ;
Map.iter functions ~f:(fun func -> String.Map.iter functions ~f:(fun func ->
Func.fold_term func ~init:() ~f:(fun () -> function Func.fold_term func ~init:() ~f:(fun () -> function
| Call {callee; _} -> ( | Call {callee; _} -> (
match Reg.of_exp callee with match Reg.of_exp callee with
@ -571,10 +571,8 @@ let set_derived_metadata functions =
index := !index - 1 ) index := !index - 1 )
in in
let functions = let functions =
List.fold functions List.fold functions ~init:String.Map.empty ~f:(fun m func ->
~init:(Map.empty (module String)) String.Map.add_exn m ~key:(Reg.name func.name.reg) ~data:func )
~f:(fun m func ->
Map.add_exn m ~key:(Reg.name func.name.reg) ~data:func )
in in
let roots = compute_roots functions in let roots = compute_roots functions in
let tips_to_roots = topsort functions roots in let tips_to_roots = topsort functions roots in
@ -599,5 +597,5 @@ let pp fs {globals; functions} =
(Vector.pp "@\n@\n" Global.pp_defn) (Vector.pp "@\n@\n" Global.pp_defn)
globals globals
(List.pp "@\n@\n" Func.pp) (List.pp "@\n@\n" Func.pp)
( Map.data functions ( String.Map.data functions
|> List.sort ~compare:(fun x y -> compare_block x.entry y.entry) ) |> List.sort ~compare:(fun x y -> compare_block x.entry y.entry) )

@ -166,10 +166,10 @@ end
module Block : sig module Block : sig
type t = block [@@deriving compare, equal, sexp_of] type t = block [@@deriving compare, equal, sexp_of]
include Comparator.S with type t := t
val pp : t pp val pp : t pp
val mk : lbl:label -> cmnd:cmnd -> term:term -> block val mk : lbl:label -> cmnd:cmnd -> term:term -> block
module Map : Map.S with type key := t
end end
module Func : sig module Func : sig

@ -98,22 +98,22 @@ let fold_vars ?ignore_cong fold_vars q ~init ~f =
let rec var_strength_ xs m q = let rec var_strength_ xs m q =
let add m v = let add m v =
match Map.find m v with match Var.Map.find m v with
| None -> Map.set m ~key:v ~data:`Anonymous | None -> Var.Map.set m ~key:v ~data:`Anonymous
| Some `Anonymous -> Map.set m ~key:v ~data:`Existential | Some `Anonymous -> Var.Map.set m ~key:v ~data:`Existential
| Some _ -> m | Some _ -> m
in in
let xs = Set.union xs q.xs in let xs = Set.union xs q.xs in
let m_stem = let m_stem =
fold_vars_stem ~ignore_cong:() q ~init:m ~f:(fun m var -> fold_vars_stem ~ignore_cong:() q ~init:m ~f:(fun m var ->
if not (Set.mem xs var) then Map.set m ~key:var ~data:`Universal if not (Set.mem xs var) then Var.Map.set m ~key:var ~data:`Universal
else add m var ) else add m var )
in in
let m = let m =
List.fold ~init:m_stem q.djns ~f:(fun m djn -> List.fold ~init:m_stem q.djns ~f:(fun m djn ->
let ms = List.map ~f:(fun dj -> snd (var_strength_ xs m dj)) djn in let ms = List.map ~f:(fun dj -> snd (var_strength_ xs m dj)) djn in
List.reduce_balanced ms ~f:(fun m1 m2 -> List.reduce_balanced ms ~f:(fun m1 m2 ->
Map.merge_skewed m1 m2 ~combine:(fun ~key:_ s1 s2 -> Var.Map.merge_skewed m1 m2 ~combine:(fun ~key:_ s1 s2 ->
match (s1, s2) with match (s1, s2) with
| `Anonymous, `Anonymous -> `Anonymous | `Anonymous, `Anonymous -> `Anonymous
| `Universal, _ | _, `Universal -> `Universal | `Universal, _ | _, `Universal -> `Universal
@ -125,7 +125,7 @@ let rec var_strength_ xs m q =
let var_strength_full ?(xs = Var.Set.empty) q = let var_strength_full ?(xs = Var.Set.empty) q =
let m = let m =
Set.fold xs ~init:Var.Map.empty ~f:(fun m x -> Set.fold xs ~init:Var.Map.empty ~f:(fun m x ->
Map.set m ~key:x ~data:`Existential ) Var.Map.set m ~key:x ~data:`Existential )
in in
var_strength_ xs m q var_strength_ xs m q
@ -212,7 +212,7 @@ let pp_us ?(pre = ("" : _ fmt)) ?vs () fs us =
let rec pp_ ?var_strength vs parent_xs parent_cong fs let rec pp_ ?var_strength vs parent_xs parent_cong fs
{us; xs; cong; pure; heap; djns} = {us; xs; cong; pure; heap; djns} =
Format.pp_open_hvbox fs 0 ; Format.pp_open_hvbox fs 0 ;
let x v = Option.bind ~f:(fun (_, m) -> Map.find m v) var_strength in let x v = Option.bind ~f:(fun (_, m) -> Var.Map.find m v) var_strength in
pp_us ~vs () fs us ; pp_us ~vs () fs us ;
let xs_d_vs, xs_i_vs = let xs_d_vs, xs_i_vs =
Set.diff_inter Set.diff_inter

@ -23,7 +23,7 @@ type starjunction = private
and disjunction = starjunction list and disjunction = starjunction list
type t = starjunction [@@deriving equal, compare, sexp] type t = starjunction [@@deriving compare, equal, sexp]
val pp_seg_norm : Equality.t -> seg pp val pp_seg_norm : Equality.t -> seg pp
val pp_us : ?pre:('a, 'a) fmt -> ?vs:Var.Set.t -> unit -> Var.Set.t pp val pp_us : ?pre:('a, 'a) fmt -> ?vs:Var.Set.t -> unit -> Var.Set.t pp

@ -116,17 +116,7 @@ end
type _t = T0.t type _t = T0.t
include T include T
module Map = Map.Make (T)
module Map = struct
include (
Map :
module type of Map
with type ('key, 'value, 'cmp) t := ('key, 'value, 'cmp) Map.t )
type 'v t = 'v Map.M(T).t [@@deriving compare, equal, sexp]
let empty = empty (module T)
end
let empty_qset = Qset.empty (module T) let empty_qset = Qset.empty (module T)
@ -370,7 +360,9 @@ module Var = struct
(** Variable renaming substitutions *) (** Variable renaming substitutions *)
module Subst = struct module Subst = struct
type t = T.t Map.M(T).t [@@deriving compare, equal, sexp] type t = T.t Map.t [@@deriving compare, equal, sexp_of]
let t_of_sexp = Map.t_of_sexp T.t_of_sexp T.t_of_sexp
let invariant s = let invariant s =
Invariant.invariant [%here] s [%sexp_of: t] Invariant.invariant [%here] s [%sexp_of: t]

@ -110,14 +110,7 @@ module Var : sig
val union_list : t list -> t val union_list : t list -> t
end end
module Map : sig module Map : Map.S with type key := t
type var := t
type 'a t = (var, 'a, comparator_witness) Map.t
[@@deriving compare, equal, sexp]
val empty : 'a t
end
val pp : t pp val pp : t pp
@ -147,16 +140,8 @@ module Var : sig
end end
end end
module Map : sig module Map : Map.S with type key := t
type term := t
type 'a t = (term, 'a, comparator_witness) Map.t
[@@deriving compare, equal, sexp]
val empty : 'a t
end
val comparator : (t, comparator_witness) Comparator.t
val ppx : Var.strength -> t pp val ppx : Var.strength -> t pp
val pp : t pp val pp : t pp
val pp_diff : (t * t) pp val pp_diff : (t * t) pp

Loading…
Cancel
Save