From aee02e27ef26b0accd059a8fbc5d0219b535d95b Mon Sep 17 00:00:00 2001 From: Mehdi Bouaziz Date: Thu, 7 Jun 2018 08:43:33 -0700 Subject: [PATCH] Cost: control-flow equality classes Reviewed By: ddino Differential Revision: D8186669 fbshipit-source-id: e919c16 --- infer/src/checkers/cost.ml | 436 ++++++++++++++++++++++++++++++ infer/src/istd/ARList.mli | 4 +- infer/src/istd/IContainer.ml | 31 +++ infer/src/istd/IContainer.mli | 15 + infer/src/istd/IList.ml | 10 + infer/src/istd/IList.mli | 2 + infer/src/istd/PrettyPrintable.ml | 13 +- 7 files changed, 496 insertions(+), 15 deletions(-) diff --git a/infer/src/checkers/cost.ml b/infer/src/checkers/cost.ml index 00661d376..ce8f51623 100644 --- a/infer/src/checkers/cost.ml +++ b/infer/src/checkers/cost.ml @@ -204,6 +204,441 @@ module BoundMap = struct BasicCost.top end +module ControlFlowCost = struct + (* A Control-flow cost represents the number of times the flow of control can go through a certain CFG item (a node or an edge), + or a sum of such things *) + + module Item = struct + type t = [`Node of Node.id | `Edge of Node.id * Node.id] + + let compare : t -> t -> int = + fun x y -> + match (x, y) with + | `Node id1, `Node id2 -> + Node.compare_id id1 id2 + | `Node _, `Edge _ -> + -1 + | `Edge _, `Node _ -> + 1 + | `Edge (f1, t1), `Edge (f2, t2) -> + [%compare : Node.id * Node.id] (f1, t1) (f2, t2) + + + let equal = [%compare.equal : t] + + let pp : F.formatter -> t -> unit = + fun fmt -> function + | `Node id -> + F.fprintf fmt "Node(%a)" Node.pp_id id + | `Edge (f, t) -> + F.fprintf fmt "Edge(%a -> %a)" Node.pp_id f Node.pp_id t + + + let normalize ~(normalizer: t -> [> t]) (x: t) : t = + match normalizer x with #t as x -> x | _ -> assert false + end + + module Sum = struct + type 'a set = (* non-empty sorted list *) 'a list + + type t = [`Sum of int * Item.t set] + + let of_list l = + let length = List.length l in + let set = List.sort ~compare:Item.compare l in + `Sum (length, set) + + + let compare : t -> t -> int = + fun (`Sum (l1, s1)) (`Sum (l2, s2)) -> [%compare : int * Item.t list] (l1, s1) (l2, s2) + + + let pp : F.formatter -> t -> unit = + fun fmt (`Sum (_, set)) -> Pp.seq ~sep:" + " Item.pp fmt set + + + let items (`Sum (_, l)) = l + + let normalized_items ~normalizer (`Sum (_, l)) = + let normalizer = (normalizer :> Item.t -> [> Item.t]) in + l |> List.rev_map ~f:(Item.normalize ~normalizer) + + + let normalize ~normalizer sum = sum |> normalized_items ~normalizer |> of_list + + (* Given a sum and an item, remove one occurence of the item in the sum. Returns [None] if the item is not present in the sum. + [remove_one_item ~item:A (A + B)] = B + [remove_one_item ~item:A (A + B + C)] = B + C + [remove_one_item ~item:A (A + A + B)] = A + B + [remove_one_item ~item:A (B + C)] = None + *) + let remove_one_item ~item (`Sum (len, l)) = + match IList.remove_first l ~f:(Item.equal item) with + | None -> + None + | Some [e] -> + Some (e :> [Item.t | t]) + | Some l -> + Some (`Sum (len - 1, l)) + end + + type t = [Item.t | Sum.t] + + let compare : t -> t -> int = + fun x y -> + match (x, y) with + | (#Item.t as x), (#Item.t as y) -> + Item.compare x y + | #Item.t, #Sum.t -> + -1 + | #Sum.t, #Item.t -> + 1 + | (#Sum.t as x), (#Sum.t as y) -> + Sum.compare x y + + + let equal = [%compare.equal : t] + + let make_node node = `Node node + + let make_pred_edge succ pred = `Edge (pred, succ) + + let make_succ_edge pred succ = `Edge (pred, succ) + + let pp : F.formatter -> t -> unit = + fun fmt -> function #Item.t as item -> Item.pp fmt item | #Sum.t as sum -> Sum.pp fmt sum + + + let sum : Item.t list -> t = function [] -> assert false | [e] -> (e :> t) | l -> Sum.of_list l + + module Set = struct + type elt = t + + type t = {mutable size: int; mutable items: Item.t ARList.t; mutable sums: Sum.t ARList.t} + + let create e = + let items, sums = + match e with + | #Item.t as item -> + (ARList.singleton item, ARList.empty) + | #Sum.t as sum -> + (ARList.empty, ARList.singleton sum) + in + {size= 1; items; sums} + + + (* Because we are modifying things on which we are iterating, we want to invalidate them first before removing them from hashtables, to avoid iterator invalidation. *) + let is_valid {size} = size >= 1 + + let size {size} = size + + (* move semantics, should not be called with aliases *) + let merge_invalidate ~from ~to_ = + assert (not (phys_equal from to_)) ; + assert (is_valid from) ; + assert (is_valid to_) ; + to_.size <- to_.size + from.size ; + to_.items <- ARList.append to_.items from.items ; + to_.sums <- ARList.append to_.sums from.sums ; + from.size <- 0 + + + let pp_equalities fmt t = + if is_valid t then + ARList.append (t.items :> elt ARList.t) (t.sums :> elt ARList.t) + |> IContainer.to_rev_list ~fold:ARList.fold_unordered |> List.sort ~compare + |> Pp.seq ~sep:" = " pp fmt + + + let normalize_sums : normalizer:(elt -> elt) -> t -> unit = + fun ~normalizer t -> + assert (is_valid t) ; + t.sums + <- t.sums + |> IContainer.rev_map_to_list ~fold:ARList.fold_unordered ~f:(Sum.normalize ~normalizer) + |> List.dedup_and_sort ~compare:Sum.compare |> ARList.of_list + + + let infer_equalities_by_removing_item ~on_infer t item = + t.sums + |> IContainer.rev_filter_map_to_list ~fold:ARList.fold_unordered + ~f:(Sum.remove_one_item ~item) + |> IContainer.iter_consecutive ~fold:List.fold ~f:on_infer + + + let infer_equalities_from_sums + : on_infer:(elt -> elt -> unit) -> normalizer:(elt -> elt) -> t -> unit = + fun ~on_infer ~normalizer t -> + if is_valid t then ( + normalize_sums ~normalizer t ; + let all_items = + t.sums + |> ARList.fold_unordered ~init:ARList.empty ~f:(fun acc sum -> + sum |> Sum.items |> ARList.of_list |> ARList.append acc ) + |> IContainer.to_rev_list ~fold:ARList.fold_unordered + |> List.dedup_and_sort ~compare:Item.compare + in + (* Keep in mind that [on_infer] can modify (and possibly invalidate) [t]. + It happens only if we merge a node while infer equalities from it, i.e. in the case an item appears in an equality class both alone and in two sums, i.e. X = A + X = A + B. + This is not a problem here (we could stop if it happens but it is not necessary as existing equalities still remain true after merges) *) + (* Also keep in mind that the current version, in the worst-case scenario, is quadratic-ish in the size of the CFG *) + List.iter all_items ~f:(fun item -> infer_equalities_by_removing_item ~on_infer t item) ) + end +end + +module ImperativeUnionFind (E : sig + type t [@@deriving compare] + + val equal : t -> t -> bool + + val pp : F.formatter -> t -> unit + + module Set : sig + type elt = t + + type t + + val create : elt -> t + + val size : t -> int + + val merge_invalidate : from:t -> to_:t -> unit + + val pp_equalities : F.formatter -> t -> unit + + val normalize_sums : normalizer:(elt -> elt) -> t -> unit + + val infer_equalities_from_sums : + on_infer:(elt -> elt -> unit) -> normalizer:(elt -> elt) -> t -> unit + end +end) = +struct + module H = struct + include Caml.Hashtbl + + let caml_fold = fold + + let fold : (('a, 'b) t, 'a * 'b, 'accum) Container.fold = + fun h ~init ~f -> + let f' k v accum = f accum (k, v) in + caml_fold f' h init + + + let pp_bindings pp_key pp_value fmt h = + let pp_item fmt (k, v) = F.fprintf fmt "%a --> %a" pp_key k pp_value v in + IContainer.pp_collection ~fold ~pp_item fmt h + + + let[@warning "-32"] pp_values pp_value fmt h = + let pp_item fmt (_, v) = pp_value fmt v in + IContainer.pp_collection ~fold ~pp_item fmt h + end + + module Repr : sig + (* Sort-of abstracting away the fact that a representative is just an element itself. + This ensures that the [Sets] hashtable is accessed with representative and not just elements. *) + + type t = private E.t + + val equal : t -> t -> bool + + val pp : F.formatter -> t -> unit + + val of_e : E.t -> t + + val is_simpler_than : t -> t -> bool + end = struct + include E + + let of_e e = e + + let is_simpler_than r1 r2 = compare r1 r2 <= 0 + end + + module Reprs = struct + type t = (E.t, Repr.t) H.t + + let create () = H.create 1 + + let rec find (t: t) e : Repr.t = + match H.find_opt t e with + | None -> + Repr.of_e e + | Some r -> + let r' = find t (r :> E.t) in + if not (phys_equal r r') then H.replace t e r' ; + r' + + + let merge (t: t) ~(from: Repr.t) ~(to_: Repr.t) = H.replace t (from :> E.t) to_ + + let normalizer t e = (find t e :> E.t) + end + + module Set = E.Set + + module Sets = struct + type t = (Repr.t, Set.t) H.t + + let create () = H.create 1 + + let find t (r: Repr.t) = + match H.find_opt t r with + | Some set -> + set + | None -> + let set = Set.create (r :> E.t) in + H.replace t r set ; set + + + let merge_no_remove _t ~from:(from_r, from_set) ~to_:(_, to_set) = + Set.merge_invalidate ~from:from_set ~to_:to_set ; + from_r + + + let merge t ~from:(from_r, from_set) ~to_:(_, to_set) = + H.remove t from_r ; + Set.merge_invalidate ~from:from_set ~to_:to_set + + + let pp_equalities fmt t = H.pp_bindings Repr.pp Set.pp_equalities fmt t + + let normalize_sums ~normalizer t = + H.iter (fun _repr set -> Set.normalize_sums ~normalizer set) t + + + let infer_equalities_from_sums ~on_infer ~normalizer t = + H.iter (fun _repr set -> Set.infer_equalities_from_sums ~on_infer ~normalizer set) t + + + let remove_list t rs = List.iter rs ~f:(H.remove t) + end + + (** + Data-structure for disjoint sets. + [reprs] is the mapping element -> representative + [sets] is the mapping representative -> set + + It implements path-compression and union by size, hence find and union are amortized O(1)-ish. + *) + type t = {reprs: Reprs.t; sets: Sets.t} + + let create () = {reprs= Reprs.create (); sets= Sets.create ()} + + let do_merge t ~from ~to_ ~merge_sets = + let to_r, _ = to_ in + let from_r, _ = from in + Reprs.merge t.reprs ~from:from_r ~to_:to_r ; + merge_sets ~from ~to_ + + + let union_with_merge t e1 e2 ~merge_sets = + let repr1 = Reprs.find t.reprs e1 in + let repr2 = Reprs.find t.reprs e2 in + if Repr.equal repr1 repr2 then None + else + let set1 = Sets.find t.sets repr1 in + let set2 = Sets.find t.sets repr2 in + let size1 = Set.size set1 in + let size2 = Set.size set2 in + if size1 < size2 || (Int.equal size1 size2 && Repr.is_simpler_than repr2 repr1) then ( + (* A desired side-effect of using [is_simpler_than] is that the representative for a set will always be a [`Node]. For now. *) + do_merge t ~from:(repr1, set1) ~to_:(repr2, set2) ~merge_sets ; + Some (e1, e2) ) + else ( + do_merge t ~from:(repr2, set2) ~to_:(repr1, set1) ~merge_sets ; + Some (e2, e1) ) + + + let union_log t e1 e2 ~merge_sets = + match union_with_merge t e1 e2 ~merge_sets with + | None -> + L.(debug Analysis Verbose) "[UF] Preexisting %a = %a@\n" E.pp e1 E.pp e2 + | Some (e1, e2) -> + L.(debug Analysis Verbose) "[UF] Union %a into %a@\n" E.pp e1 E.pp e2 + + + let union t e1 e2 = union_log t e1 e2 ~merge_sets:(Sets.merge t.sets) + + let union_defer_remove t e1 e2 = + let res = ref None in + let merge_sets ~from ~to_ = res := Some (Sets.merge_no_remove t.sets ~from ~to_) in + union_log t e1 e2 ~merge_sets ; !res + + + let pp_equalities fmt t = Sets.pp_equalities fmt t.sets + + let normalizer t = Reprs.normalizer t.reprs + + let normalize_sums t = Sets.normalize_sums ~normalizer:(normalizer t) t.sets + + (** + Infer equalities from sums, like this: + (1) A + sum1 = A + sum2 => sum1 = sum2 + + It does not try to saturate + (2) A = B + C /\ B = D + E => A = C + D + E + Nor combine more than 2 equations + (3) A = B + C /\ B = D + E /\ F = C + D + E => A = F + ((3) is implied by (1) /\ (2)) + + Its complexity is unknown but I think it is bounded by nbNodes x nbEdges x max. + *) + let infer_equalities_from_sums t ~max = + let normalizer = normalizer t in + let on_infer sets_to_remove e1 e2 = + (* need to defer removes to avoid iterator invalidation *) + union_defer_remove t e1 e2 + |> Option.iter ~f:(fun set_to_remove -> sets_to_remove := set_to_remove :: !sets_to_remove) + in + let rec loop max = + let sets_to_remove = ref [] in + let on_infer = on_infer sets_to_remove in + Sets.infer_equalities_from_sums ~on_infer ~normalizer t.sets ; + if not (List.is_empty !sets_to_remove) then ( + Sets.remove_list t.sets !sets_to_remove ; + L.(debug Analysis Verbose) "[ConstraintSolver] %a@\n" pp_equalities t ; + if max > 0 then loop (max - 1) + else + L.(debug Analysis Verbose) "[ConstraintSolver] Maximum number of iterations reached@\n" ) + in + loop max +end + +module ConstraintSolver = struct + module Equalities = ImperativeUnionFind (ControlFlowCost) + + let add_constraints equalities node get_nodes make = + match get_nodes node with + | [] -> + (* either start/exit node or dead node (broken CFG) *) + () + | nodes -> + let node_id = Node.id node in + let edges = List.rev_map nodes ~f:(fun other -> make node_id (Node.id other)) in + let sum = ControlFlowCost.sum edges in + Equalities.union equalities (ControlFlowCost.make_node node_id) sum + + + let collect_on_node equalities node = + add_constraints equalities node Procdesc.Node.get_preds ControlFlowCost.make_pred_edge ; + add_constraints equalities node Procdesc.Node.get_succs ControlFlowCost.make_succ_edge + + + let collect_constraints node_cfg = + let equalities = Equalities.create () in + Container.iter node_cfg ~fold:NodeCFG.fold_nodes ~f:(collect_on_node equalities) ; + L.(debug Analysis Verbose) + "[ConstraintSolver] Procedure %a @@ %a@\n" Typ.Procname.pp (Procdesc.get_proc_name node_cfg) + Location.pp_file_pos (Procdesc.get_loc node_cfg) ; + L.(debug Analysis Verbose) "[ConstraintSolver] %a@\n" Equalities.pp_equalities equalities ; + Equalities.normalize_sums equalities ; + L.(debug Analysis Verbose) "[ConstraintSolver] %a@\n" Equalities.pp_equalities equalities ; + Equalities.infer_equalities_from_sums equalities ~max:10 ; + L.(debug Analysis Verbose) "[ConstraintSolver] %a@\n" Equalities.pp_equalities equalities ; + equalities +end + (* Structural Constraints are expressions of the kind: n <= n1 +...+ nk @@ -578,6 +1013,7 @@ let checker ({Callbacks.tenv; proc_desc} as callback_args) : Summary.t = BoundMap.compute_upperbound_map node_cfg inferbo_invariant_map data_dep_invariant_map control_dep_invariant_map in + let _ = ConstraintSolver.collect_constraints node_cfg in let constraints = StructuralConstraints.compute_structural_constraints node_cfg in let min_trees = MinTree.compute_trees_from_contraints bound_map node_cfg constraints in let trees_valuation = diff --git a/infer/src/istd/ARList.mli b/infer/src/istd/ARList.mli index 8472c18c5..02252ab8a 100644 --- a/infer/src/istd/ARList.mli +++ b/infer/src/istd/ARList.mli @@ -14,8 +14,6 @@ open! IStd include sig (* ocaml ignores the warning suppression at toplevel, hence the [include struct ... end] trick *) - [@@@warning "-60"] - type +'a t (* O(1) time and O(1) allocation *) @@ -64,4 +62,4 @@ include sig val fold_unordered : ('a t, 'a, 'accum) Container.fold (** Always better than [fold_left] when you do not care about the order. *) -end +end[@@warning "-32"] diff --git a/infer/src/istd/IContainer.ml b/infer/src/istd/IContainer.ml index c8d339e4b..79d5d67e0 100644 --- a/infer/src/istd/IContainer.ml +++ b/infer/src/istd/IContainer.ml @@ -7,6 +7,7 @@ (* Extension of Base.Container, i.e. generic definitions of container operations in terms of fold. *) open! IStd +module F = Format type 'a singleton_or_more = Empty | Singleton of 'a | More @@ -29,6 +30,8 @@ let forto excl ~init ~f = aux excl ~f init 0 +let to_rev_list ~fold t = fold t ~init:[] ~f:(fun tl hd -> hd :: tl) + let rev_filter_to_list ~fold t ~f = fold t ~init:[] ~f:(fun acc item -> if f item then item :: acc else acc) @@ -37,3 +40,31 @@ let rev_map_to_list ~fold t ~f = fold t ~init:[] ~f:(fun acc item -> f item :: a let rev_filter_map_to_list ~fold t ~f = fold t ~init:[] ~f:(fun acc item -> IList.opt_cons (f item) acc) + + +let iter_consecutive ~fold t ~f = + let _ : _ option = + fold t ~init:None ~f:(fun prev_opt curr -> + (match prev_opt with Some prev -> f prev curr | None -> ()) ; + Some curr ) + in + () + + +let pp_collection ~fold ~pp_item fmt c = + let f prev_opt item = + prev_opt |> Option.iter ~f:(F.fprintf fmt "@[%a,@]@ " pp_item) ; + Some item + in + let pp_aux fmt c = fold c ~init:None ~f |> Option.iter ~f:(F.fprintf fmt "@[%a@] " pp_item) in + F.fprintf fmt "@[{ %a}@]" pp_aux c + + +let pp_seq ~fold ~sep pp_item fmt c = + let f first item = + if not first then F.pp_print_string fmt sep ; + pp_item fmt item ; + false + in + let _is_empty : bool = fold c ~init:true ~f in + () diff --git a/infer/src/istd/IContainer.mli b/infer/src/istd/IContainer.mli index a9d8a8aa8..f42bbafe2 100644 --- a/infer/src/istd/IContainer.mli +++ b/infer/src/istd/IContainer.mli @@ -6,6 +6,7 @@ *) open! IStd +module F = Format type 'a singleton_or_more = Empty | Singleton of 'a | More @@ -16,9 +17,23 @@ val mem_nth : fold:('t, _, int) Container.fold -> 't -> int -> bool val forto : (int, int, 'accum) Container.fold +val to_rev_list : fold:('t, 'a, 'a list) Container.fold -> 't -> 'a list + val rev_filter_to_list : fold:('t, 'a, 'a list) Container.fold -> 't -> f:('a -> bool) -> 'a list val rev_map_to_list : fold:('t, 'a, 'b list) Container.fold -> 't -> f:('a -> 'b) -> 'b list val rev_filter_map_to_list : fold:('t, 'a, 'b list) Container.fold -> 't -> f:('a -> 'b option) -> 'b list + +val iter_consecutive : + fold:('t, 'a, 'a option) Container.fold -> 't -> f:('a -> 'a -> unit) -> unit + +val pp_collection : + fold:('t, 'a, 'a option) Container.fold -> pp_item:(F.formatter -> 'a -> unit) -> F.formatter + -> 't -> unit + +val pp_seq : + fold:('t, 'a, bool) Container.fold -> sep:string -> (F.formatter -> 'a -> unit) -> F.formatter + -> 't -> unit + [@@warning "-32"] diff --git a/infer/src/istd/IList.ml b/infer/src/istd/IList.ml index fce9635d8..7eb3145ff 100644 --- a/infer/src/istd/IList.ml +++ b/infer/src/istd/IList.ml @@ -161,3 +161,13 @@ let rec drop list index = let opt_cons opt list = match opt with Some x -> x :: list | None -> list + +let remove_first = + let rec aux list ~f rev_front = + match list with + | [] -> + None + | hd :: tl -> + if f hd then Some (List.rev_append rev_front tl) else aux tl ~f (hd :: rev_front) + in + fun list ~f -> aux list ~f [] diff --git a/infer/src/istd/IList.mli b/infer/src/istd/IList.mli index 6091fbf14..9c4e9be9e 100644 --- a/infer/src/istd/IList.mli +++ b/infer/src/istd/IList.mli @@ -39,3 +39,5 @@ val drop : 'a list -> int -> 'a list val opt_cons : 'a option -> 'a list -> 'a list (** [opt_cons None l] returns [l]. [opt_cons (Some x) l] returns [x :: l]*) + +val remove_first : 'a list -> f:('a -> bool) -> 'a list option diff --git a/infer/src/istd/PrettyPrintable.ml b/infer/src/istd/PrettyPrintable.ml index ab0239832..35225946e 100644 --- a/infer/src/istd/PrettyPrintable.ml +++ b/infer/src/istd/PrettyPrintable.ml @@ -32,18 +32,7 @@ module type PPMap = sig val pp : pp_value:(F.formatter -> 'a -> unit) -> F.formatter -> 'a t -> unit end -let pp_collection ~pp_item fmt c = - let rec pp_list fmt = function - | [] -> - () - | [item] -> - F.fprintf fmt "@[%a@] " pp_item item - | item :: items -> - F.fprintf fmt "@[%a,@]@ " pp_item item ; - pp_list fmt items - in - F.fprintf fmt "@[{ %a}@]" pp_list c - +let pp_collection ~pp_item fmt c = IContainer.pp_collection ~fold:List.fold ~pp_item fmt c module MakePPSet (Ord : PrintableOrderedType) = struct include Caml.Set.Make (Ord)