[sledge] Shift to a more standard Map API

Summary:
The changes in map_intf.ml dictate the rest. The previous API
minimized changes when changing the backing implementation. But that
API is hostile toward composition, partial application, and
state-passing style code.

Reviewed By: jvillard

Differential Revision: D24306050

fbshipit-source-id: 71e286d4e
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent 01bf695fa3
commit 4780b92584

@ -53,7 +53,7 @@ and x_trm : var_env -> Smt.Ast.term -> Term.t =
fun n term ->
match term with
| Const s -> (
try VarEnv.find_exn n s
try VarEnv.find_exn s n
with _ -> (
try Term.rational (Q.of_string s)
with _ -> (

@ -34,22 +34,26 @@ end) : S with type key = Key.t = struct
let empty = M.empty
let singleton = M.singleton
let add_exn m ~key ~data = M.add key data m
let set m ~key ~data = M.add key data m
let add_multi m ~key ~data =
let add_exn ~key ~data m =
assert (not (M.mem key m)) ;
M.add key data m
let add ~key ~data m = M.add key data m
let add_multi ~key ~data m =
M.update key
(function Some vs -> Some (data :: vs) | None -> Some [data])
m
let remove m key = M.remove key m
let merge l r ~f = M.merge_safe l r ~f:(fun key -> f ~key)
let remove key m = M.remove key m
let merge l r ~f = M.merge_safe l r ~f
let merge_endo t u ~f =
let change = ref false in
let t' =
merge t u ~f:(fun ~key side ->
let f_side = f ~key side in
merge t u ~f:(fun key side ->
let f_side = f key side in
( match (side, f_side) with
| (`Both (data, _) | `Left data), Some data' when data' == data ->
()
@ -58,9 +62,6 @@ end) : S with type key = Key.t = struct
in
if !change then t' else t
let merge_skewed x y ~combine =
M.union (fun key v1 v2 -> Some (combine ~key v1 v2)) x y
let union x y ~f = M.union f x y
let partition m ~f = M.partition f m
let is_empty = M.is_empty
@ -103,13 +104,6 @@ end) : S with type key = Key.t = struct
| None -> false
let length = M.cardinal
let choose_key = root_key
let choose = root_binding
let choose_exn m = Option.get_exn (choose m)
let min_binding = M.min_binding_opt
let mem m k = M.mem k m
let find_exn m k = M.find k m
let find m k = M.find_opt k m
let only_binding m =
match root_key m with
@ -127,10 +121,18 @@ end) : S with type key = Key.t = struct
| l, Some v, r when is_empty l && is_empty r -> `One (k, v)
| _ -> `Many )
let find_multi m k =
let choose_key = root_key
let choose = root_binding
let choose_exn m = Option.get_exn (choose m)
let min_binding = M.min_binding_opt
let mem k m = M.mem k m
let find_exn k m = M.find k m
let find k m = M.find_opt k m
let find_multi k m =
match M.find_opt k m with None -> [] | Some vs -> vs
let find_and_remove m k =
let find_and_remove k m =
let found = ref None in
let m =
M.update k
@ -141,13 +143,13 @@ end) : S with type key = Key.t = struct
in
Option.map ~f:(fun v -> (v, m)) !found
let pop m = choose m |> Option.map ~f:(fun (k, v) -> (k, v, remove m k))
let pop m = choose m |> Option.map ~f:(fun (k, v) -> (k, v, remove k m))
let pop_min_binding m =
min_binding m |> Option.map ~f:(fun (k, v) -> (k, v, remove m k))
min_binding m |> Option.map ~f:(fun (k, v) -> (k, v, remove k m))
let change m key ~f = M.update key f m
let update m k ~f = M.update k (fun v -> Some (f v)) m
let change k m ~f = M.update k f m
let update k m ~f = M.update k (fun v -> Some (f v)) m
let map m ~f = M.map f m
let mapi m ~f = M.mapi (fun key data -> f ~key ~data) m
let map_endo t ~f = map_endo map t ~f
@ -157,9 +159,9 @@ end) : S with type key = Key.t = struct
let existsi m ~f = M.exists (fun key data -> f ~key ~data) m
let for_alli m ~f = M.for_all (fun key data -> f ~key ~data) m
let fold m ~init ~f = M.fold (fun key data acc -> f ~key ~data acc) m init
let to_alist ?key_order:_ = M.to_list
let data m = Iter.to_list (M.values m)
let to_iter = M.to_iter
let keys = M.keys
let values = M.values
let to_iter m = Iter.rev (M.to_iter m)
let to_iter2 l r =
let seq = ref Iter.empty in
@ -169,10 +171,10 @@ end) : S with type key = Key.t = struct
|> ignore ;
!seq
let symmetric_diff ~data_equal l r =
let symmetric_diff l r ~eq =
Iter.filter_map (to_iter2 l r) ~f:(fun (k, vv) ->
match vv with
| `Both (lv, rv) when data_equal lv rv -> None
| `Both (lv, rv) when eq lv rv -> None
| `Both vv -> Some (k, `Unequal vv)
| `Left lv -> Some (k, `Left lv)
| `Right rv -> Some (k, `Right rv) )
@ -181,9 +183,9 @@ end) : S with type key = Key.t = struct
Format.fprintf fs "@[<1>[%a]@]"
(List.pp ",@ " (fun fs (k, v) ->
Format.fprintf fs "@[%a@ @<2>↦ %a@]" pp_k k pp_v v ))
(to_alist m)
(Iter.to_list (to_iter m))
let pp_diff ~data_equal pp_key pp_val pp_diff_val fs (x, y) =
let pp_diff pp_key pp_val pp_diff_val ~eq fs (x, y) =
let pp_diff_elt fs = function
| k, `Left v ->
Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v
@ -192,7 +194,7 @@ end) : S with type key = Key.t = struct
| k, `Unequal vv ->
Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_diff_val vv
in
let sd = Iter.to_list (symmetric_diff ~data_equal x y) in
let sd = Iter.to_list (symmetric_diff ~eq x y) in
if not (List.is_empty sd) then
Format.fprintf fs "[@[<hv>%a@]];@ " (List.pp ";@ " pp_diff_elt) sd
end

