diff --git a/sledge/cli/smtlib.ml b/sledge/cli/smtlib.ml index e4ea4eb9c..df9320019 100644 --- a/sledge/cli/smtlib.ml +++ b/sledge/cli/smtlib.ml @@ -53,7 +53,7 @@ and x_trm : var_env -> Smt.Ast.term -> Term.t = fun n term -> match term with | Const s -> ( - try VarEnv.find_exn n s + try VarEnv.find_exn s n with _ -> ( try Term.rational (Q.of_string s) with _ -> ( diff --git a/sledge/nonstdlib/map.ml b/sledge/nonstdlib/map.ml index 45d0b1437..4fd91df2c 100644 --- a/sledge/nonstdlib/map.ml +++ b/sledge/nonstdlib/map.ml @@ -34,22 +34,26 @@ end) : S with type key = Key.t = struct let empty = M.empty let singleton = M.singleton - let add_exn m ~key ~data = M.add key data m - let set m ~key ~data = M.add key data m - let add_multi m ~key ~data = + let add_exn ~key ~data m = + assert (not (M.mem key m)) ; + M.add key data m + + let add ~key ~data m = M.add key data m + + let add_multi ~key ~data m = M.update key (function Some vs -> Some (data :: vs) | None -> Some [data]) m - let remove m key = M.remove key m - let merge l r ~f = M.merge_safe l r ~f:(fun key -> f ~key) + let remove key m = M.remove key m + let merge l r ~f = M.merge_safe l r ~f let merge_endo t u ~f = let change = ref false in let t' = - merge t u ~f:(fun ~key side -> - let f_side = f ~key side in + merge t u ~f:(fun key side -> + let f_side = f key side in ( match (side, f_side) with | (`Both (data, _) | `Left data), Some data' when data' == data -> () @@ -58,9 +62,6 @@ end) : S with type key = Key.t = struct in if !change then t' else t - let merge_skewed x y ~combine = - M.union (fun key v1 v2 -> Some (combine ~key v1 v2)) x y - let union x y ~f = M.union f x y let partition m ~f = M.partition f m let is_empty = M.is_empty @@ -103,13 +104,6 @@ end) : S with type key = Key.t = struct | None -> false let length = M.cardinal - let choose_key = root_key - let choose = root_binding - let choose_exn m = Option.get_exn (choose m) - let min_binding = M.min_binding_opt - let mem m k = M.mem k m - let find_exn m k = M.find k m - let find m k = M.find_opt k m let only_binding m = match root_key m with @@ -127,10 +121,18 @@ end) : S with type key = Key.t = struct | l, Some v, r when is_empty l && is_empty r -> `One (k, v) | _ -> `Many ) - let find_multi m k = + let choose_key = root_key + let choose = root_binding + let choose_exn m = Option.get_exn (choose m) + let min_binding = M.min_binding_opt + let mem k m = M.mem k m + let find_exn k m = M.find k m + let find k m = M.find_opt k m + + let find_multi k m = match M.find_opt k m with None -> [] | Some vs -> vs - let find_and_remove m k = + let find_and_remove k m = let found = ref None in let m = M.update k @@ -141,13 +143,13 @@ end) : S with type key = Key.t = struct in Option.map ~f:(fun v -> (v, m)) !found - let pop m = choose m |> Option.map ~f:(fun (k, v) -> (k, v, remove m k)) + let pop m = choose m |> Option.map ~f:(fun (k, v) -> (k, v, remove k m)) let pop_min_binding m = - min_binding m |> Option.map ~f:(fun (k, v) -> (k, v, remove m k)) + min_binding m |> Option.map ~f:(fun (k, v) -> (k, v, remove k m)) - let change m key ~f = M.update key f m - let update m k ~f = M.update k (fun v -> Some (f v)) m + let change k m ~f = M.update k f m + let update k m ~f = M.update k (fun v -> Some (f v)) m let map m ~f = M.map f m let mapi m ~f = M.mapi (fun key data -> f ~key ~data) m let map_endo t ~f = map_endo map t ~f @@ -157,9 +159,9 @@ end) : S with type key = Key.t = struct let existsi m ~f = M.exists (fun key data -> f ~key ~data) m let for_alli m ~f = M.for_all (fun key data -> f ~key ~data) m let fold m ~init ~f = M.fold (fun key data acc -> f ~key ~data acc) m init - let to_alist ?key_order:_ = M.to_list - let data m = Iter.to_list (M.values m) - let to_iter = M.to_iter + let keys = M.keys + let values = M.values + let to_iter m = Iter.rev (M.to_iter m) let to_iter2 l r = let seq = ref Iter.empty in @@ -169,10 +171,10 @@ end) : S with type key = Key.t = struct |> ignore ; !seq - let symmetric_diff ~data_equal l r = + let symmetric_diff l r ~eq = Iter.filter_map (to_iter2 l r) ~f:(fun (k, vv) -> match vv with - | `Both (lv, rv) when data_equal lv rv -> None + | `Both (lv, rv) when eq lv rv -> None | `Both vv -> Some (k, `Unequal vv) | `Left lv -> Some (k, `Left lv) | `Right rv -> Some (k, `Right rv) ) @@ -181,9 +183,9 @@ end) : S with type key = Key.t = struct Format.fprintf fs "@[<1>[%a]@]" (List.pp ",@ " (fun fs (k, v) -> Format.fprintf fs "@[%a@ @<2>↦ %a@]" pp_k k pp_v v )) - (to_alist m) + (Iter.to_list (to_iter m)) - let pp_diff ~data_equal pp_key pp_val pp_diff_val fs (x, y) = + let pp_diff pp_key pp_val pp_diff_val ~eq fs (x, y) = let pp_diff_elt fs = function | k, `Left v -> Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v @@ -192,7 +194,7 @@ end) : S with type key = Key.t = struct | k, `Unequal vv -> Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_diff_val vv in - let sd = Iter.to_list (symmetric_diff ~data_equal x y) in + let sd = Iter.to_list (symmetric_diff ~eq x y) in if not (List.is_empty sd) then Format.fprintf fs "[@[%a@]];@ " (List.pp ";@ " pp_diff_elt) sd end diff --git a/sledge/nonstdlib/map_intf.ml b/sledge/nonstdlib/map_intf.ml index e9ec394cb..035324103 100644 --- a/sledge/nonstdlib/map_intf.ml +++ b/sledge/nonstdlib/map_intf.ml @@ -21,16 +21,16 @@ module type S = sig val empty : 'a t val singleton : key -> 'a -> 'a t - val add_exn : 'a t -> key:key -> data:'a -> 'a t - val set : 'a t -> key:key -> data:'a -> 'a t - val add_multi : 'a list t -> key:key -> data:'a -> 'a list t - val remove : 'a t -> key -> 'a t + val add_exn : key:key -> data:'a -> 'a t -> 'a t + val add : key:key -> data:'a -> 'a t -> 'a t + val add_multi : key:key -> data:'a -> 'a list t -> 'a list t + val remove : key -> 'a t -> 'a t val merge : 'a t -> 'b t -> f: - ( key:key + ( key -> [`Left of 'a | `Both of 'a * 'b | `Right of 'b] -> 'c option) -> 'c t @@ -39,7 +39,7 @@ module type S = sig 'a t -> 'b t -> f: - ( key:key + ( key -> [`Left of 'a | `Both of 'a * 'b | `Right of 'b] -> 'a option) -> 'a t @@ -47,9 +47,6 @@ module type S = sig left argument, which enables preserving [==] if [f] preserves [==] of every value. *) - val merge_skewed : - 'a t -> 'a t -> combine:(key:key -> 'a -> 'a -> 'a) -> 'a t - val union : 'a t -> 'a t -> f:(key -> 'a -> 'a -> 'a option) -> 'a t val partition : 'a t -> f:(key -> 'a -> bool) -> 'a t * 'a t @@ -74,12 +71,12 @@ module type S = sig equivalent maps. [O(1)]. *) val min_binding : 'a t -> (key * 'a) option - val mem : 'a t -> key -> bool - val find : 'a t -> key -> 'a option - val find_exn : 'a t -> key -> 'a - val find_multi : 'a list t -> key -> 'a list + val mem : key -> 'a t -> bool + val find : key -> 'a t -> 'a option + val find_exn : key -> 'a t -> 'a + val find_multi : key -> 'a list t -> 'a list - val find_and_remove : 'a t -> key -> ('a * 'a t) option + val find_and_remove : key -> 'a t -> ('a * 'a t) option (** Find and remove the binding for a key. *) val pop : 'a t -> (key * 'a * 'a t) option @@ -91,8 +88,8 @@ module type S = sig (** {1 Transform} *) - val change : 'a t -> key -> f:('a option -> 'a option) -> 'a t - val update : 'a t -> key -> f:('a option -> 'a) -> 'a t + val change : key -> 'a t -> f:('a option -> 'a option) -> 'a t + val update : key -> 'a t -> f:('a option -> 'a) -> 'a t val map : 'a t -> f:('a -> 'b) -> 'b t val mapi : 'a t -> f:(key:key -> data:'a -> 'b) -> 'b t @@ -108,14 +105,12 @@ module type S = sig val iteri : 'a t -> f:(key:key -> data:'a -> unit) -> unit val existsi : 'a t -> f:(key:key -> data:'a -> bool) -> bool val for_alli : 'a t -> f:(key:key -> data:'a -> bool) -> bool - val fold : 'a t -> init:'b -> f:(key:key -> data:'a -> 'b -> 'b) -> 'b + val fold : 'a t -> init:'s -> f:(key:key -> data:'a -> 's -> 's) -> 's (** {1 Convert} *) - val to_alist : - ?key_order:[`Increasing | `Decreasing] -> 'a t -> (key * 'a) list - - val data : 'a t -> 'a list + val keys : 'a t -> key iter + val values : 'a t -> 'a iter val to_iter : 'a t -> (key * 'a) iter val to_iter2 : @@ -124,9 +119,9 @@ module type S = sig -> (key * [`Left of 'a | `Both of 'a * 'b | `Right of 'b]) iter val symmetric_diff : - data_equal:('a -> 'b -> bool) - -> 'a t + 'a t -> 'b t + -> eq:('a -> 'b -> bool) -> (key * [> `Left of 'a | `Unequal of 'a * 'b | `Right of 'b]) iter (** {1 Pretty-print} *) @@ -134,9 +129,9 @@ module type S = sig val pp : key pp -> 'a pp -> 'a t pp val pp_diff : - data_equal:('a -> 'a -> bool) - -> key pp + key pp -> 'a pp -> ('a * 'a) pp + -> eq:('a -> 'a -> bool) -> ('a t * 'a t) pp end diff --git a/sledge/nonstdlib/multiset.ml b/sledge/nonstdlib/multiset.ml index e44935db2..7f34a9085 100644 --- a/sledge/nonstdlib/multiset.ml +++ b/sledge/nonstdlib/multiset.ml @@ -34,23 +34,25 @@ struct let sexp_of_t s = List.sexp_of_t (Sexplib.Conv.sexp_of_pair Elt.sexp_of_t Mul.sexp_of_t) - (M.to_alist s) + (Iter.to_list (M.to_iter s)) let t_of_sexp elt_of_sexp sexp = List.fold_left - ~f:(fun m (key, data) -> M.add_exn m ~key ~data) + ~f:(fun m (key, data) -> M.add_exn ~key ~data m) ~init:M.empty (List.t_of_sexp (Sexplib.Conv.pair_of_sexp elt_of_sexp Mul.t_of_sexp) sexp) - let pp sep pp_elt fs s = List.pp sep pp_elt fs (M.to_alist s) + let pp sep pp_elt fs s = + List.pp sep pp_elt fs (Iter.to_list (M.to_iter s)) + let empty = M.empty let of_ x i = if Mul.equal Mul.zero i then empty else M.singleton x i let if_nz i = if Mul.equal Mul.zero i then None else Some i - let add m x i = - M.change m x ~f:(function + let add x i m = + M.change x m ~f:(function | Some j -> if_nz (Mul.add i j) | None -> if_nz i ) @@ -59,7 +61,7 @@ struct let union m n = M.union m n ~f:(fun _ i j -> if_nz (Mul.add i j)) let diff m n = - M.merge m n ~f:(fun ~key:_ -> function + M.merge m n ~f:(fun _ -> function | `Both (i, j) -> if_nz (Mul.sub i j) | `Left i -> Some i | `Right j -> Some (Mul.neg j) ) @@ -72,8 +74,8 @@ struct M.fold m ~init:(m, m') ~f:(fun ~key:x ~data:i (m, m') -> let x', i' = f x i in if x' == x then - if Mul.equal i' i then (m, m') else (M.set m ~key:x ~data:i', m') - else (M.remove m x, add m' x' i') ) + if Mul.equal i' i then (m, m') else (M.add ~key:x ~data:i' m, m') + else (M.remove x m, add x' i' m') ) in union m m' @@ -89,16 +91,16 @@ struct | Some (x', i') -> if x' == x then if Mul.equal i' i then (m, m') - else (M.set m ~key:x ~data:i', m') - else (M.remove m x, union m' d) - | None -> (M.remove m x, union m' d) ) + else (M.add ~key:x ~data:i' m, m') + else (M.remove x m, union m' d) + | None -> (M.remove x m, union m' d) ) in union m m' let is_empty = M.is_empty let is_singleton = M.is_singleton let length m = M.length m - let count m x = match M.find m x with Some q -> q | None -> Mul.zero + let count x m = match M.find x m with Some q -> q | None -> Mul.zero let only_elt = M.only_binding let classify = M.classify let choose = M.choose @@ -106,7 +108,7 @@ struct let pop = M.pop let min_elt = M.min_binding let pop_min_elt = M.pop_min_binding - let to_list m = M.to_alist m + let to_iter = M.to_iter let iter m ~f = M.iteri ~f:(fun ~key ~data -> f key data) m let exists m ~f = M.existsi ~f:(fun ~key ~data -> f key data) m let for_all m ~f = M.for_alli ~f:(fun ~key ~data -> f key data) m diff --git a/sledge/nonstdlib/multiset_intf.ml b/sledge/nonstdlib/multiset_intf.ml index 69cfa200b..fd02a86d9 100644 --- a/sledge/nonstdlib/multiset_intf.ml +++ b/sledge/nonstdlib/multiset_intf.ml @@ -37,10 +37,10 @@ module type S = sig val of_ : elt -> mul -> t - val add : t -> elt -> mul -> t + val add : elt -> mul -> t -> t (** Add to multiplicity of single element. [O(log n)] *) - val remove : t -> elt -> t + val remove : elt -> t -> t (** Set the multiplicity of an element to zero. [O(log n)] *) val union : t -> t -> t @@ -74,7 +74,7 @@ module type S = sig val length : t -> int (** Number of elements with non-zero multiplicity. [O(1)]. *) - val count : t -> elt -> mul + val count : elt -> t -> mul (** Multiplicity of an element. [O(log n)]. *) val only_elt : t -> (elt * mul) option @@ -98,11 +98,11 @@ module type S = sig val classify : t -> [`Zero | `One of elt * mul | `Many] (** Classify a set as either empty, singleton, or otherwise. *) - val find_and_remove : t -> elt -> (mul * t) option + val find_and_remove : elt -> t -> (mul * t) option (** Find and remove an element. *) - val to_list : t -> (elt * mul) list - (** Convert to a list of elements in ascending order. *) + val to_iter : t -> (elt * mul) iter + (** Enumerate elements in ascending order. *) (* traversals *) diff --git a/sledge/src/arithmetic.ml b/sledge/src/arithmetic.ml index 9347700dc..bf07124a4 100644 --- a/sledge/src/arithmetic.ml +++ b/sledge/src/arithmetic.ml @@ -235,7 +235,7 @@ module Representation (Trm : INDETERMINATE) = struct (* transform *) let split_const poly = - match Sum.find_and_remove poly Mono.one with + match Sum.find_and_remove Mono.one poly with | Some (c, p_c) -> (p_c, c) | None -> (poly, Q.zero) @@ -249,11 +249,11 @@ module Representation (Trm : INDETERMINATE) = struct let trm' = f trm in if trm == trm' then (m, cm') else - (Prod.remove m trm, CM.mul cm' (CM.of_trm trm' ~power)) ) + (Prod.remove trm m, CM.mul cm' (CM.of_trm trm' ~power)) ) in if CM.equal_one cm' then (p, p') else - ( Sum.remove p mono + ( Sum.remove mono p , Sum.union p' (CM.to_poly (CM.mul (coeff, m) cm')) ) ) in if Sum.is_empty p' then poly else Sum.union p p' |> check invariant diff --git a/sledge/src/control.ml b/sledge/src/control.ml index 4edee6bbb..be926c8b9 100644 --- a/sledge/src/control.ml +++ b/sledge/src/control.ml @@ -187,10 +187,10 @@ module Make (Dom : Domain_intf.Dom) = struct let empty = M.empty let find = M.find - let set = M.set + let add = M.add let join x y = - M.merge x y ~f:(fun ~key:_ -> function + M.merge x y ~f:(fun _ -> function | `Left d | `Right d -> Some d | `Both (d1, d2) -> Some (Int.max d1 d2) ) end @@ -215,7 +215,7 @@ module Make (Dom : Domain_intf.Dom) = struct let add ?prev ~retreating stk state curr depths ((pq, ws, bound) as work) = let edge = {Edge.dst= curr; src= prev; stk} in - let depth = Option.value (Depths.find depths edge) ~default:0 in + let depth = Option.value (Depths.find edge depths) ~default:0 in let depth = if retreating then depth + 1 else depth in if depth > bound then ( [%Trace.info "prune: %i: %a" depth Edge.pp edge] ; @@ -223,9 +223,9 @@ module Make (Dom : Domain_intf.Dom) = struct else let pq = FHeap.add pq (depth, edge) in [%Trace.info "@[<6>enqueue %i: %a@ | %a@]" depth Edge.pp edge pp pq] ; - let depths = Depths.set depths ~key:edge ~data:depth in + let depths = Depths.add ~key:edge ~data:depth depths in let ws = - Llair.Block.Map.add_multi ws ~key:curr ~data:(state, depths) + Llair.Block.Map.add_multi ~key:curr ~data:(state, depths) ws in (pq, ws, bound) @@ -236,7 +236,7 @@ module Make (Dom : Domain_intf.Dom) = struct let rec run ~f (pq0, ws, bnd) = match FHeap.pop pq0 with | Some ((_, ({Edge.dst; stk} as edge)), pq) -> ( - match Llair.Block.Map.find_and_remove ws dst with + match Llair.Block.Map.find_and_remove dst ws with | Some (q :: qs, ws) -> let join (qa, da) (q, d) = (Dom.join q qa, Depths.join d da) in let skipped, (qs, depths) = @@ -245,7 +245,7 @@ module Make (Dom : Domain_intf.Dom) = struct | Some joined, depths -> (skipped, (joined, depths)) | None, _ -> (curr :: skipped, joined) ) in - let ws = Llair.Block.Map.add_exn ws ~key:dst ~data:skipped in + let ws = Llair.Block.Map.add_exn ~key:dst ~data:skipped ws in run ~f (f stk qs dst depths (pq, ws, bnd)) | _ -> [%Trace.info "done: %a" Edge.pp edge] ; @@ -330,7 +330,7 @@ module Make (Dom : Domain_intf.Dom) = struct ~formals: (Llair.Reg.Set.union (Llair.Reg.Set.of_list formals) globals) in - RegTbl.add_multi summary_table ~key:name.reg ~data:function_summary ; + RegTbl.add_multi ~key:name.reg ~data:function_summary summary_table ; pp_st () ; post_state in @@ -432,7 +432,7 @@ module Make (Dom : Domain_intf.Dom) = struct | None -> x ) | Call ({callee; actuals; areturn; return} as call) -> ( let lookup name = - Option.to_list (Llair.Func.find pgm.functions name) + Option.to_list (Llair.Func.find name pgm.functions) in let callees, state = Dom.resolve_callee lookup callee state in match callees with @@ -492,7 +492,9 @@ module Make (Dom : Domain_intf.Dom) = struct let harness : exec_opts -> Llair.program -> (int -> Work.t) option = fun opts pgm -> - List.find_map ~f:(Llair.Func.find pgm.functions) opts.entry_points + List.find_map + ~f:(fun entry_point -> Llair.Func.find entry_point pgm.functions) + opts.entry_points |> function | Some {name= {reg}; formals= []; freturn; locals; entry} -> Some @@ -517,5 +519,5 @@ module Make (Dom : Domain_intf.Dom) = struct exec_pgm opts pgm ; RegTbl.fold summary_table ~init:Llair.Reg.Map.empty ~f:(fun ~key ~data map -> - match data with [] -> map | _ -> Llair.Reg.Map.set map ~key ~data ) + match data with [] -> map | _ -> Llair.Reg.Map.add ~key ~data map ) end diff --git a/sledge/src/domain_used_globals.ml b/sledge/src/domain_used_globals.ml index e833de63f..54edcb86e 100644 --- a/sledge/src/domain_used_globals.ml +++ b/sledge/src/domain_used_globals.ml @@ -111,7 +111,7 @@ let by_function : r -> Llair.Reg.t -> t = ( match s with | Declared set -> set | Per_function map -> ( - match Llair.Reg.Map.find map fn with + match Llair.Reg.Map.find fn map with | Some gs -> gs | None -> fail diff --git a/sledge/src/fol.ml b/sledge/src/fol.ml index 2709058b5..7d070e326 100644 --- a/sledge/src/fol.ml +++ b/sledge/src/fol.ml @@ -1095,7 +1095,7 @@ module Context = struct ~f:(fun ~key:rep ~data:cls clss -> let rep' = of_ses rep in let cls' = List.map ~f:of_ses cls in - Term.Map.set ~key:rep' ~data:cls' clss ) + Term.Map.add ~key:rep' ~data:cls' clss ) let diff_classes r s = Term.Map.filter_mapi (classes r) ~f:(fun ~key:rep ~data:cls -> @@ -1116,7 +1116,8 @@ module Context = struct (fun fs (rep, cls) -> Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) rep (ppx_cls x) (List.sort ~cmp:Term.compare cls) ) - fs (Term.Map.to_alist clss) + fs + (Iter.to_list (Term.Map.to_iter clss)) let pp fs r = ppx_classes (fun _ -> None) fs (classes r) diff --git a/sledge/src/llair/llair.ml b/sledge/src/llair/llair.ml index 38a012ee9..b5af535b2 100644 --- a/sledge/src/llair/llair.ml +++ b/sledge/src/llair/llair.ml @@ -562,8 +562,8 @@ let set_derived_metadata functions = | Iswitch {tbl; _} -> IArray.iter tbl ~f:jump | Call ({callee; return; throw; _} as call) -> ( match - Option.bind ~f:(Func.find functions) - (Option.map ~f:Reg.name (Reg.of_exp callee)) + let* reg = Reg.of_exp callee in + Func.find (Reg.name reg) functions with | Some func -> if Block_label.Set.mem ancestors func.entry then @@ -589,7 +589,7 @@ let set_derived_metadata functions = in let functions = List.fold functions ~init:String.Map.empty ~f:(fun m func -> - String.Map.add_exn m ~key:(Reg.name func.name.reg) ~data:func ) + String.Map.add_exn ~key:(Reg.name func.name.reg) ~data:func m ) in let roots = compute_roots functions in let tips_to_roots = topsort functions roots in @@ -616,6 +616,7 @@ module Program = struct (IArray.pp "@\n@\n" Global.pp_defn) globals (List.pp "@\n@\n" Func.pp) - ( String.Map.data functions + ( String.Map.values functions + |> Iter.to_list |> List.sort ~cmp:(fun x y -> compare_block x.entry y.entry) ) end diff --git a/sledge/src/llair/llair.mli b/sledge/src/llair/llair.mli index 7a3d29e76..9e0b0d26b 100644 --- a/sledge/src/llair/llair.mli +++ b/sledge/src/llair/llair.mli @@ -195,7 +195,7 @@ module Func : sig -> fthrow:Reg.t -> t - val find : functions -> string -> func option + val find : string -> functions -> func option (** Look up a function of the given name in the given functions. *) val is_undefined : func -> bool diff --git a/sledge/src/ses/equality.ml b/sledge/src/ses/equality.ml index 12dd90fe9..673ee3127 100644 --- a/sledge/src/ses/equality.ml +++ b/sledge/src/ses/equality.ml @@ -38,8 +38,8 @@ module Subst : sig val empty : t val is_empty : t -> bool val length : t -> int - val mem : t -> Term.t -> bool - val find : t -> Term.t -> Term.t option + val mem : Term.t -> t -> bool + val find : Term.t -> t -> Term.t option val fold : t -> init:'a -> f:(key:Term.t -> data:Term.t -> 'a -> 'a) -> 'a val iteri : t -> f:(key:Term.t -> data:Term.t -> unit) -> unit val for_alli : t -> f:(key:Term.t -> data:Term.t -> bool) -> bool @@ -52,17 +52,13 @@ module Subst : sig val remove : Var.Set.t -> t -> t val map_entries : f:(Term.t -> Term.t) -> t -> t val to_iter : t -> (Term.t * Term.t) iter - val to_alist : t -> (Term.t * Term.t) list val partition_valid : Var.Set.t -> t -> t * Var.Set.t * t end = struct type t = Term.t Term.Map.t [@@deriving compare, equal, sexp_of] let t_of_sexp = Term.Map.t_of_sexp Term.t_of_sexp let pp = Term.Map.pp Term.pp Term.pp - - let pp_diff = - Term.Map.pp_diff ~data_equal:Term.equal Term.pp Term.pp Term.pp_diff - + let pp_diff = Term.Map.pp_diff ~eq:Term.equal Term.pp Term.pp Term.pp_diff let empty = Term.Map.empty let is_empty = Term.Map.is_empty let length = Term.Map.length @@ -72,10 +68,9 @@ end = struct let iteri = Term.Map.iteri let for_alli = Term.Map.for_alli let to_iter = Term.Map.to_iter - let to_alist = Term.Map.to_alist ~key_order:`Increasing (** look up a term in a substitution *) - let apply s a = Term.Map.find s a |> Option.value ~default:a + let apply s a = Term.Map.find a s |> Option.value ~default:a let rec subst s a = apply s (Term.map ~f:(subst s) a) @@ -88,7 +83,7 @@ end = struct [%Trace.call fun {pf} -> pf "%a@ %a" pp r pp s] ; let r' = Term.Map.map_endo ~f:(norm s) r in - Term.Map.merge_endo r' s ~f:(fun ~key -> function + Term.Map.merge_endo r' s ~f:(fun key -> function | `Both (data_r, data_s) -> assert ( Term.equal data_s data_r @@ -112,7 +107,7 @@ end = struct let extend e s = let exception Found in match - Term.Map.update s e ~f:(function + Term.Map.update e s ~f:(function | Some _ -> raise_notrace Found | None -> e ) with @@ -121,7 +116,7 @@ end = struct (** remove entries for vars *) let remove xs s = - Var.Set.fold ~f:(fun s x -> Term.Map.remove s (Term.var x)) ~init:s xs + Var.Set.fold ~f:(fun s x -> Term.Map.remove (Term.var x) s) ~init:s xs (** map over a subst, applying [f] to both domain and range, requires that [f] is injective and for any set of terms [E], [f\[E\]] is disjoint @@ -132,9 +127,9 @@ end = struct let data' = f data in if Term.equal key' key then if Term.equal data' data then s - else Term.Map.set s ~key ~data:data' + else Term.Map.add ~key ~data:data' s else - let s = Term.Map.remove s key in + let s = Term.Map.remove key s in match (key : Term.t) with | Integer _ | Rational _ -> s | _ -> Term.Map.add_exn ~key:key' ~data:data' s ) @@ -167,10 +162,10 @@ end = struct Term.Map.fold s ~init:(t, ks, s) ~f:(fun ~key ~data (t, ks, s) -> if is_valid_eq ks key data then (t, ks, s) else - let t = Term.Map.set ~key ~data t + let t = Term.Map.add ~key ~data t and ks = Var.Set.diff ks (Var.Set.union (Term.fv key) (Term.fv data)) - and s = Term.Map.remove s key in + and s = Term.Map.remove key s in (t, ks, s) ) in if s' != s then partition_valid_ t' ks' s' else (t', ks', s') @@ -347,7 +342,7 @@ type classes = Term.t list Term.Map.t let classes r = let add key data cls = if Term.equal key data then cls - else Term.Map.add_multi cls ~key:data ~data:key + else Term.Map.add_multi ~key:data ~data:key cls in Subst.fold r.rep ~init:Term.Map.empty ~f:(fun ~key ~data cls -> match classify key with @@ -356,7 +351,7 @@ let classes r = let cls_of r e = let e' = Subst.apply r.rep e in - Term.Map.find (classes r) e' |> Option.value ~default:[e'] + Term.Map.find e' (classes r) |> Option.value ~default:[e'] (** Pretty-printing *) @@ -370,7 +365,7 @@ let pp fs {sat; rep} = let pp_term_v fs (k, v) = if not (Term.equal k v) then Term.pp fs v in Format.fprintf fs "@[{@[sat= %b;@ rep= %a@]}@]" sat (pp_alist Term.pp pp_term_v) - (Subst.to_alist rep) + (Iter.to_list (Subst.to_iter rep)) let pp_diff fs (r, s) = let pp_sat fs = @@ -392,18 +387,18 @@ let ppx_classes x fs clss = (fun fs (rep, cls) -> Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) rep (ppx_cls x) (List.sort ~cmp:Term.compare cls) ) - fs (Term.Map.to_alist clss) + fs + (Iter.to_list (Term.Map.to_iter clss)) let pp_classes fs r = ppx_classes (fun _ -> None) fs (classes r) let pp_diff_clss = - Term.Map.pp_diff ~data_equal:(List.equal Term.equal) Term.pp pp_cls - pp_diff_cls + Term.Map.pp_diff ~eq:(List.equal Term.equal) Term.pp pp_cls pp_diff_cls (** Basic queries *) (** test membership in carrier *) -let in_car r e = Subst.mem r.rep e +let in_car r e = Subst.mem e r.rep (** congruent specialized to assume subterms of [a'] are [Subst.norm]alized wrt [r] (or canonized) *) @@ -575,7 +570,7 @@ let normalize = canon let class_of r e = let e' = normalize r e in - e' :: Term.Map.find_multi (classes r) e' + e' :: Term.Map.find_multi e' (classes r) let fold_uses_of r t ~init ~f = let rec fold_ e ~init:s ~f = @@ -707,7 +702,7 @@ let subst_invariant us s0 s = Subst.iteri s ~f:(fun ~key ~data -> (* dom of new entries not ito us *) assert ( - Option.for_all ~f:(Term.equal data) (Subst.find s0 key) + Option.for_all ~f:(Term.equal data) (Subst.find key s0) || not (Var.Set.is_subset (Term.fv key) ~of_:us) ) ; (* rep not ito us implies trm not ito us *) assert ( @@ -912,8 +907,8 @@ let solve_class us us_xs ~key:rep ~data:cls (classes, subst) = let cls = List.rev_append cls_not_ito_us_xs cls in let cls = List.remove ~eq:Term.equal (Subst.norm subst rep) cls in let classes = - if List.is_empty cls then Term.Map.remove classes rep - else Term.Map.set classes ~key:rep ~data:cls + if List.is_empty cls then Term.Map.remove rep classes + else Term.Map.add ~key:rep ~data:cls classes in (classes, subst) |> @@ -980,7 +975,7 @@ let solve_for_xs r us xs (classes, subst, us_xs) = Var.Set.fold xs ~init:(classes, subst, us_xs) ~f:(fun (classes, subst, us_xs) x -> let x = Term.var x in - if Subst.mem subst x then (classes, subst, us_xs) + if Subst.mem x subst then (classes, subst, us_xs) else solve_concat_extracts r us x (classes, subst, us_xs) ) (** move equations from [classes] to [subst] which can be expressed, after diff --git a/sledge/src/ses/term.ml b/sledge/src/ses/term.ml index 7f0b21e0e..4a2b1f57c 100644 --- a/sledge/src/ses/term.ml +++ b/sledge/src/ses/term.ml @@ -368,9 +368,9 @@ module Sum = struct assert (not (Q.equal Q.zero coeff)) ; match term with | Integer {data} when Z.equal Z.zero data -> sum - | Integer {data} -> Qset.add sum one Q.(coeff * of_z data) - | Rational {data} -> Qset.add sum one Q.(coeff * data) - | _ -> Qset.add sum term coeff + | Integer {data} -> Qset.add one Q.(coeff * of_z data) sum + | Rational {data} -> Qset.add one Q.(coeff * data) sum + | _ -> Qset.add term coeff sum let of_ ?(coeff = Q.one) term = add coeff term empty @@ -403,7 +403,7 @@ module Prod = struct let add term prod = assert (match term with Integer _ | Rational _ -> false | _ -> true) ; - Qset.add prod term Q.one + Qset.add term Q.one prod let of_ term = add term empty let union = Qset.union @@ -972,7 +972,7 @@ let d_int = function Integer {data} -> Some data | _ -> None (** Access *) -let const_of = function Add poly -> Some (Qset.count poly one) | _ -> None +let const_of = function Add poly -> Some (Qset.count one poly) | _ -> None (** Transform *) @@ -1198,7 +1198,7 @@ let rec solve_sum rejected_sum sum = let* mono, coeff, sum = Qset.pop_min_elt sum in match solve_for_mono rejected_sum coeff mono sum with | Some _ as soln -> soln - | None -> solve_sum (Qset.add rejected_sum mono coeff) sum + | None -> solve_sum (Qset.add mono coeff rejected_sum) sum (* solve [0 = e] *) let solve_zero_eq ?for_ e = @@ -1209,7 +1209,7 @@ let solve_zero_eq ?for_ e = match for_ with | None -> solve_sum Qset.empty sum | Some mono -> - let* coeff, sum = Qset.find_and_remove sum mono in + let* coeff, sum = Qset.find_and_remove mono sum in solve_for_mono Qset.empty coeff mono sum ) | _ -> None ) |> diff --git a/sledge/src/ses/var0.ml b/sledge/src/ses/var0.ml index 01a1b8c58..7f3df66fa 100644 --- a/sledge/src/ses/var0.ml +++ b/sledge/src/ses/var0.ml @@ -87,7 +87,7 @@ module Make (T : REPR) = struct Set.fold dom ~init:(empty, Set.empty, wrt) ~f:(fun (sub, rng, wrt) x -> let x', wrt = fresh (name x) ~wrt in - let sub = Map.add_exn sub ~key:x ~data:x' in + let sub = Map.add_exn ~key:x ~data:x' sub in let rng = Set.add rng x' in (sub, rng, wrt) ) in @@ -107,7 +107,7 @@ module Make (T : REPR) = struct let invert sub = Map.fold sub ~init:empty ~f:(fun ~key ~data sub' -> - Map.add_exn sub' ~key:data ~data:key ) + Map.add_exn ~key:data ~data:key sub' ) |> check invariant let restrict sub vs = @@ -119,13 +119,13 @@ module Make (T : REPR) = struct assert ( (* all substs are injective, so the current mapping is the only one that can cause [data] to be in [rng] *) - (not (Set.mem (range (Map.remove sub key)) data)) + (not (Set.mem (range (Map.remove key sub)) data)) || violates invariant sub ) ; - {z with sub= Map.remove z.sub key} ) ) + {z with sub= Map.remove key z.sub} ) ) |> check (fun {sub; dom; rng} -> assert (Set.equal dom (domain sub)) ; assert (Set.equal rng (range sub)) ) - let apply sub v = Map.find sub v |> Option.value ~default:v + let apply sub v = Map.find v sub |> Option.value ~default:v end end diff --git a/sledge/src/sh.ml b/sledge/src/sh.ml index 8cac3dbe0..dc6f15cd7 100644 --- a/sledge/src/sh.ml +++ b/sledge/src/sh.ml @@ -90,28 +90,28 @@ let fold_vars ?ignore_ctx ?ignore_pure fold_vars q ~init ~f = (** Pretty-printing *) let rec var_strength_ xs m q = - let add m v = - match Var.Map.find m v with - | None -> Var.Map.set m ~key:v ~data:`Anonymous - | Some `Anonymous -> Var.Map.set m ~key:v ~data:`Existential + let add v m = + match Var.Map.find v m with + | None -> Var.Map.add ~key:v ~data:`Anonymous m + | Some `Anonymous -> Var.Map.add ~key:v ~data:`Existential m | Some _ -> m in let xs = Var.Set.union xs q.xs in let m_stem = fold_vars_stem ~ignore_ctx:() q ~init:m ~f:(fun m var -> if not (Var.Set.mem xs var) then - Var.Map.set m ~key:var ~data:`Universal - else add m var ) + Var.Map.add ~key:var ~data:`Universal m + else add var m ) in let m = List.fold ~init:m_stem q.djns ~f:(fun m djn -> let ms = List.map ~f:(fun dj -> snd (var_strength_ xs m dj)) djn in List.reduce ms ~f:(fun m1 m2 -> - Var.Map.merge_skewed m1 m2 ~combine:(fun ~key:_ s1 s2 -> + Var.Map.union m1 m2 ~f:(fun _ s1 s2 -> match (s1, s2) with - | `Anonymous, `Anonymous -> `Anonymous - | `Universal, _ | _, `Universal -> `Universal - | `Existential, _ | _, `Existential -> `Existential ) ) + | `Anonymous, `Anonymous -> Some `Anonymous + | `Universal, _ | _, `Universal -> Some `Universal + | `Existential, _ | _, `Existential -> Some `Existential ) ) |> Option.value ~default:m ) in (m_stem, m) @@ -119,7 +119,7 @@ let rec var_strength_ xs m q = let var_strength ?(xs = Var.Set.empty) q = let m = Var.Set.fold xs ~init:Var.Map.empty ~f:(fun m x -> - Var.Map.set m ~key:x ~data:`Existential ) + Var.Map.add ~key:x ~data:`Existential m ) in var_strength_ xs m q @@ -198,7 +198,7 @@ let pp_us ?(pre = ("" : _ fmt)) ?vs () fs us = let rec pp_ ?var_strength vs parent_xs parent_ctx fs {us; xs; ctx; pure; heap; djns} = Format.pp_open_hvbox fs 0 ; - let x v = Option.bind ~f:(fun (_, m) -> Var.Map.find m v) var_strength in + let x v = Option.bind ~f:(fun (_, m) -> Var.Map.find v m) var_strength in pp_us ~vs () fs us ; let xs_d_vs, xs_i_vs = Var.Set.diff_inter