diff --git a/infer/src/checkers/cost.ml b/infer/src/checkers/cost.ml index e03029dae..b3f21744f 100644 --- a/infer/src/checkers/cost.ml +++ b/infer/src/checkers/cost.ml @@ -301,8 +301,6 @@ module ControlFlowCost = struct Sum.compare x y - let equal = [%compare.equal : t] - let make_node node = `Node node let make_pred_edge succ pred = `Edge (pred, succ) @@ -316,7 +314,7 @@ module ControlFlowCost = struct let sum : Item.t list -> t = function [] -> assert false | [e] -> (e :> t) | l -> Sum.of_list l module Set = struct - type elt = t + type elt = t [@@deriving compare] type t = {mutable size: int; mutable items: Item.t ARList.t; mutable sums: Sum.t ARList.t} @@ -331,13 +329,13 @@ module ControlFlowCost = struct {size= 1; items; sums} - (* Because we are modifying things on which we are iterating, we want to invalidate them first before removing them from hashtables, to avoid iterator invalidation. *) - let is_valid {size} = size >= 1 + let compare_size {size= size1} {size= size2} = Int.compare size1 size2 - let size {size} = size + (* Invalidation is just a sanity check, union-find already takes care of it. *) + let is_valid {size} = size >= 1 (* move semantics, should not be called with aliases *) - let merge_invalidate ~from ~to_ = + let merge ~from ~to_ = assert (not (phys_equal from to_)) ; assert (is_valid from) ; assert (is_valid to_) ; @@ -348,15 +346,13 @@ module ControlFlowCost = struct let pp_equalities fmt t = - if is_valid t then - ARList.append (t.items :> elt ARList.t) (t.sums :> elt ARList.t) - |> IContainer.to_rev_list ~fold:ARList.fold_unordered |> List.sort ~compare - |> Pp.seq ~sep:" = " pp fmt + ARList.append (t.items :> elt ARList.t) (t.sums :> elt ARList.t) + |> IContainer.to_rev_list ~fold:ARList.fold_unordered |> List.sort ~compare + |> Pp.seq ~sep:" = " pp fmt let normalize_sums : normalizer:(elt -> elt) -> t -> unit = fun ~normalizer t -> - assert (is_valid t) ; t.sums <- t.sums |> IContainer.rev_map_to_list ~fold:ARList.fold_unordered ~f:(Sum.normalize ~normalizer) @@ -373,214 +369,50 @@ module ControlFlowCost = struct let infer_equalities_from_sums : on_infer:(elt -> elt -> unit) -> normalizer:(elt -> elt) -> t -> unit = fun ~on_infer ~normalizer t -> - if is_valid t then ( - normalize_sums ~normalizer t ; - let all_items = - t.sums - |> ARList.fold_unordered ~init:ARList.empty ~f:(fun acc sum -> - sum |> Sum.items |> ARList.of_list |> ARList.append acc ) - |> IContainer.to_rev_list ~fold:ARList.fold_unordered - |> List.dedup_and_sort ~compare:Item.compare - in - (* Keep in mind that [on_infer] can modify (and possibly invalidate) [t]. - It happens only if we merge a node while infer equalities from it, i.e. in the case an item appears in an equality class both alone and in two sums, i.e. X = A + X = A + B. + normalize_sums ~normalizer t ; + let all_items = + t.sums + |> ARList.fold_unordered ~init:ARList.empty ~f:(fun acc sum -> + sum |> Sum.items |> ARList.of_list |> ARList.append acc ) + |> IContainer.to_rev_list ~fold:ARList.fold_unordered + |> List.dedup_and_sort ~compare:Item.compare + in + (* Keep in mind that [on_infer] can modify [t]. + It happens only if we merge a node while infering equalities from it, i.e. in the case an item appears in an equality class both alone and in two sums, i.e. X = A + X = A + B. This is not a problem here (we could stop if it happens but it is not necessary as existing equalities still remain true after merges) *) - (* Also keep in mind that the current version, in the worst-case scenario, is quadratic-ish in the size of the CFG *) - List.iter all_items ~f:(fun item -> infer_equalities_by_removing_item ~on_infer t item) ) + (* Also keep in mind that the current version, in the worst-case scenario, is quadratic-ish in the size of the CFG *) + List.iter all_items ~f:(fun item -> infer_equalities_by_removing_item ~on_infer t item) end end -module ImperativeUnionFind (E : sig - type t [@@deriving compare] - - val equal : t -> t -> bool - - val pp : F.formatter -> t -> unit - - module Set : sig - type elt = t - - type t - - val create : elt -> t - - val size : t -> int - - val merge_invalidate : from:t -> to_:t -> unit - - val pp_equalities : F.formatter -> t -> unit - - val normalize_sums : normalizer:(elt -> elt) -> t -> unit - - val infer_equalities_from_sums : - on_infer:(elt -> elt -> unit) -> normalizer:(elt -> elt) -> t -> unit - end -end) = -struct - module H = struct - include Caml.Hashtbl - - let caml_fold = fold - - let fold : (('a, 'b) t, 'a * 'b, 'accum) Container.fold = - fun h ~init ~f -> - let f' k v accum = f accum (k, v) in - caml_fold f' h init - - - let pp_bindings pp_key pp_value fmt h = - let pp_item fmt (k, v) = F.fprintf fmt "%a --> %a" pp_key k pp_value v in - IContainer.pp_collection ~fold ~pp_item fmt h - - - let[@warning "-32"] pp_values pp_value fmt h = - let pp_item fmt (_, v) = pp_value fmt v in - IContainer.pp_collection ~fold ~pp_item fmt h - end - - module Repr : sig - (* Sort-of abstracting away the fact that a representative is just an element itself. - This ensures that the [Sets] hashtable is accessed with representative and not just elements. *) - - type t = private E.t - - val equal : t -> t -> bool - - val pp : F.formatter -> t -> unit - - val of_e : E.t -> t - - val is_simpler_than : t -> t -> bool - end = struct - include E - - let of_e e = e - - let is_simpler_than r1 r2 = compare r1 r2 <= 0 - end - - module Reprs = struct - type t = (E.t, Repr.t) H.t - - let create () = H.create 1 - - let rec find (t: t) e : Repr.t = - match H.find_opt t e with - | None -> - Repr.of_e e - | Some r -> - let r' = find t (r :> E.t) in - if not (phys_equal r r') then H.replace t e r' ; - r' - - - let merge (t: t) ~(from: Repr.t) ~(to_: Repr.t) = H.replace t (from :> E.t) to_ +module ConstraintSolver = struct + module Equalities = struct + include ImperativeUnionFind.Make (ControlFlowCost.Set) - let normalizer t e = (find t e :> E.t) - end + let normalizer equalities e = (find equalities e :> ControlFlowCost.t) - module Set = E.Set + let pp_repr fmt (repr: Repr.t) = ControlFlowCost.pp fmt (repr :> ControlFlowCost.t) - module Sets = struct - type t = (Repr.t, Set.t) H.t + let pp_equalities fmt equalities = + let pp_item fmt (repr, set) = + F.fprintf fmt "%a --> %a" pp_repr repr ControlFlowCost.Set.pp_equalities set + in + IContainer.pp_collection ~fold:fold_sets ~pp_item fmt equalities - let create () = H.create 1 - let find t (r: Repr.t) = - match H.find_opt t r with - | Some set -> - set + let log_union equalities e1 e2 = + match union equalities e1 e2 with | None -> - let set = Set.create (r :> E.t) in - H.replace t r set ; set - - - let merge_no_remove _t ~from:(from_r, from_set) ~to_:(_, to_set) = - Set.merge_invalidate ~from:from_set ~to_:to_set ; - from_r - - - let merge t ~from:(from_r, from_set) ~to_:(_, to_set) = - H.remove t from_r ; - Set.merge_invalidate ~from:from_set ~to_:to_set - - - let pp_equalities fmt t = H.pp_bindings Repr.pp Set.pp_equalities fmt t - - let fold_equalities t ~init ~f = H.fold t ~init ~f - - let normalize_sums ~normalizer t = - H.iter (fun _repr set -> Set.normalize_sums ~normalizer set) t - - - let infer_equalities_from_sums ~on_infer ~normalizer t = - H.iter (fun _repr set -> Set.infer_equalities_from_sums ~on_infer ~normalizer set) t - - - let remove_list t rs = List.iter rs ~f:(H.remove t) - end - - (** - Data-structure for disjoint sets. - [reprs] is the mapping element -> representative - [sets] is the mapping representative -> set - - It implements path-compression and union by size, hence find and union are amortized O(1)-ish. - *) - type t = {reprs: Reprs.t; sets: Sets.t} - - let create () = {reprs= Reprs.create (); sets= Sets.create ()} - - let do_merge t ~from ~to_ ~merge_sets = - let to_r, _ = to_ in - let from_r, _ = from in - Reprs.merge t.reprs ~from:from_r ~to_:to_r ; - merge_sets ~from ~to_ - - - let union_with_merge t e1 e2 ~merge_sets = - let repr1 = Reprs.find t.reprs e1 in - let repr2 = Reprs.find t.reprs e2 in - if Repr.equal repr1 repr2 then None - else - let set1 = Sets.find t.sets repr1 in - let set2 = Sets.find t.sets repr2 in - let size1 = Set.size set1 in - let size2 = Set.size set2 in - if size1 < size2 || (Int.equal size1 size2 && Repr.is_simpler_than repr2 repr1) then ( - (* A desired side-effect of using [is_simpler_than] is that the representative for a set will always be a [`Node]. For now. *) - do_merge t ~from:(repr1, set1) ~to_:(repr2, set2) ~merge_sets ; - Some (e1, e2) ) - else ( - do_merge t ~from:(repr2, set2) ~to_:(repr1, set1) ~merge_sets ; - Some (e2, e1) ) - - - let union_log t e1 e2 ~merge_sets = - match union_with_merge t e1 e2 ~merge_sets with - | None -> - L.(debug Analysis Verbose) "[UF] Preexisting %a = %a@\n" E.pp e1 E.pp e2 - | Some (e1, e2) -> - L.(debug Analysis Verbose) "[UF] Union %a into %a@\n" E.pp e1 E.pp e2 - - - let union t e1 e2 = union_log t e1 e2 ~merge_sets:(Sets.merge t.sets) - - let union_defer_remove t e1 e2 = - let res = ref None in - let merge_sets ~from ~to_ = res := Some (Sets.merge_no_remove t.sets ~from ~to_) in - union_log t e1 e2 ~merge_sets ; !res - + L.(debug Analysis Verbose) + "[UF] Preexisting %a = %a@\n" ControlFlowCost.pp e1 ControlFlowCost.pp e2 ; + false + | Some (e1, e2) -> + L.(debug Analysis Verbose) + "[UF] Union %a into %a@\n" ControlFlowCost.pp e1 ControlFlowCost.pp e2 ; + true - let pp_equalities fmt t = Sets.pp_equalities fmt t.sets - let fold_equalities t ~init ~f = Sets.fold_equalities t.sets ~init ~f - - let normalizer t = Reprs.normalizer t.reprs - - let normalize_sums t = Sets.normalize_sums ~normalizer:(normalizer t) t.sets - - (** + (** Infer equalities from sums, like this: (1) A + sum1 = A + sum2 => sum1 = sum2 @@ -592,29 +424,34 @@ struct Its complexity is unknown but I think it is bounded by nbNodes x nbEdges x max. *) - let infer_equalities_from_sums t ~max = - let normalizer = normalizer t in - let on_infer sets_to_remove e1 e2 = - (* need to defer removes to avoid iterator invalidation *) - union_defer_remove t e1 e2 - |> Option.iter ~f:(fun set_to_remove -> sets_to_remove := set_to_remove :: !sets_to_remove) - in - let rec loop max = - let sets_to_remove = ref [] in - let on_infer = on_infer sets_to_remove in - Sets.infer_equalities_from_sums ~on_infer ~normalizer t.sets ; - if not (List.is_empty !sets_to_remove) then ( - Sets.remove_list t.sets !sets_to_remove ; - L.(debug Analysis Verbose) "[ConstraintSolver] %a@\n" pp_equalities t ; - if max > 0 then loop (max - 1) - else - L.(debug Analysis Verbose) "[ConstraintSolver] Maximum number of iterations reached@\n" ) - in - loop max -end + let infer_equalities_from_sums equalities ~max = + let normalizer = normalizer equalities in + let f did_infer (_repr, set) = + let did_infer = ref did_infer in + let on_infer e1 e2 = if log_union equalities e1 e2 then did_infer := true in + ControlFlowCost.Set.infer_equalities_from_sums ~on_infer ~normalizer set ; + !did_infer + in + let rec loop max = + if fold_sets equalities ~init:false ~f then ( + L.(debug Analysis Verbose) "[ConstraintSolver] %a@\n" pp_equalities equalities ; + if max > 0 then loop (max - 1) + else + L.(debug Analysis Verbose) "[ConstraintSolver] Maximum number of iterations reached@\n" ) + in + loop max -module ConstraintSolver = struct - module Equalities = ImperativeUnionFind (ControlFlowCost) + + let normalize_sums equalities = + let normalizer = normalizer equalities in + Container.iter ~fold:fold_sets equalities ~f:(fun (_repr, set) -> + ControlFlowCost.Set.normalize_sums ~normalizer set ) + + + let union equalities e1 e2 = + let _ : bool = log_union equalities e1 e2 in + () + end let add_constraints equalities node get_nodes make = match get_nodes node with @@ -870,15 +707,15 @@ module MinTree = struct let start_node = Node.id (NodeCFG.start_node node_cfg) in let start_node_item = ControlFlowCost.Item.of_node start_node in let eqs = ConstraintSolver.collect_constraints node_cfg in - let start_node_reprs = ConstraintSolver.Equalities.Reprs.find eqs.reprs start_node_item in + let start_node_reprs = ConstraintSolver.Equalities.find eqs start_node_item in L.(debug Analysis Verbose) "@\n =========== Computed Equalities ==========@\n" ; L.(debug Analysis Verbose) "[Equalities] %a@\n" ConstraintSolver.Equalities.pp_equalities eqs ; let minimum_propagation = with_cache (minimum_propagation bound_map constraints) |> Staged.unstage in let min_trees, representative_map = - ConstraintSolver.Equalities.fold_equalities eqs ~init:(Node.IdMap.empty, Node.IdMap.empty) - ~f:(fun (acc_trees, acc_representative) (rep, eq_cl) -> + ConstraintSolver.Equalities.fold_sets eqs ~init:(Node.IdMap.empty, Node.IdMap.empty) ~f: + (fun (acc_trees, acc_representative) (rep, eq_cl) -> let rep_id = match ControlFlowCost.is_node (rep :> ControlFlowCost.t) with | Some nid -> diff --git a/infer/src/istd/IContainer.ml b/infer/src/istd/IContainer.ml index b207f13b7..e4bea2c3d 100644 --- a/infer/src/istd/IContainer.ml +++ b/infer/src/istd/IContainer.ml @@ -58,3 +58,7 @@ let pp_collection ~fold ~pp_item fmt c = in let pp_aux fmt c = fold c ~init:None ~f |> Option.iter ~f:(F.fprintf fmt "@[%a@] " pp_item) in F.fprintf fmt "@[{ %a}@]" pp_aux c + + +let filter ~fold ~filter t ~init ~f = + fold t ~init ~f:(fun acc item -> if filter item then f acc item else acc) diff --git a/infer/src/istd/IContainer.mli b/infer/src/istd/IContainer.mli index 25f39876c..de286c021 100644 --- a/infer/src/istd/IContainer.mli +++ b/infer/src/istd/IContainer.mli @@ -32,3 +32,6 @@ val iter_consecutive : val pp_collection : fold:('t, 'a, 'a option) Container.fold -> pp_item:(F.formatter -> 'a -> unit) -> F.formatter -> 't -> unit + +val filter : + fold:('t, 'a, 'accum) Container.fold -> filter:('a -> bool) -> ('t, 'a, 'accum) Container.fold diff --git a/infer/src/istd/ImperativeUnionFind.ml b/infer/src/istd/ImperativeUnionFind.ml new file mode 100644 index 000000000..1847c583e --- /dev/null +++ b/infer/src/istd/ImperativeUnionFind.ml @@ -0,0 +1,149 @@ +(* + * Copyright (c) 2018-present, Facebook, Inc. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + *) + +open! IStd + +module type Set = sig + type elt [@@deriving compare] + + type t + + val create : elt -> t + + val compare_size : t -> t -> int + + val merge : from:t -> to_:t -> unit +end + +module Make (Set : Set) = struct + module H = struct + include Caml.Hashtbl + + let fold : (('a, 'b) t, 'a * 'b, 'accum) Container.fold = + fun h ~init ~f -> + let f' k v accum = f accum (k, v) in + fold f' h init + end + + module Repr : sig + (* Sort-of abstracting away the fact that a representative is just an element itself. + This ensures that the [Sets] hashtable is accessed with representative only. *) + + type t = private Set.elt + + val equal : t -> t -> bool + + val of_elt : Set.elt -> t + + val is_simpler_than : t -> t -> bool + end = struct + type t = Set.elt [@@deriving compare] + + let equal = [%compare.equal : t] + + let of_elt e = e + + let is_simpler_than r1 r2 = compare r1 r2 <= 0 + end + + module Reprs = struct + type t = (Set.elt, Repr.t) H.t + + let create () = H.create 1 + + let is_a_repr (t: t) e = not (H.mem t e) + + let rec find (t: t) e : Repr.t = + match H.find_opt t e with + | None -> + Repr.of_elt e + | Some r -> + let r' = find t (r :> Set.elt) in + if not (phys_equal r r') then H.replace t e r' ; + r' + + + let merge (t: t) ~(from: Repr.t) ~(to_: Repr.t) = H.replace t (from :> Set.elt) to_ + end + + module Sets = struct + type t = (Repr.t, Set.t) H.t + + let create () = H.create 1 + + let find_create t (r: Repr.t) = + match H.find_opt t r with + | Some set -> + set + | None -> + let set = Set.create (r :> Set.elt) in + H.replace t r set ; set + + + let fold = H.fold + + let remove_now t r = H.remove t r + end + + (** + Data-structure for disjoint sets. + [reprs] is the mapping element -> representative + [sets] is the mapping representative -> set + + It implements path-compression and union by size, hence find and union are amortized O(1)-ish. + + [nb_iterators] and [to_remove] are used to defer removing elements to avoid iterator invalidation during fold. + *) + type t = {reprs: Reprs.t; sets: Sets.t; mutable nb_iterators: int; mutable to_remove: Repr.t list} + + let create () = {reprs= Reprs.create (); sets= Sets.create (); nb_iterators= 0; to_remove= []} + + let find t e = Reprs.find t.reprs e + + let do_merge t ~from_r ~from_set ~to_r ~to_set = + Reprs.merge t.reprs ~from:from_r ~to_:to_r ; + Set.merge ~from:from_set ~to_:to_set ; + if t.nb_iterators <= 0 then Sets.remove_now t.sets from_r + else t.to_remove <- from_r :: t.to_remove + + + let union t e1 e2 = + let repr1 = find t e1 in + let repr2 = find t e2 in + if Repr.equal repr1 repr2 then None + else + let set1 = Sets.find_create t.sets repr1 in + let set2 = Sets.find_create t.sets repr2 in + let cmp_size = Set.compare_size set1 set2 in + if cmp_size < 0 || (Int.equal cmp_size 0 && Repr.is_simpler_than repr2 repr1) then ( + (* A desired side-effect of using [is_simpler_than] is that the representative for a set will always be a [`Node]. For now. *) + do_merge t ~from_r:repr1 ~from_set:set1 ~to_r:repr2 ~to_set:set2 ; + Some (e1, e2) ) + else ( + do_merge t ~from_r:repr2 ~from_set:set2 ~to_r:repr1 ~to_set:set1 ; + Some (e2, e1) ) + + + let is_still_a_repr t ((repr: Repr.t), _) = Reprs.is_a_repr t.reprs (repr :> Set.elt) + + let after_fold t = + let new_nb_iterators = t.nb_iterators - 1 in + t.nb_iterators <- new_nb_iterators ; + if new_nb_iterators <= 0 && not (List.is_empty t.to_remove) then ( + List.iter t.to_remove ~f:(Sets.remove_now t.sets) ; + t.to_remove <- [] ) + + + let fold_sets t ~init ~f = + t.nb_iterators <- t.nb_iterators + 1 ; + match IContainer.filter ~fold:Sets.fold ~filter:(is_still_a_repr t) t.sets ~init ~f with + | result -> + after_fold t ; result + | exception e -> + (* Ensures [nb_iterators] is correct *) + IExn.reraise_after ~f:(fun () -> after_fold t) e +end diff --git a/infer/src/istd/ImperativeUnionFind.mli b/infer/src/istd/ImperativeUnionFind.mli new file mode 100644 index 000000000..9397c9459 --- /dev/null +++ b/infer/src/istd/ImperativeUnionFind.mli @@ -0,0 +1,40 @@ +(* + * Copyright (c) 2018-present, Facebook, Inc. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + *) + +open! IStd + +module type Set = sig + type elt [@@deriving compare] + + type t + + val create : elt -> t + + val compare_size : t -> t -> int + + val merge : from:t -> to_:t -> unit +end + +module Make (Set : Set) : sig + module Repr : sig + type t = private Set.elt + + val equal : t -> t -> bool + end + + type t + + val create : unit -> t + + val find : t -> Set.elt -> Repr.t + + val union : t -> Set.elt -> Set.elt -> (Set.elt * Set.elt) option + (** [union t e1 e2] returns [None] if [e1] and [e2] were already in the same set, [Some (a, b)] if [a] is merged into [b] (were [(a, b)] is either [(e1, e2)] or [(e2, e1)]). *) + + val fold_sets : (t, Repr.t * Set.t, 'accum) Container.fold + (** It is safe to call [find] or [union] while folding. *) +end