@ -21,16 +21,16 @@ module type S = sig
val empty : 'a t
val singleton : key -> 'a -> 'a t
val add_exn : 'a t -> key:key -> data:'a -> 'a t
val set : 'a t -> key:key -> data:'a -> 'a t
val add_multi : 'a list t -> key:key -> data:'a -> 'a list t
val remove : 'a t -> key -> 'a t
val add_exn : key:key -> data:'a -> 'a t -> 'a t
val add : key:key -> data:'a -> 'a t -> 'a t
val add_multi : key:key -> data:'a -> 'a list t -> 'a list t
val remove : key -> 'a t -> 'a t
val merge :
'a t
-> 'b t
-> f:
( key:key
( key
-> [`Left of 'a | `Both of 'a * 'b | `Right of 'b]
-> 'c option)
-> 'c t
@ -39,7 +39,7 @@ module type S = sig
'a t
-> 'b t
-> f:
( key:key
( key
-> [`Left of 'a | `Both of 'a * 'b | `Right of 'b]
-> 'a option)
-> 'a t
@ -47,9 +47,6 @@ module type S = sig
left argument, which enables preserving [==] if [f] preserves [==] of
every value. *)
val merge_skewed :
'a t -> 'a t -> combine:(key:key -> 'a -> 'a -> 'a) -> 'a t
val union : 'a t -> 'a t -> f:(key -> 'a -> 'a -> 'a option) -> 'a t
val partition : 'a t -> f:(key -> 'a -> bool) -> 'a t * 'a t
@ -74,12 +71,12 @@ module type S = sig
equivalent maps. [O(1)]. *)
val min_binding : 'a t -> (key * 'a) option
val mem : 'a t -> key -> bool
val find : 'a t -> key -> 'a option
val find_exn : 'a t -> key -> 'a
val find_multi : 'a list t -> key -> 'a list
val mem : key -> 'a t -> bool
val find : key -> 'a t -> 'a option
val find_exn : key -> 'a t -> 'a
val find_multi : key -> 'a list t -> 'a list
val find_and_remove : 'a t -> key -> ('a * 'a t) option
val find_and_remove : key -> 'a t -> ('a * 'a t) option
(** Find and remove the binding for a key. *)
val pop : 'a t -> (key * 'a * 'a t) option
@ -91,8 +88,8 @@ module type S = sig
(** {1 Transform} *)
val change : 'a t -> key -> f:('a option -> 'a option) -> 'a t
val update : 'a t -> key -> f:('a option -> 'a) -> 'a t
val change : key -> 'a t -> f:('a option -> 'a option) -> 'a t
val update : key -> 'a t -> f:('a option -> 'a) -> 'a t
val map : 'a t -> f:('a -> 'b) -> 'b t
val mapi : 'a t -> f:(key:key -> data:'a -> 'b) -> 'b t
@ -108,14 +105,12 @@ module type S = sig
val iteri : 'a t -> f:(key:key -> data:'a -> unit) -> unit
val existsi : 'a t -> f:(key:key -> data:'a -> bool) -> bool
val for_alli : 'a t -> f:(key:key -> data:'a -> bool) -> bool
val fold : 'a t -> init:'b -> f:(key:key -> data:'a -> 'b -> 'b) -> 'b
val fold : 'a t -> init:'s -> f:(key:key -> data:'a -> 's -> 's) -> 's
(** {1 Convert} *)
val to_alist :
?key_order:[`Increasing | `Decreasing] -> 'a t -> (key * 'a) list
val data : 'a t -> 'a list
val keys : 'a t -> key iter
val values : 'a t -> 'a iter
val to_iter : 'a t -> (key * 'a) iter
val to_iter2 :
@ -124,9 +119,9 @@ module type S = sig
-> (key * [`Left of 'a | `Both of 'a * 'b | `Right of 'b]) iter
val symmetric_diff :
data_equal:('a -> 'b -> bool)
-> 'a t
'a t
-> 'b t
-> eq:('a -> 'b -> bool)
-> (key * [> `Left of 'a | `Unequal of 'a * 'b | `Right of 'b]) iter
(** {1 Pretty-print} *)
@ -134,9 +129,9 @@ module type S = sig
val pp : key pp -> 'a pp -> 'a t pp
val pp_diff :
data_equal:('a -> 'a -> bool)
-> key pp
key pp
-> 'a pp
-> ('a * 'a) pp
-> eq:('a -> 'a -> bool)
-> ('a t * 'a t) pp
end

