Cost: move union-find to its own module

Reviewed By: ddino

Differential Revision: D8348274

fbshipit-source-id: 6bd6ce6
master
Mehdi Bouaziz 7 years ago committed by Facebook Github Bot
parent 21ced6af62
commit 9ae3b42aef

@ -301,8 +301,6 @@ module ControlFlowCost = struct
Sum.compare x y
let equal = [%compare.equal : t]
let make_node node = `Node node
let make_pred_edge succ pred = `Edge (pred, succ)
@ -316,7 +314,7 @@ module ControlFlowCost = struct
let sum : Item.t list -> t = function [] -> assert false | [e] -> (e :> t) | l -> Sum.of_list l
module Set = struct
type elt = t
type elt = t [@@deriving compare]
type t = {mutable size: int; mutable items: Item.t ARList.t; mutable sums: Sum.t ARList.t}
@ -331,13 +329,13 @@ module ControlFlowCost = struct
{size= 1; items; sums}
(* Because we are modifying things on which we are iterating, we want to invalidate them first before removing them from hashtables, to avoid iterator invalidation. *)
let is_valid {size} = size >= 1
let compare_size {size= size1} {size= size2} = Int.compare size1 size2
let size {size} = size
(* Invalidation is just a sanity check, union-find already takes care of it. *)
let is_valid {size} = size >= 1
(* move semantics, should not be called with aliases *)
let merge_invalidate ~from ~to_ =
let merge ~from ~to_ =
assert (not (phys_equal from to_)) ;
assert (is_valid from) ;
assert (is_valid to_) ;
@ -348,15 +346,13 @@ module ControlFlowCost = struct
let pp_equalities fmt t =
if is_valid t then
ARList.append (t.items :> elt ARList.t) (t.sums :> elt ARList.t)
|> IContainer.to_rev_list ~fold:ARList.fold_unordered |> List.sort ~compare
|> Pp.seq ~sep:" = " pp fmt
ARList.append (t.items :> elt ARList.t) (t.sums :> elt ARList.t)
|> IContainer.to_rev_list ~fold:ARList.fold_unordered |> List.sort ~compare
|> Pp.seq ~sep:" = " pp fmt
let normalize_sums : normalizer:(elt -> elt) -> t -> unit =
fun ~normalizer t ->
assert (is_valid t) ;
t.sums
<- t.sums
|> IContainer.rev_map_to_list ~fold:ARList.fold_unordered ~f:(Sum.normalize ~normalizer)
@ -373,214 +369,50 @@ module ControlFlowCost = struct
let infer_equalities_from_sums
: on_infer:(elt -> elt -> unit) -> normalizer:(elt -> elt) -> t -> unit =
fun ~on_infer ~normalizer t ->
if is_valid t then (
normalize_sums ~normalizer t ;
let all_items =
t.sums
|> ARList.fold_unordered ~init:ARList.empty ~f:(fun acc sum ->
sum |> Sum.items |> ARList.of_list |> ARList.append acc )
|> IContainer.to_rev_list ~fold:ARList.fold_unordered
|> List.dedup_and_sort ~compare:Item.compare
in
(* Keep in mind that [on_infer] can modify (and possibly invalidate) [t].
It happens only if we merge a node while infer equalities from it, i.e. in the case an item appears in an equality class both alone and in two sums, i.e. X = A + X = A + B.
normalize_sums ~normalizer t ;
let all_items =
t.sums
|> ARList.fold_unordered ~init:ARList.empty ~f:(fun acc sum ->
sum |> Sum.items |> ARList.of_list |> ARList.append acc )
|> IContainer.to_rev_list ~fold:ARList.fold_unordered
|> List.dedup_and_sort ~compare:Item.compare
in
(* Keep in mind that [on_infer] can modify [t].
It happens only if we merge a node while infering equalities from it, i.e. in the case an item appears in an equality class both alone and in two sums, i.e. X = A + X = A + B.
This is not a problem here (we could stop if it happens but it is not necessary as existing equalities still remain true after merges) *)
(* Also keep in mind that the current version, in the worst-case scenario, is quadratic-ish in the size of the CFG *)
List.iter all_items ~f:(fun item -> infer_equalities_by_removing_item ~on_infer t item) )
(* Also keep in mind that the current version, in the worst-case scenario, is quadratic-ish in the size of the CFG *)
List.iter all_items ~f:(fun item -> infer_equalities_by_removing_item ~on_infer t item)
end
end
module ImperativeUnionFind (E : sig
type t [@@deriving compare]
val equal : t -> t -> bool
val pp : F.formatter -> t -> unit
module Set : sig
type elt = t
type t
val create : elt -> t
val size : t -> int
val merge_invalidate : from:t -> to_:t -> unit
val pp_equalities : F.formatter -> t -> unit
val normalize_sums : normalizer:(elt -> elt) -> t -> unit
val infer_equalities_from_sums :
on_infer:(elt -> elt -> unit) -> normalizer:(elt -> elt) -> t -> unit
end
end) =
struct
module H = struct
include Caml.Hashtbl
let caml_fold = fold
let fold : (('a, 'b) t, 'a * 'b, 'accum) Container.fold =
fun h ~init ~f ->
let f' k v accum = f accum (k, v) in
caml_fold f' h init
let pp_bindings pp_key pp_value fmt h =
let pp_item fmt (k, v) = F.fprintf fmt "%a --> %a" pp_key k pp_value v in
IContainer.pp_collection ~fold ~pp_item fmt h
let[@warning "-32"] pp_values pp_value fmt h =
let pp_item fmt (_, v) = pp_value fmt v in
IContainer.pp_collection ~fold ~pp_item fmt h
end
module Repr : sig
(* Sort-of abstracting away the fact that a representative is just an element itself.
This ensures that the [Sets] hashtable is accessed with representative and not just elements. *)
type t = private E.t
val equal : t -> t -> bool
val pp : F.formatter -> t -> unit
val of_e : E.t -> t
val is_simpler_than : t -> t -> bool
end = struct
include E
let of_e e = e
let is_simpler_than r1 r2 = compare r1 r2 <= 0
end
module Reprs = struct
type t = (E.t, Repr.t) H.t
let create () = H.create 1
let rec find (t: t) e : Repr.t =
match H.find_opt t e with
| None ->
Repr.of_e e
| Some r ->
let r' = find t (r :> E.t) in
if not (phys_equal r r') then H.replace t e r' ;
r'
let merge (t: t) ~(from: Repr.t) ~(to_: Repr.t) = H.replace t (from :> E.t) to_
module ConstraintSolver = struct
module Equalities = struct
include ImperativeUnionFind.Make (ControlFlowCost.Set)
let normalizer t e = (find t e :> E.t)
end
let normalizer equalities e = (find equalities e :> ControlFlowCost.t)
module Set = E.Set
let pp_repr fmt (repr: Repr.t) = ControlFlowCost.pp fmt (repr :> ControlFlowCost.t)
module Sets = struct
type t = (Repr.t, Set.t) H.t
let pp_equalities fmt equalities =
let pp_item fmt (repr, set) =
F.fprintf fmt "%a --> %a" pp_repr repr ControlFlowCost.Set.pp_equalities set
in
IContainer.pp_collection ~fold:fold_sets ~pp_item fmt equalities
let create () = H.create 1
let find t (r: Repr.t) =
match H.find_opt t r with
| Some set ->
set
let log_union equalities e1 e2 =
match union equalities e1 e2 with
| None ->
let set = Set.create (r :> E.t) in
H.replace t r set ; set
let merge_no_remove _t ~from:(from_r, from_set) ~to_:(_, to_set) =
Set.merge_invalidate ~from:from_set ~to_:to_set ;
from_r
let merge t ~from:(from_r, from_set) ~to_:(_, to_set) =
H.remove t from_r ;
Set.merge_invalidate ~from:from_set ~to_:to_set
let pp_equalities fmt t = H.pp_bindings Repr.pp Set.pp_equalities fmt t
let fold_equalities t ~init ~f = H.fold t ~init ~f
let normalize_sums ~normalizer t =
H.iter (fun _repr set -> Set.normalize_sums ~normalizer set) t
let infer_equalities_from_sums ~on_infer ~normalizer t =
H.iter (fun _repr set -> Set.infer_equalities_from_sums ~on_infer ~normalizer set) t
let remove_list t rs = List.iter rs ~f:(H.remove t)
end
(**
Data-structure for disjoint sets.
[reprs] is the mapping element -> representative
[sets] is the mapping representative -> set
It implements path-compression and union by size, hence find and union are amortized O(1)-ish.
*)
type t = {reprs: Reprs.t; sets: Sets.t}
let create () = {reprs= Reprs.create (); sets= Sets.create ()}
let do_merge t ~from ~to_ ~merge_sets =
let to_r, _ = to_ in
let from_r, _ = from in
Reprs.merge t.reprs ~from:from_r ~to_:to_r ;
merge_sets ~from ~to_
let union_with_merge t e1 e2 ~merge_sets =
let repr1 = Reprs.find t.reprs e1 in
let repr2 = Reprs.find t.reprs e2 in
if Repr.equal repr1 repr2 then None
else
let set1 = Sets.find t.sets repr1 in
let set2 = Sets.find t.sets repr2 in
let size1 = Set.size set1 in
let size2 = Set.size set2 in
if size1 < size2 || (Int.equal size1 size2 && Repr.is_simpler_than repr2 repr1) then (
(* A desired side-effect of using [is_simpler_than] is that the representative for a set will always be a [`Node]. For now. *)
do_merge t ~from:(repr1, set1) ~to_:(repr2, set2) ~merge_sets ;
Some (e1, e2) )
else (
do_merge t ~from:(repr2, set2) ~to_:(repr1, set1) ~merge_sets ;
Some (e2, e1) )
let union_log t e1 e2 ~merge_sets =
match union_with_merge t e1 e2 ~merge_sets with
| None ->
L.(debug Analysis Verbose) "[UF] Preexisting %a = %a@\n" E.pp e1 E.pp e2
| Some (e1, e2) ->
L.(debug Analysis Verbose) "[UF] Union %a into %a@\n" E.pp e1 E.pp e2
let union t e1 e2 = union_log t e1 e2 ~merge_sets:(Sets.merge t.sets)
let union_defer_remove t e1 e2 =
let res = ref None in
let merge_sets ~from ~to_ = res := Some (Sets.merge_no_remove t.sets ~from ~to_) in
union_log t e1 e2 ~merge_sets ; !res
L.(debug Analysis Verbose)
"[UF] Preexisting %a = %a@\n" ControlFlowCost.pp e1 ControlFlowCost.pp e2 ;
false
| Some (e1, e2) ->
L.(debug Analysis Verbose)
"[UF] Union %a into %a@\n" ControlFlowCost.pp e1 ControlFlowCost.pp e2 ;
true
let pp_equalities fmt t = Sets.pp_equalities fmt t.sets
let fold_equalities t ~init ~f = Sets.fold_equalities t.sets ~init ~f
let normalizer t = Reprs.normalizer t.reprs
let normalize_sums t = Sets.normalize_sums ~normalizer:(normalizer t) t.sets
(**
(**
Infer equalities from sums, like this:
(1) A + sum1 = A + sum2 => sum1 = sum2
@ -592,29 +424,34 @@ struct
Its complexity is unknown but I think it is bounded by nbNodes x nbEdges x max.
*)
let infer_equalities_from_sums t ~max =
let normalizer = normalizer t in
let on_infer sets_to_remove e1 e2 =
(* need to defer removes to avoid iterator invalidation *)
union_defer_remove t e1 e2
|> Option.iter ~f:(fun set_to_remove -> sets_to_remove := set_to_remove :: !sets_to_remove)
in
let rec loop max =
let sets_to_remove = ref [] in
let on_infer = on_infer sets_to_remove in
Sets.infer_equalities_from_sums ~on_infer ~normalizer t.sets ;
if not (List.is_empty !sets_to_remove) then (
Sets.remove_list t.sets !sets_to_remove ;
L.(debug Analysis Verbose) "[ConstraintSolver] %a@\n" pp_equalities t ;
if max > 0 then loop (max - 1)
else
L.(debug Analysis Verbose) "[ConstraintSolver] Maximum number of iterations reached@\n" )
in
loop max
end
let infer_equalities_from_sums equalities ~max =
let normalizer = normalizer equalities in
let f did_infer (_repr, set) =
let did_infer = ref did_infer in
let on_infer e1 e2 = if log_union equalities e1 e2 then did_infer := true in
ControlFlowCost.Set.infer_equalities_from_sums ~on_infer ~normalizer set ;
!did_infer
in
let rec loop max =
if fold_sets equalities ~init:false ~f then (
L.(debug Analysis Verbose) "[ConstraintSolver] %a@\n" pp_equalities equalities ;
if max > 0 then loop (max - 1)
else
L.(debug Analysis Verbose) "[ConstraintSolver] Maximum number of iterations reached@\n" )
in
loop max
module ConstraintSolver = struct
module Equalities = ImperativeUnionFind (ControlFlowCost)
let normalize_sums equalities =
let normalizer = normalizer equalities in
Container.iter ~fold:fold_sets equalities ~f:(fun (_repr, set) ->
ControlFlowCost.Set.normalize_sums ~normalizer set )
let union equalities e1 e2 =
let _ : bool = log_union equalities e1 e2 in
()
end
let add_constraints equalities node get_nodes make =
match get_nodes node with
@ -870,15 +707,15 @@ module MinTree = struct
let start_node = Node.id (NodeCFG.start_node node_cfg) in
let start_node_item = ControlFlowCost.Item.of_node start_node in
let eqs = ConstraintSolver.collect_constraints node_cfg in
let start_node_reprs = ConstraintSolver.Equalities.Reprs.find eqs.reprs start_node_item in
let start_node_reprs = ConstraintSolver.Equalities.find eqs start_node_item in
L.(debug Analysis Verbose) "@\n =========== Computed Equalities ==========@\n" ;
L.(debug Analysis Verbose) "[Equalities] %a@\n" ConstraintSolver.Equalities.pp_equalities eqs ;
let minimum_propagation =
with_cache (minimum_propagation bound_map constraints) |> Staged.unstage
in
let min_trees, representative_map =
ConstraintSolver.Equalities.fold_equalities eqs ~init:(Node.IdMap.empty, Node.IdMap.empty)
~f:(fun (acc_trees, acc_representative) (rep, eq_cl) ->
ConstraintSolver.Equalities.fold_sets eqs ~init:(Node.IdMap.empty, Node.IdMap.empty) ~f:
(fun (acc_trees, acc_representative) (rep, eq_cl) ->
let rep_id =
match ControlFlowCost.is_node (rep :> ControlFlowCost.t) with
| Some nid ->

@ -58,3 +58,7 @@ let pp_collection ~fold ~pp_item fmt c =
in
let pp_aux fmt c = fold c ~init:None ~f |> Option.iter ~f:(F.fprintf fmt "@[<h>%a@] " pp_item) in
F.fprintf fmt "@[<hv 2>{ %a}@]" pp_aux c
let filter ~fold ~filter t ~init ~f =
fold t ~init ~f:(fun acc item -> if filter item then f acc item else acc)

@ -32,3 +32,6 @@ val iter_consecutive :
val pp_collection :
fold:('t, 'a, 'a option) Container.fold -> pp_item:(F.formatter -> 'a -> unit) -> F.formatter
-> 't -> unit
val filter :
fold:('t, 'a, 'accum) Container.fold -> filter:('a -> bool) -> ('t, 'a, 'accum) Container.fold

@ -0,0 +1,149 @@
(*
* Copyright (c) 2018-present, Facebook, Inc.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*)
open! IStd
module type Set = sig
type elt [@@deriving compare]
type t
val create : elt -> t
val compare_size : t -> t -> int
val merge : from:t -> to_:t -> unit
end
module Make (Set : Set) = struct
module H = struct
include Caml.Hashtbl
let fold : (('a, 'b) t, 'a * 'b, 'accum) Container.fold =
fun h ~init ~f ->
let f' k v accum = f accum (k, v) in
fold f' h init
end
module Repr : sig
(* Sort-of abstracting away the fact that a representative is just an element itself.
This ensures that the [Sets] hashtable is accessed with representative only. *)
type t = private Set.elt
val equal : t -> t -> bool
val of_elt : Set.elt -> t
val is_simpler_than : t -> t -> bool
end = struct
type t = Set.elt [@@deriving compare]
let equal = [%compare.equal : t]
let of_elt e = e
let is_simpler_than r1 r2 = compare r1 r2 <= 0
end
module Reprs = struct
type t = (Set.elt, Repr.t) H.t
let create () = H.create 1
let is_a_repr (t: t) e = not (H.mem t e)
let rec find (t: t) e : Repr.t =
match H.find_opt t e with
| None ->
Repr.of_elt e
| Some r ->
let r' = find t (r :> Set.elt) in
if not (phys_equal r r') then H.replace t e r' ;
r'
let merge (t: t) ~(from: Repr.t) ~(to_: Repr.t) = H.replace t (from :> Set.elt) to_
end
module Sets = struct
type t = (Repr.t, Set.t) H.t
let create () = H.create 1
let find_create t (r: Repr.t) =
match H.find_opt t r with
| Some set ->
set
| None ->
let set = Set.create (r :> Set.elt) in
H.replace t r set ; set
let fold = H.fold
let remove_now t r = H.remove t r
end
(**
Data-structure for disjoint sets.
[reprs] is the mapping element -> representative
[sets] is the mapping representative -> set
It implements path-compression and union by size, hence find and union are amortized O(1)-ish.
[nb_iterators] and [to_remove] are used to defer removing elements to avoid iterator invalidation during fold.
*)
type t = {reprs: Reprs.t; sets: Sets.t; mutable nb_iterators: int; mutable to_remove: Repr.t list}
let create () = {reprs= Reprs.create (); sets= Sets.create (); nb_iterators= 0; to_remove= []}
let find t e = Reprs.find t.reprs e
let do_merge t ~from_r ~from_set ~to_r ~to_set =
Reprs.merge t.reprs ~from:from_r ~to_:to_r ;
Set.merge ~from:from_set ~to_:to_set ;
if t.nb_iterators <= 0 then Sets.remove_now t.sets from_r
else t.to_remove <- from_r :: t.to_remove
let union t e1 e2 =
let repr1 = find t e1 in
let repr2 = find t e2 in
if Repr.equal repr1 repr2 then None
else
let set1 = Sets.find_create t.sets repr1 in
let set2 = Sets.find_create t.sets repr2 in
let cmp_size = Set.compare_size set1 set2 in
if cmp_size < 0 || (Int.equal cmp_size 0 && Repr.is_simpler_than repr2 repr1) then (
(* A desired side-effect of using [is_simpler_than] is that the representative for a set will always be a [`Node]. For now. *)
do_merge t ~from_r:repr1 ~from_set:set1 ~to_r:repr2 ~to_set:set2 ;
Some (e1, e2) )
else (
do_merge t ~from_r:repr2 ~from_set:set2 ~to_r:repr1 ~to_set:set1 ;
Some (e2, e1) )
let is_still_a_repr t ((repr: Repr.t), _) = Reprs.is_a_repr t.reprs (repr :> Set.elt)
let after_fold t =
let new_nb_iterators = t.nb_iterators - 1 in
t.nb_iterators <- new_nb_iterators ;
if new_nb_iterators <= 0 && not (List.is_empty t.to_remove) then (
List.iter t.to_remove ~f:(Sets.remove_now t.sets) ;
t.to_remove <- [] )
let fold_sets t ~init ~f =
t.nb_iterators <- t.nb_iterators + 1 ;
match IContainer.filter ~fold:Sets.fold ~filter:(is_still_a_repr t) t.sets ~init ~f with
| result ->
after_fold t ; result
| exception e ->
(* Ensures [nb_iterators] is correct *)
IExn.reraise_after ~f:(fun () -> after_fold t) e
end

@ -0,0 +1,40 @@
(*
* Copyright (c) 2018-present, Facebook, Inc.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*)
open! IStd
module type Set = sig
type elt [@@deriving compare]
type t
val create : elt -> t
val compare_size : t -> t -> int
val merge : from:t -> to_:t -> unit
end
module Make (Set : Set) : sig
module Repr : sig
type t = private Set.elt
val equal : t -> t -> bool
end
type t
val create : unit -> t
val find : t -> Set.elt -> Repr.t
val union : t -> Set.elt -> Set.elt -> (Set.elt * Set.elt) option
(** [union t e1 e2] returns [None] if [e1] and [e2] were already in the same set, [Some (a, b)] if [a] is merged into [b] (were [(a, b)] is either [(e1, e2)] or [(e2, e1)]). *)
val fold_sets : (t, Repr.t * Set.t, 'accum) Container.fold
(** It is safe to call [find] or [union] while folding. *)
end
Loading…
Cancel
Save