Cost: control-flow equality classes

Reviewed By: ddino

Differential Revision: D8186669

fbshipit-source-id: e919c16
master
Mehdi Bouaziz 7 years ago committed by Facebook Github Bot
parent fd93d907e8
commit aee02e27ef

@ -204,6 +204,441 @@ module BoundMap = struct
BasicCost.top BasicCost.top
end end
module ControlFlowCost = struct
(* A Control-flow cost represents the number of times the flow of control can go through a certain CFG item (a node or an edge),
or a sum of such things *)
module Item = struct
type t = [`Node of Node.id | `Edge of Node.id * Node.id]
let compare : t -> t -> int =
fun x y ->
match (x, y) with
| `Node id1, `Node id2 ->
Node.compare_id id1 id2
| `Node _, `Edge _ ->
-1
| `Edge _, `Node _ ->
1
| `Edge (f1, t1), `Edge (f2, t2) ->
[%compare : Node.id * Node.id] (f1, t1) (f2, t2)
let equal = [%compare.equal : t]
let pp : F.formatter -> t -> unit =
fun fmt -> function
| `Node id ->
F.fprintf fmt "Node(%a)" Node.pp_id id
| `Edge (f, t) ->
F.fprintf fmt "Edge(%a -> %a)" Node.pp_id f Node.pp_id t
let normalize ~(normalizer: t -> [> t]) (x: t) : t =
match normalizer x with #t as x -> x | _ -> assert false
end
module Sum = struct
type 'a set = (* non-empty sorted list *) 'a list
type t = [`Sum of int * Item.t set]
let of_list l =
let length = List.length l in
let set = List.sort ~compare:Item.compare l in
`Sum (length, set)
let compare : t -> t -> int =
fun (`Sum (l1, s1)) (`Sum (l2, s2)) -> [%compare : int * Item.t list] (l1, s1) (l2, s2)
let pp : F.formatter -> t -> unit =
fun fmt (`Sum (_, set)) -> Pp.seq ~sep:" + " Item.pp fmt set
let items (`Sum (_, l)) = l
let normalized_items ~normalizer (`Sum (_, l)) =
let normalizer = (normalizer :> Item.t -> [> Item.t]) in
l |> List.rev_map ~f:(Item.normalize ~normalizer)
let normalize ~normalizer sum = sum |> normalized_items ~normalizer |> of_list
(* Given a sum and an item, remove one occurence of the item in the sum. Returns [None] if the item is not present in the sum.
[remove_one_item ~item:A (A + B)] = B
[remove_one_item ~item:A (A + B + C)] = B + C
[remove_one_item ~item:A (A + A + B)] = A + B
[remove_one_item ~item:A (B + C)] = None
*)
let remove_one_item ~item (`Sum (len, l)) =
match IList.remove_first l ~f:(Item.equal item) with
| None ->
None
| Some [e] ->
Some (e :> [Item.t | t])
| Some l ->
Some (`Sum (len - 1, l))
end
type t = [Item.t | Sum.t]
let compare : t -> t -> int =
fun x y ->
match (x, y) with
| (#Item.t as x), (#Item.t as y) ->
Item.compare x y
| #Item.t, #Sum.t ->
-1
| #Sum.t, #Item.t ->
1
| (#Sum.t as x), (#Sum.t as y) ->
Sum.compare x y
let equal = [%compare.equal : t]
let make_node node = `Node node
let make_pred_edge succ pred = `Edge (pred, succ)
let make_succ_edge pred succ = `Edge (pred, succ)
let pp : F.formatter -> t -> unit =
fun fmt -> function #Item.t as item -> Item.pp fmt item | #Sum.t as sum -> Sum.pp fmt sum
let sum : Item.t list -> t = function [] -> assert false | [e] -> (e :> t) | l -> Sum.of_list l
module Set = struct
type elt = t
type t = {mutable size: int; mutable items: Item.t ARList.t; mutable sums: Sum.t ARList.t}
let create e =
let items, sums =
match e with
| #Item.t as item ->
(ARList.singleton item, ARList.empty)
| #Sum.t as sum ->
(ARList.empty, ARList.singleton sum)
in
{size= 1; items; sums}
(* Because we are modifying things on which we are iterating, we want to invalidate them first before removing them from hashtables, to avoid iterator invalidation. *)
let is_valid {size} = size >= 1
let size {size} = size
(* move semantics, should not be called with aliases *)
let merge_invalidate ~from ~to_ =
assert (not (phys_equal from to_)) ;
assert (is_valid from) ;
assert (is_valid to_) ;
to_.size <- to_.size + from.size ;
to_.items <- ARList.append to_.items from.items ;
to_.sums <- ARList.append to_.sums from.sums ;
from.size <- 0
let pp_equalities fmt t =
if is_valid t then
ARList.append (t.items :> elt ARList.t) (t.sums :> elt ARList.t)
|> IContainer.to_rev_list ~fold:ARList.fold_unordered |> List.sort ~compare
|> Pp.seq ~sep:" = " pp fmt
let normalize_sums : normalizer:(elt -> elt) -> t -> unit =
fun ~normalizer t ->
assert (is_valid t) ;
t.sums
<- t.sums
|> IContainer.rev_map_to_list ~fold:ARList.fold_unordered ~f:(Sum.normalize ~normalizer)
|> List.dedup_and_sort ~compare:Sum.compare |> ARList.of_list
let infer_equalities_by_removing_item ~on_infer t item =
t.sums
|> IContainer.rev_filter_map_to_list ~fold:ARList.fold_unordered
~f:(Sum.remove_one_item ~item)
|> IContainer.iter_consecutive ~fold:List.fold ~f:on_infer
let infer_equalities_from_sums
: on_infer:(elt -> elt -> unit) -> normalizer:(elt -> elt) -> t -> unit =
fun ~on_infer ~normalizer t ->
if is_valid t then (
normalize_sums ~normalizer t ;
let all_items =
t.sums
|> ARList.fold_unordered ~init:ARList.empty ~f:(fun acc sum ->
sum |> Sum.items |> ARList.of_list |> ARList.append acc )
|> IContainer.to_rev_list ~fold:ARList.fold_unordered
|> List.dedup_and_sort ~compare:Item.compare
in
(* Keep in mind that [on_infer] can modify (and possibly invalidate) [t].
It happens only if we merge a node while infer equalities from it, i.e. in the case an item appears in an equality class both alone and in two sums, i.e. X = A + X = A + B.
This is not a problem here (we could stop if it happens but it is not necessary as existing equalities still remain true after merges) *)
(* Also keep in mind that the current version, in the worst-case scenario, is quadratic-ish in the size of the CFG *)
List.iter all_items ~f:(fun item -> infer_equalities_by_removing_item ~on_infer t item) )
end
end
module ImperativeUnionFind (E : sig
type t [@@deriving compare]
val equal : t -> t -> bool
val pp : F.formatter -> t -> unit
module Set : sig
type elt = t
type t
val create : elt -> t
val size : t -> int
val merge_invalidate : from:t -> to_:t -> unit
val pp_equalities : F.formatter -> t -> unit
val normalize_sums : normalizer:(elt -> elt) -> t -> unit
val infer_equalities_from_sums :
on_infer:(elt -> elt -> unit) -> normalizer:(elt -> elt) -> t -> unit
end
end) =
struct
module H = struct
include Caml.Hashtbl
let caml_fold = fold
let fold : (('a, 'b) t, 'a * 'b, 'accum) Container.fold =
fun h ~init ~f ->
let f' k v accum = f accum (k, v) in
caml_fold f' h init
let pp_bindings pp_key pp_value fmt h =
let pp_item fmt (k, v) = F.fprintf fmt "%a --> %a" pp_key k pp_value v in
IContainer.pp_collection ~fold ~pp_item fmt h
let[@warning "-32"] pp_values pp_value fmt h =
let pp_item fmt (_, v) = pp_value fmt v in
IContainer.pp_collection ~fold ~pp_item fmt h
end
module Repr : sig
(* Sort-of abstracting away the fact that a representative is just an element itself.
This ensures that the [Sets] hashtable is accessed with representative and not just elements. *)
type t = private E.t
val equal : t -> t -> bool
val pp : F.formatter -> t -> unit
val of_e : E.t -> t
val is_simpler_than : t -> t -> bool
end = struct
include E
let of_e e = e
let is_simpler_than r1 r2 = compare r1 r2 <= 0
end
module Reprs = struct
type t = (E.t, Repr.t) H.t
let create () = H.create 1
let rec find (t: t) e : Repr.t =
match H.find_opt t e with
| None ->
Repr.of_e e
| Some r ->
let r' = find t (r :> E.t) in
if not (phys_equal r r') then H.replace t e r' ;
r'
let merge (t: t) ~(from: Repr.t) ~(to_: Repr.t) = H.replace t (from :> E.t) to_
let normalizer t e = (find t e :> E.t)
end
module Set = E.Set
module Sets = struct
type t = (Repr.t, Set.t) H.t
let create () = H.create 1
let find t (r: Repr.t) =
match H.find_opt t r with
| Some set ->
set
| None ->
let set = Set.create (r :> E.t) in
H.replace t r set ; set
let merge_no_remove _t ~from:(from_r, from_set) ~to_:(_, to_set) =
Set.merge_invalidate ~from:from_set ~to_:to_set ;
from_r
let merge t ~from:(from_r, from_set) ~to_:(_, to_set) =
H.remove t from_r ;
Set.merge_invalidate ~from:from_set ~to_:to_set
let pp_equalities fmt t = H.pp_bindings Repr.pp Set.pp_equalities fmt t
let normalize_sums ~normalizer t =
H.iter (fun _repr set -> Set.normalize_sums ~normalizer set) t
let infer_equalities_from_sums ~on_infer ~normalizer t =
H.iter (fun _repr set -> Set.infer_equalities_from_sums ~on_infer ~normalizer set) t
let remove_list t rs = List.iter rs ~f:(H.remove t)
end
(**
Data-structure for disjoint sets.
[reprs] is the mapping element -> representative
[sets] is the mapping representative -> set
It implements path-compression and union by size, hence find and union are amortized O(1)-ish.
*)
type t = {reprs: Reprs.t; sets: Sets.t}
let create () = {reprs= Reprs.create (); sets= Sets.create ()}
let do_merge t ~from ~to_ ~merge_sets =
let to_r, _ = to_ in
let from_r, _ = from in
Reprs.merge t.reprs ~from:from_r ~to_:to_r ;
merge_sets ~from ~to_
let union_with_merge t e1 e2 ~merge_sets =
let repr1 = Reprs.find t.reprs e1 in
let repr2 = Reprs.find t.reprs e2 in
if Repr.equal repr1 repr2 then None
else
let set1 = Sets.find t.sets repr1 in
let set2 = Sets.find t.sets repr2 in
let size1 = Set.size set1 in
let size2 = Set.size set2 in
if size1 < size2 || (Int.equal size1 size2 && Repr.is_simpler_than repr2 repr1) then (
(* A desired side-effect of using [is_simpler_than] is that the representative for a set will always be a [`Node]. For now. *)
do_merge t ~from:(repr1, set1) ~to_:(repr2, set2) ~merge_sets ;
Some (e1, e2) )
else (
do_merge t ~from:(repr2, set2) ~to_:(repr1, set1) ~merge_sets ;
Some (e2, e1) )
let union_log t e1 e2 ~merge_sets =
match union_with_merge t e1 e2 ~merge_sets with
| None ->
L.(debug Analysis Verbose) "[UF] Preexisting %a = %a@\n" E.pp e1 E.pp e2
| Some (e1, e2) ->
L.(debug Analysis Verbose) "[UF] Union %a into %a@\n" E.pp e1 E.pp e2
let union t e1 e2 = union_log t e1 e2 ~merge_sets:(Sets.merge t.sets)
let union_defer_remove t e1 e2 =
let res = ref None in
let merge_sets ~from ~to_ = res := Some (Sets.merge_no_remove t.sets ~from ~to_) in
union_log t e1 e2 ~merge_sets ; !res
let pp_equalities fmt t = Sets.pp_equalities fmt t.sets
let normalizer t = Reprs.normalizer t.reprs
let normalize_sums t = Sets.normalize_sums ~normalizer:(normalizer t) t.sets
(**
Infer equalities from sums, like this:
(1) A + sum1 = A + sum2 => sum1 = sum2
It does not try to saturate
(2) A = B + C /\ B = D + E => A = C + D + E
Nor combine more than 2 equations
(3) A = B + C /\ B = D + E /\ F = C + D + E => A = F
((3) is implied by (1) /\ (2))
Its complexity is unknown but I think it is bounded by nbNodes x nbEdges x max.
*)
let infer_equalities_from_sums t ~max =
let normalizer = normalizer t in
let on_infer sets_to_remove e1 e2 =
(* need to defer removes to avoid iterator invalidation *)
union_defer_remove t e1 e2
|> Option.iter ~f:(fun set_to_remove -> sets_to_remove := set_to_remove :: !sets_to_remove)
in
let rec loop max =
let sets_to_remove = ref [] in
let on_infer = on_infer sets_to_remove in
Sets.infer_equalities_from_sums ~on_infer ~normalizer t.sets ;
if not (List.is_empty !sets_to_remove) then (
Sets.remove_list t.sets !sets_to_remove ;
L.(debug Analysis Verbose) "[ConstraintSolver] %a@\n" pp_equalities t ;
if max > 0 then loop (max - 1)
else
L.(debug Analysis Verbose) "[ConstraintSolver] Maximum number of iterations reached@\n" )
in
loop max
end
module ConstraintSolver = struct
module Equalities = ImperativeUnionFind (ControlFlowCost)
let add_constraints equalities node get_nodes make =
match get_nodes node with
| [] ->
(* either start/exit node or dead node (broken CFG) *)
()
| nodes ->
let node_id = Node.id node in
let edges = List.rev_map nodes ~f:(fun other -> make node_id (Node.id other)) in
let sum = ControlFlowCost.sum edges in
Equalities.union equalities (ControlFlowCost.make_node node_id) sum
let collect_on_node equalities node =
add_constraints equalities node Procdesc.Node.get_preds ControlFlowCost.make_pred_edge ;
add_constraints equalities node Procdesc.Node.get_succs ControlFlowCost.make_succ_edge
let collect_constraints node_cfg =
let equalities = Equalities.create () in
Container.iter node_cfg ~fold:NodeCFG.fold_nodes ~f:(collect_on_node equalities) ;
L.(debug Analysis Verbose)
"[ConstraintSolver] Procedure %a @@ %a@\n" Typ.Procname.pp (Procdesc.get_proc_name node_cfg)
Location.pp_file_pos (Procdesc.get_loc node_cfg) ;
L.(debug Analysis Verbose) "[ConstraintSolver] %a@\n" Equalities.pp_equalities equalities ;
Equalities.normalize_sums equalities ;
L.(debug Analysis Verbose) "[ConstraintSolver] %a@\n" Equalities.pp_equalities equalities ;
Equalities.infer_equalities_from_sums equalities ~max:10 ;
L.(debug Analysis Verbose) "[ConstraintSolver] %a@\n" Equalities.pp_equalities equalities ;
equalities
end
(* Structural Constraints are expressions of the kind: (* Structural Constraints are expressions of the kind:
n <= n1 +...+ nk n <= n1 +...+ nk
@ -578,6 +1013,7 @@ let checker ({Callbacks.tenv; proc_desc} as callback_args) : Summary.t =
BoundMap.compute_upperbound_map node_cfg inferbo_invariant_map data_dep_invariant_map BoundMap.compute_upperbound_map node_cfg inferbo_invariant_map data_dep_invariant_map
control_dep_invariant_map control_dep_invariant_map
in in
let _ = ConstraintSolver.collect_constraints node_cfg in
let constraints = StructuralConstraints.compute_structural_constraints node_cfg in let constraints = StructuralConstraints.compute_structural_constraints node_cfg in
let min_trees = MinTree.compute_trees_from_contraints bound_map node_cfg constraints in let min_trees = MinTree.compute_trees_from_contraints bound_map node_cfg constraints in
let trees_valuation = let trees_valuation =

@ -14,8 +14,6 @@ open! IStd
include sig include sig
(* ocaml ignores the warning suppression at toplevel, hence the [include struct ... end] trick *) (* ocaml ignores the warning suppression at toplevel, hence the [include struct ... end] trick *)
[@@@warning "-60"]
type +'a t type +'a t
(* O(1) time and O(1) allocation *) (* O(1) time and O(1) allocation *)
@ -64,4 +62,4 @@ include sig
val fold_unordered : ('a t, 'a, 'accum) Container.fold val fold_unordered : ('a t, 'a, 'accum) Container.fold
(** Always better than [fold_left] when you do not care about the order. *) (** Always better than [fold_left] when you do not care about the order. *)
end end[@@warning "-32"]

@ -7,6 +7,7 @@
(* Extension of Base.Container, i.e. generic definitions of container operations in terms of fold. *) (* Extension of Base.Container, i.e. generic definitions of container operations in terms of fold. *)
open! IStd open! IStd
module F = Format
type 'a singleton_or_more = Empty | Singleton of 'a | More type 'a singleton_or_more = Empty | Singleton of 'a | More
@ -29,6 +30,8 @@ let forto excl ~init ~f =
aux excl ~f init 0 aux excl ~f init 0
let to_rev_list ~fold t = fold t ~init:[] ~f:(fun tl hd -> hd :: tl)
let rev_filter_to_list ~fold t ~f = let rev_filter_to_list ~fold t ~f =
fold t ~init:[] ~f:(fun acc item -> if f item then item :: acc else acc) fold t ~init:[] ~f:(fun acc item -> if f item then item :: acc else acc)
@ -37,3 +40,31 @@ let rev_map_to_list ~fold t ~f = fold t ~init:[] ~f:(fun acc item -> f item :: a
let rev_filter_map_to_list ~fold t ~f = let rev_filter_map_to_list ~fold t ~f =
fold t ~init:[] ~f:(fun acc item -> IList.opt_cons (f item) acc) fold t ~init:[] ~f:(fun acc item -> IList.opt_cons (f item) acc)
let iter_consecutive ~fold t ~f =
let _ : _ option =
fold t ~init:None ~f:(fun prev_opt curr ->
(match prev_opt with Some prev -> f prev curr | None -> ()) ;
Some curr )
in
()
let pp_collection ~fold ~pp_item fmt c =
let f prev_opt item =
prev_opt |> Option.iter ~f:(F.fprintf fmt "@[<h>%a,@]@ " pp_item) ;
Some item
in
let pp_aux fmt c = fold c ~init:None ~f |> Option.iter ~f:(F.fprintf fmt "@[<h>%a@] " pp_item) in
F.fprintf fmt "@[<hv 2>{ %a}@]" pp_aux c
let pp_seq ~fold ~sep pp_item fmt c =
let f first item =
if not first then F.pp_print_string fmt sep ;
pp_item fmt item ;
false
in
let _is_empty : bool = fold c ~init:true ~f in
()

@ -6,6 +6,7 @@
*) *)
open! IStd open! IStd
module F = Format
type 'a singleton_or_more = Empty | Singleton of 'a | More type 'a singleton_or_more = Empty | Singleton of 'a | More
@ -16,9 +17,23 @@ val mem_nth : fold:('t, _, int) Container.fold -> 't -> int -> bool
val forto : (int, int, 'accum) Container.fold val forto : (int, int, 'accum) Container.fold
val to_rev_list : fold:('t, 'a, 'a list) Container.fold -> 't -> 'a list
val rev_filter_to_list : fold:('t, 'a, 'a list) Container.fold -> 't -> f:('a -> bool) -> 'a list val rev_filter_to_list : fold:('t, 'a, 'a list) Container.fold -> 't -> f:('a -> bool) -> 'a list
val rev_map_to_list : fold:('t, 'a, 'b list) Container.fold -> 't -> f:('a -> 'b) -> 'b list val rev_map_to_list : fold:('t, 'a, 'b list) Container.fold -> 't -> f:('a -> 'b) -> 'b list
val rev_filter_map_to_list : val rev_filter_map_to_list :
fold:('t, 'a, 'b list) Container.fold -> 't -> f:('a -> 'b option) -> 'b list fold:('t, 'a, 'b list) Container.fold -> 't -> f:('a -> 'b option) -> 'b list
val iter_consecutive :
fold:('t, 'a, 'a option) Container.fold -> 't -> f:('a -> 'a -> unit) -> unit
val pp_collection :
fold:('t, 'a, 'a option) Container.fold -> pp_item:(F.formatter -> 'a -> unit) -> F.formatter
-> 't -> unit
val pp_seq :
fold:('t, 'a, bool) Container.fold -> sep:string -> (F.formatter -> 'a -> unit) -> F.formatter
-> 't -> unit
[@@warning "-32"]

@ -161,3 +161,13 @@ let rec drop list index =
let opt_cons opt list = match opt with Some x -> x :: list | None -> list let opt_cons opt list = match opt with Some x -> x :: list | None -> list
let remove_first =
let rec aux list ~f rev_front =
match list with
| [] ->
None
| hd :: tl ->
if f hd then Some (List.rev_append rev_front tl) else aux tl ~f (hd :: rev_front)
in
fun list ~f -> aux list ~f []

@ -39,3 +39,5 @@ val drop : 'a list -> int -> 'a list
val opt_cons : 'a option -> 'a list -> 'a list val opt_cons : 'a option -> 'a list -> 'a list
(** [opt_cons None l] returns [l]. [opt_cons (Some x) l] returns [x :: l]*) (** [opt_cons None l] returns [l]. [opt_cons (Some x) l] returns [x :: l]*)
val remove_first : 'a list -> f:('a -> bool) -> 'a list option

@ -32,18 +32,7 @@ module type PPMap = sig
val pp : pp_value:(F.formatter -> 'a -> unit) -> F.formatter -> 'a t -> unit val pp : pp_value:(F.formatter -> 'a -> unit) -> F.formatter -> 'a t -> unit
end end
let pp_collection ~pp_item fmt c = let pp_collection ~pp_item fmt c = IContainer.pp_collection ~fold:List.fold ~pp_item fmt c
let rec pp_list fmt = function
| [] ->
()
| [item] ->
F.fprintf fmt "@[<h>%a@] " pp_item item
| item :: items ->
F.fprintf fmt "@[<h>%a,@]@ " pp_item item ;
pp_list fmt items
in
F.fprintf fmt "@[<hv 2>{ %a}@]" pp_list c
module MakePPSet (Ord : PrintableOrderedType) = struct module MakePPSet (Ord : PrintableOrderedType) = struct
include Caml.Set.Make (Ord) include Caml.Set.Make (Ord)

Loading…
Cancel
Save