@ -34,23 +34,25 @@ struct
let sexp_of_t s =
List.sexp_of_t
(Sexplib.Conv.sexp_of_pair Elt.sexp_of_t Mul.sexp_of_t)
(M.to_alist s)
(Iter.to_list (M.to_iter s))
let t_of_sexp elt_of_sexp sexp =
List.fold_left
~f:(fun m (key, data) -> M.add_exn m ~key ~data)
~f:(fun m (key, data) -> M.add_exn ~key ~data m)
~init:M.empty
(List.t_of_sexp
(Sexplib.Conv.pair_of_sexp elt_of_sexp Mul.t_of_sexp)
sexp)
let pp sep pp_elt fs s = List.pp sep pp_elt fs (M.to_alist s)
let pp sep pp_elt fs s =
List.pp sep pp_elt fs (Iter.to_list (M.to_iter s))
let empty = M.empty
let of_ x i = if Mul.equal Mul.zero i then empty else M.singleton x i
let if_nz i = if Mul.equal Mul.zero i then None else Some i
let add m x i =
M.change m x ~f:(function
let add x i m =
M.change x m ~f:(function
| Some j -> if_nz (Mul.add i j)
| None -> if_nz i )
@ -59,7 +61,7 @@ struct
let union m n = M.union m n ~f:(fun _ i j -> if_nz (Mul.add i j))
let diff m n =
M.merge m n ~f:(fun ~key:_ -> function
M.merge m n ~f:(fun _ -> function
| `Both (i, j) -> if_nz (Mul.sub i j)
| `Left i -> Some i
| `Right j -> Some (Mul.neg j) )
@ -72,8 +74,8 @@ struct
M.fold m ~init:(m, m') ~f:(fun ~key:x ~data:i (m, m') ->
let x', i' = f x i in
if x' == x then
if Mul.equal i' i then (m, m') else (M.set m ~key:x ~data:i', m')
else (M.remove m x, add m' x' i') )
if Mul.equal i' i then (m, m') else (M.add ~key:x ~data:i' m, m')
else (M.remove x m, add x' i' m') )
in
union m m'
@ -89,16 +91,16 @@ struct
| Some (x', i') ->
if x' == x then
if Mul.equal i' i then (m, m')
else (M.set m ~key:x ~data:i', m')
else (M.remove m x, union m' d)
| None -> (M.remove m x, union m' d) )
else (M.add ~key:x ~data:i' m, m')
else (M.remove x m, union m' d)
| None -> (M.remove x m, union m' d) )
in
union m m'
let is_empty = M.is_empty
let is_singleton = M.is_singleton
let length m = M.length m
let count m x = match M.find m x with Some q -> q | None -> Mul.zero
let count x m = match M.find x m with Some q -> q | None -> Mul.zero
let only_elt = M.only_binding
let classify = M.classify
let choose = M.choose
@ -106,7 +108,7 @@ struct
let pop = M.pop
let min_elt = M.min_binding
let pop_min_elt = M.pop_min_binding
let to_list m = M.to_alist m
let to_iter = M.to_iter
let iter m ~f = M.iteri ~f:(fun ~key ~data -> f key data) m
let exists m ~f = M.existsi ~f:(fun ~key ~data -> f key data) m
let for_all m ~f = M.for_alli ~f:(fun ~key ~data -> f key data) m

@ -37,10 +37,10 @@ module type S = sig
val of_ : elt -> mul -> t
val add : t -> elt -> mul -> t
val add : elt -> mul -> t -> t
(** Add to multiplicity of single element. [O(log n)] *)
val remove : t -> elt -> t
val remove : elt -> t -> t
(** Set the multiplicity of an element to zero. [O(log n)] *)
val union : t -> t -> t
@ -74,7 +74,7 @@ module type S = sig
val length : t -> int
(** Number of elements with non-zero multiplicity. [O(1)]. *)
val count : t -> elt -> mul
val count : elt -> t -> mul
(** Multiplicity of an element. [O(log n)]. *)
val only_elt : t -> (elt * mul) option
@ -98,11 +98,11 @@ module type S = sig
val classify : t -> [`Zero | `One of elt * mul | `Many]
(** Classify a set as either empty, singleton, or otherwise. *)
val find_and_remove : t -> elt -> (mul * t) option
val find_and_remove : elt -> t -> (mul * t) option
(** Find and remove an element. *)
val to_list : t -> (elt * mul) list
(** Convert to a list of elements in ascending order. *)
val to_iter : t -> (elt * mul) iter
(** Enumerate elements in ascending order. *)
(* traversals *)

@ -235,7 +235,7 @@ module Representation (Trm : INDETERMINATE) = struct
(* transform *)
let split_const poly =
match Sum.find_and_remove poly Mono.one with
match Sum.find_and_remove Mono.one poly with
| Some (c, p_c) -> (p_c, c)
| None -> (poly, Q.zero)
@ -249,11 +249,11 @@ module Representation (Trm : INDETERMINATE) = struct
let trm' = f trm in
if trm == trm' then (m, cm')
else
(Prod.remove m trm, CM.mul cm' (CM.of_trm trm' ~power)) )
(Prod.remove trm m, CM.mul cm' (CM.of_trm trm' ~power)) )
in
if CM.equal_one cm' then (p, p')
else
( Sum.remove p mono
( Sum.remove mono p
, Sum.union p' (CM.to_poly (CM.mul (coeff, m) cm')) ) )
in
if Sum.is_empty p' then poly else Sum.union p p' |> check invariant

@ -187,10 +187,10 @@ module Make (Dom : Domain_intf.Dom) = struct
let empty = M.empty
let find = M.find
let set = M.set
let add = M.add
let join x y =
M.merge x y ~f:(fun ~key:_ -> function
M.merge x y ~f:(fun _ -> function
| `Left d | `Right d -> Some d
| `Both (d1, d2) -> Some (Int.max d1 d2) )
end
@ -215,7 +215,7 @@ module Make (Dom : Domain_intf.Dom) = struct
let add ?prev ~retreating stk state curr depths ((pq, ws, bound) as work)
=
let edge = {Edge.dst= curr; src= prev; stk} in
let depth = Option.value (Depths.find depths edge) ~default:0 in
let depth = Option.value (Depths.find edge depths) ~default:0 in
let depth = if retreating then depth + 1 else depth in
if depth > bound then (
[%Trace.info "prune: %i: %a" depth Edge.pp edge] ;
@ -223,9 +223,9 @@ module Make (Dom : Domain_intf.Dom) = struct
else
let pq = FHeap.add pq (depth, edge) in
[%Trace.info "@[<6>enqueue %i: %a@ | %a@]" depth Edge.pp edge pp pq] ;
let depths = Depths.set depths ~key:edge ~data:depth in
let depths = Depths.add ~key:edge ~data:depth depths in
let ws =
Llair.Block.Map.add_multi ws ~key:curr ~data:(state, depths)
Llair.Block.Map.add_multi ~key:curr ~data:(state, depths) ws
in
(pq, ws, bound)
@ -236,7 +236,7 @@ module Make (Dom : Domain_intf.Dom) = struct
let rec run ~f (pq0, ws, bnd) =
match FHeap.pop pq0 with
| Some ((_, ({Edge.dst; stk} as edge)), pq) -> (
match Llair.Block.Map.find_and_remove ws dst with
match Llair.Block.Map.find_and_remove dst ws with
| Some (q :: qs, ws) ->
let join (qa, da) (q, d) = (Dom.join q qa, Depths.join d da) in
let skipped, (qs, depths) =
@ -245,7 +245,7 @@ module Make (Dom : Domain_intf.Dom) = struct
| Some joined, depths -> (skipped, (joined, depths))
| None, _ -> (curr :: skipped, joined) )
in
let ws = Llair.Block.Map.add_exn ws ~key:dst ~data:skipped in
let ws = Llair.Block.Map.add_exn ~key:dst ~data:skipped ws in
run ~f (f stk qs dst depths (pq, ws, bnd))
| _ ->
[%Trace.info "done: %a" Edge.pp edge] ;
@ -330,7 +330,7 @@ module Make (Dom : Domain_intf.Dom) = struct
~formals:
(Llair.Reg.Set.union (Llair.Reg.Set.of_list formals) globals)
in
RegTbl.add_multi summary_table ~key:name.reg ~data:function_summary ;
RegTbl.add_multi ~key:name.reg ~data:function_summary summary_table ;
pp_st () ;
post_state
in
@ -432,7 +432,7 @@ module Make (Dom : Domain_intf.Dom) = struct
| None -> x )
| Call ({callee; actuals; areturn; return} as call) -> (
let lookup name =
Option.to_list (Llair.Func.find pgm.functions name)
Option.to_list (Llair.Func.find name pgm.functions)
in
let callees, state = Dom.resolve_callee lookup callee state in
match callees with
@ -492,7 +492,9 @@ module Make (Dom : Domain_intf.Dom) = struct
let harness : exec_opts -> Llair.program -> (int -> Work.t) option =
fun opts pgm ->
List.find_map ~f:(Llair.Func.find pgm.functions) opts.entry_points
List.find_map
~f:(fun entry_point -> Llair.Func.find entry_point pgm.functions)
opts.entry_points
|> function
| Some {name= {reg}; formals= []; freturn; locals; entry} ->
Some
@ -517,5 +519,5 @@ module Make (Dom : Domain_intf.Dom) = struct
exec_pgm opts pgm ;
RegTbl.fold summary_table ~init:Llair.Reg.Map.empty
~f:(fun ~key ~data map ->
match data with [] -> map | _ -> Llair.Reg.Map.set map ~key ~data )
match data with [] -> map | _ -> Llair.Reg.Map.add ~key ~data map )
end

@ -111,7 +111,7 @@ let by_function : r -> Llair.Reg.t -> t =
( match s with
| Declared set -> set
| Per_function map -> (
match Llair.Reg.Map.find map fn with
match Llair.Reg.Map.find fn map with
| Some gs -> gs
| None ->
fail

@ -1095,7 +1095,7 @@ module Context = struct
~f:(fun ~key:rep ~data:cls clss ->
let rep' = of_ses rep in
let cls' = List.map ~f:of_ses cls in
Term.Map.set ~key:rep' ~data:cls' clss )
Term.Map.add ~key:rep' ~data:cls' clss )
let diff_classes r s =
Term.Map.filter_mapi (classes r) ~f:(fun ~key:rep ~data:cls ->
@ -1116,7 +1116,8 @@ module Context = struct
(fun fs (rep, cls) ->
Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) rep (ppx_cls x)
(List.sort ~cmp:Term.compare cls) )
fs (Term.Map.to_alist clss)
fs
(Iter.to_list (Term.Map.to_iter clss))
let pp fs r = ppx_classes (fun _ -> None) fs (classes r)

@ -562,8 +562,8 @@ let set_derived_metadata functions =
| Iswitch {tbl; _} -> IArray.iter tbl ~f:jump
| Call ({callee; return; throw; _} as call) ->
( match
Option.bind ~f:(Func.find functions)
(Option.map ~f:Reg.name (Reg.of_exp callee))
let* reg = Reg.of_exp callee in
Func.find (Reg.name reg) functions
with
| Some func ->
if Block_label.Set.mem ancestors func.entry then
@ -589,7 +589,7 @@ let set_derived_metadata functions =
in
let functions =
List.fold functions ~init:String.Map.empty ~f:(fun m func ->
String.Map.add_exn m ~key:(Reg.name func.name.reg) ~data:func )
String.Map.add_exn ~key:(Reg.name func.name.reg) ~data:func m )
in
let roots = compute_roots functions in
let tips_to_roots = topsort functions roots in
@ -616,6 +616,7 @@ module Program = struct
(IArray.pp "@\n@\n" Global.pp_defn)
globals
(List.pp "@\n@\n" Func.pp)
( String.Map.data functions
( String.Map.values functions
|> Iter.to_list
|> List.sort ~cmp:(fun x y -> compare_block x.entry y.entry) )
end

@ -195,7 +195,7 @@ module Func : sig
-> fthrow:Reg.t
-> t
val find : functions -> string -> func option
val find : string -> functions -> func option
(** Look up a function of the given name in the given functions. *)
val is_undefined : func -> bool

@ -38,8 +38,8 @@ module Subst : sig
val empty : t
val is_empty : t -> bool
val length : t -> int
val mem : t -> Term.t -> bool
val find : t -> Term.t -> Term.t option
val mem : Term.t -> t -> bool
val find : Term.t -> t -> Term.t option
val fold : t -> init:'a -> f:(key:Term.t -> data:Term.t -> 'a -> 'a) -> 'a
val iteri : t -> f:(key:Term.t -> data:Term.t -> unit) -> unit
val for_alli : t -> f:(key:Term.t -> data:Term.t -> bool) -> bool
@ -52,17 +52,13 @@ module Subst : sig
val remove : Var.Set.t -> t -> t
val map_entries : f:(Term.t -> Term.t) -> t -> t
val to_iter : t -> (Term.t * Term.t) iter
val to_alist : t -> (Term.t * Term.t) list
val partition_valid : Var.Set.t -> t -> t * Var.Set.t * t
end = struct
type t = Term.t Term.Map.t [@@deriving compare, equal, sexp_of]
let t_of_sexp = Term.Map.t_of_sexp Term.t_of_sexp
let pp = Term.Map.pp Term.pp Term.pp
let pp_diff =
Term.Map.pp_diff ~data_equal:Term.equal Term.pp Term.pp Term.pp_diff
let pp_diff = Term.Map.pp_diff ~eq:Term.equal Term.pp Term.pp Term.pp_diff
let empty = Term.Map.empty
let is_empty = Term.Map.is_empty
let length = Term.Map.length
@ -72,10 +68,9 @@ end = struct
let iteri = Term.Map.iteri
let for_alli = Term.Map.for_alli
let to_iter = Term.Map.to_iter
let to_alist = Term.Map.to_alist ~key_order:`Increasing
(** look up a term in a substitution *)
let apply s a = Term.Map.find s a |> Option.value ~default:a
let apply s a = Term.Map.find a s |> Option.value ~default:a
let rec subst s a = apply s (Term.map ~f:(subst s) a)
@ -88,7 +83,7 @@ end = struct
[%Trace.call fun {pf} -> pf "%a@ %a" pp r pp s]
;
let r' = Term.Map.map_endo ~f:(norm s) r in
Term.Map.merge_endo r' s ~f:(fun ~key -> function
Term.Map.merge_endo r' s ~f:(fun key -> function
| `Both (data_r, data_s) ->
assert (
Term.equal data_s data_r
@ -112,7 +107,7 @@ end = struct
let extend e s =
let exception Found in
match
Term.Map.update s e ~f:(function
Term.Map.update e s ~f:(function
| Some _ -> raise_notrace Found
| None -> e )
with
@ -121,7 +116,7 @@ end = struct
(** remove entries for vars *)
let remove xs s =
Var.Set.fold ~f:(fun s x -> Term.Map.remove s (Term.var x)) ~init:s xs
Var.Set.fold ~f:(fun s x -> Term.Map.remove (Term.var x) s) ~init:s xs
(** map over a subst, applying [f] to both domain and range, requires that
[f] is injective and for any set of terms [E], [f\[E\]] is disjoint
@ -132,9 +127,9 @@ end = struct
let data' = f data in
if Term.equal key' key then
if Term.equal data' data then s
else Term.Map.set s ~key ~data:data'
else Term.Map.add ~key ~data:data' s
else
let s = Term.Map.remove s key in
let s = Term.Map.remove key s in
match (key : Term.t) with
| Integer _ | Rational _ -> s
| _ -> Term.Map.add_exn ~key:key' ~data:data' s )
@ -167,10 +162,10 @@ end = struct
Term.Map.fold s ~init:(t, ks, s) ~f:(fun ~key ~data (t, ks, s) ->
if is_valid_eq ks key data then (t, ks, s)
else
let t = Term.Map.set ~key ~data t
let t = Term.Map.add ~key ~data t
and ks =
Var.Set.diff ks (Var.Set.union (Term.fv key) (Term.fv data))
and s = Term.Map.remove s key in
and s = Term.Map.remove key s in
(t, ks, s) )
in
if s' != s then partition_valid_ t' ks' s' else (t', ks', s')
@ -347,7 +342,7 @@ type classes = Term.t list Term.Map.t
let classes r =
let add key data cls =
if Term.equal key data then cls
else Term.Map.add_multi cls ~key:data ~data:key
else Term.Map.add_multi ~key:data ~data:key cls
in
Subst.fold r.rep ~init:Term.Map.empty ~f:(fun ~key ~data cls ->
match classify key with
@ -356,7 +351,7 @@ let classes r =
let cls_of r e =
let e' = Subst.apply r.rep e in
Term.Map.find (classes r) e' |> Option.value ~default:[e']
Term.Map.find e' (classes r) |> Option.value ~default:[e']
(** Pretty-printing *)
@ -370,7 +365,7 @@ let pp fs {sat; rep} =
let pp_term_v fs (k, v) = if not (Term.equal k v) then Term.pp fs v in
Format.fprintf fs "@[{@[<hv>sat= %b;@ rep= %a@]}@]" sat
(pp_alist Term.pp pp_term_v)
(Subst.to_alist rep)
(Iter.to_list (Subst.to_iter rep))
let pp_diff fs (r, s) =
let pp_sat fs =
@ -392,18 +387,18 @@ let ppx_classes x fs clss =
(fun fs (rep, cls) ->
Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) rep (ppx_cls x)
(List.sort ~cmp:Term.compare cls) )
fs (Term.Map.to_alist clss)
fs
(Iter.to_list (Term.Map.to_iter clss))
let pp_classes fs r = ppx_classes (fun _ -> None) fs (classes r)
let pp_diff_clss =
Term.Map.pp_diff ~data_equal:(List.equal Term.equal) Term.pp pp_cls
pp_diff_cls
Term.Map.pp_diff ~eq:(List.equal Term.equal) Term.pp pp_cls pp_diff_cls
(** Basic queries *)
(** test membership in carrier *)
let in_car r e = Subst.mem r.rep e
let in_car r e = Subst.mem e r.rep
(** congruent specialized to assume subterms of [a'] are [Subst.norm]alized
wrt [r] (or canonized) *)
@ -575,7 +570,7 @@ let normalize = canon
let class_of r e =
let e' = normalize r e in
e' :: Term.Map.find_multi (classes r) e'
e' :: Term.Map.find_multi e' (classes r)
let fold_uses_of r t ~init ~f =
let rec fold_ e ~init:s ~f =
@ -707,7 +702,7 @@ let subst_invariant us s0 s =
Subst.iteri s ~f:(fun ~key ~data ->
(* dom of new entries not ito us *)
assert (
Option.for_all ~f:(Term.equal data) (Subst.find s0 key)
Option.for_all ~f:(Term.equal data) (Subst.find key s0)
|| not (Var.Set.is_subset (Term.fv key) ~of_:us) ) ;
(* rep not ito us implies trm not ito us *)
assert (
@ -912,8 +907,8 @@ let solve_class us us_xs ~key:rep ~data:cls (classes, subst) =
let cls = List.rev_append cls_not_ito_us_xs cls in
let cls = List.remove ~eq:Term.equal (Subst.norm subst rep) cls in
let classes =
if List.is_empty cls then Term.Map.remove classes rep
else Term.Map.set classes ~key:rep ~data:cls
if List.is_empty cls then Term.Map.remove rep classes
else Term.Map.add ~key:rep ~data:cls classes
in
(classes, subst)
|>
@ -980,7 +975,7 @@ let solve_for_xs r us xs (classes, subst, us_xs) =
Var.Set.fold xs ~init:(classes, subst, us_xs)
~f:(fun (classes, subst, us_xs) x ->
let x = Term.var x in
if Subst.mem subst x then (classes, subst, us_xs)
if Subst.mem x subst then (classes, subst, us_xs)
else solve_concat_extracts r us x (classes, subst, us_xs) )
(** move equations from [classes] to [subst] which can be expressed, after

@ -368,9 +368,9 @@ module Sum = struct
assert (not (Q.equal Q.zero coeff)) ;
match term with
| Integer {data} when Z.equal Z.zero data -> sum
| Integer {data} -> Qset.add sum one Q.(coeff * of_z data)
| Rational {data} -> Qset.add sum one Q.(coeff * data)
| _ -> Qset.add sum term coeff
| Integer {data} -> Qset.add one Q.(coeff * of_z data) sum
| Rational {data} -> Qset.add one Q.(coeff * data) sum
| _ -> Qset.add term coeff sum
let of_ ?(coeff = Q.one) term = add coeff term empty
@ -403,7 +403,7 @@ module Prod = struct
let add term prod =
assert (match term with Integer _ | Rational _ -> false | _ -> true) ;
Qset.add prod term Q.one
Qset.add term Q.one prod
let of_ term = add term empty
let union = Qset.union
@ -972,7 +972,7 @@ let d_int = function Integer {data} -> Some data | _ -> None
(** Access *)
let const_of = function Add poly -> Some (Qset.count poly one) | _ -> None
let const_of = function Add poly -> Some (Qset.count one poly) | _ -> None
(** Transform *)
@ -1198,7 +1198,7 @@ let rec solve_sum rejected_sum sum =
let* mono, coeff, sum = Qset.pop_min_elt sum in
match solve_for_mono rejected_sum coeff mono sum with
| Some _ as soln -> soln
| None -> solve_sum (Qset.add rejected_sum mono coeff) sum
| None -> solve_sum (Qset.add mono coeff rejected_sum) sum
(* solve [0 = e] *)
let solve_zero_eq ?for_ e =
@ -1209,7 +1209,7 @@ let solve_zero_eq ?for_ e =
match for_ with
| None -> solve_sum Qset.empty sum
| Some mono ->
let* coeff, sum = Qset.find_and_remove sum mono in
let* coeff, sum = Qset.find_and_remove mono sum in
solve_for_mono Qset.empty coeff mono sum )
| _ -> None )
|>

@ -87,7 +87,7 @@ module Make (T : REPR) = struct
Set.fold dom ~init:(empty, Set.empty, wrt)
~f:(fun (sub, rng, wrt) x ->
let x', wrt = fresh (name x) ~wrt in
let sub = Map.add_exn sub ~key:x ~data:x' in
let sub = Map.add_exn ~key:x ~data:x' sub in
let rng = Set.add rng x' in
(sub, rng, wrt) )
in
@ -107,7 +107,7 @@ module Make (T : REPR) = struct
let invert sub =
Map.fold sub ~init:empty ~f:(fun ~key ~data sub' ->
Map.add_exn sub' ~key:data ~data:key )
Map.add_exn ~key:data ~data:key sub' )
|> check invariant
let restrict sub vs =
@ -119,13 +119,13 @@ module Make (T : REPR) = struct
assert (
(* all substs are injective, so the current mapping is the
only one that can cause [data] to be in [rng] *)
(not (Set.mem (range (Map.remove sub key)) data))
(not (Set.mem (range (Map.remove key sub)) data))
|| violates invariant sub ) ;
{z with sub= Map.remove z.sub key} ) )
{z with sub= Map.remove key z.sub} ) )
|> check (fun {sub; dom; rng} ->
assert (Set.equal dom (domain sub)) ;
assert (Set.equal rng (range sub)) )
let apply sub v = Map.find sub v |> Option.value ~default:v
let apply sub v = Map.find v sub |> Option.value ~default:v
end
end

@ -90,28 +90,28 @@ let fold_vars ?ignore_ctx ?ignore_pure fold_vars q ~init ~f =
(** Pretty-printing *)
let rec var_strength_ xs m q =
let add m v =
match Var.Map.find m v with
| None -> Var.Map.set m ~key:v ~data:`Anonymous
| Some `Anonymous -> Var.Map.set m ~key:v ~data:`Existential
let add v m =
match Var.Map.find v m with
| None -> Var.Map.add ~key:v ~data:`Anonymous m
| Some `Anonymous -> Var.Map.add ~key:v ~data:`Existential m
| Some _ -> m
in
let xs = Var.Set.union xs q.xs in
let m_stem =
fold_vars_stem ~ignore_ctx:() q ~init:m ~f:(fun m var ->
if not (Var.Set.mem xs var) then
Var.Map.set m ~key:var ~data:`Universal
else add m var )
Var.Map.add ~key:var ~data:`Universal m
else add var m )
in
let m =
List.fold ~init:m_stem q.djns ~f:(fun m djn ->
let ms = List.map ~f:(fun dj -> snd (var_strength_ xs m dj)) djn in
List.reduce ms ~f:(fun m1 m2 ->
Var.Map.merge_skewed m1 m2 ~combine:(fun ~key:_ s1 s2 ->
Var.Map.union m1 m2 ~f:(fun _ s1 s2 ->
match (s1, s2) with
| `Anonymous, `Anonymous -> `Anonymous
| `Universal, _ | _, `Universal -> `Universal
| `Existential, _ | _, `Existential -> `Existential ) )
| `Anonymous, `Anonymous -> Some `Anonymous
| `Universal, _ | _, `Universal -> Some `Universal
| `Existential, _ | _, `Existential -> Some `Existential ) )
|> Option.value ~default:m )
in
(m_stem, m)
@ -119,7 +119,7 @@ let rec var_strength_ xs m q =
let var_strength ?(xs = Var.Set.empty) q =
let m =
Var.Set.fold xs ~init:Var.Map.empty ~f:(fun m x ->
Var.Map.set m ~key:x ~data:`Existential )
Var.Map.add ~key:x ~data:`Existential m )
in
var_strength_ xs m q
@ -198,7 +198,7 @@ let pp_us ?(pre = ("" : _ fmt)) ?vs () fs us =
let rec pp_ ?var_strength vs parent_xs parent_ctx fs
{us; xs; ctx; pure; heap; djns} =
Format.pp_open_hvbox fs 0 ;
let x v = Option.bind ~f:(fun (_, m) -> Var.Map.find m v) var_strength in
let x v = Option.bind ~f:(fun (_, m) -> Var.Map.find v m) var_strength in
pp_us ~vs () fs us ;
let xs_d_vs, xs_i_vs =
Var.Set.diff_inter

Loading…
Cancel
Save