Let's traverse the constraints list only once

Reviewed By: mbouaziz

Differential Revision: D8076304

fbshipit-source-id: 51a8c53
master
Dino Distefano 7 years ago committed by Facebook Github Bot
parent a0e3314b7b
commit 4b296003fa

@ -29,12 +29,7 @@ let expensive_threshold = BasicCost.of_int_exn 200
module InstrCFG = ProcCfg.NormalOneInstrPerNode
module NodeCFG = ProcCfg.Normal
module InstrCFGScheduler = Scheduler.ReversePostorder (InstrCFG)
module Node = struct
include ProcCfg.DefaultNode
let equal_id = [%compare.equal : id]
end
module Node = ProcCfg.DefaultNode
(* Compute a map (node,instruction) -> basic_cost, where basic_cost is the
cost known for a certain operation. For example for basic operation we
@ -218,36 +213,31 @@ end
equal to the sum of the number of times nodes n1,..., nk can be executed.
*)
module StructuralConstraints = struct
type rhs = Single of Node.id | Sum of Node.IdSet.t
type t = {lhs: Node.id; rhs: rhs}
let is_single ~lhs:expected_lhs = function
| {lhs; rhs= Single single} when Node.equal_id lhs expected_lhs ->
Some single
| _ ->
None
let is_sum ~lhs:expected_lhs = function
| {lhs; rhs= Sum sum} when Node.equal_id lhs expected_lhs ->
Some sum
| _ ->
None
type t = {single: Node.id list; sum: Node.IdSet.t list}
(*
Finds subset of constraints of node k.
It returns a pair (single_constraints, sum_constraints) where single constraints are
of the form 'x_k <= x_j' and sum constraints are of the form 'x_k <= x_j1 +...+ x_jn'.
*)
let get_constraints_of_node constraints k =
let c = Node.IdMap.find_opt k constraints in
match c with Some c -> c | _ -> {single= []; sum= []}
let pp_rhs fmt = function
| Single nid ->
Node.pp_id fmt nid
| Sum nidset ->
Pp.seq ~sep:" + " Node.pp_id fmt (Node.IdSet.elements nidset)
let pp fmt {lhs; rhs} = F.fprintf fmt "%a <= %a" Node.pp_id lhs pp_rhs rhs
let print_constraint_list constraints =
L.(debug Analysis Medium) "@\n\n******* Structural Constraints **** @\n" ;
List.iter ~f:(fun c -> L.(debug Analysis Medium) "@\n %a @\n" pp c) constraints ;
let print_constraints_map constraints =
let pp_nidset fmt nidset = Pp.seq ~sep:" + " Node.pp_id fmt (Node.IdSet.elements nidset) in
L.(debug Analysis Medium)
"@\n\n******* Structural Constraints size = %i **** @\n" (Node.IdMap.cardinal constraints) ;
Node.IdMap.iter
(fun n {single; sum} ->
List.iter
~f:(fun s -> L.(debug Analysis Medium) "@\n %a <= %a @\n" Node.pp_id n Node.pp_id s)
single ;
List.iter
~f:(fun s -> L.(debug Analysis Medium) "@\n %a <= %a @\n" Node.pp_id n pp_nidset s)
sum )
constraints ;
L.(debug Analysis Medium) "@\n******* END Structural Constraints **** @\n\n"
@ -258,24 +248,28 @@ module StructuralConstraints = struct
*)
let compute_structural_constraints node_cfg =
let compute_node_constraints acc node =
let constraints_append node get_nodes tail =
let constraints_add node get_nodes =
match get_nodes node with
| [] ->
tail
{single= []; sum= []}
| [single] ->
{lhs= NodeCFG.id node; rhs= Single (NodeCFG.id single)} :: tail
{single= [NodeCFG.id single]; sum= []}
| nodes ->
let sum =
List.fold nodes ~init:Node.IdSet.empty ~f:(fun idset node ->
Node.IdSet.add (NodeCFG.id node) idset )
in
{lhs= NodeCFG.id node; rhs= Sum sum} :: tail
{single= []; sum= [sum]}
in
acc |> constraints_append node Procdesc.Node.get_preds
|> constraints_append node Procdesc.Node.get_succs
let preds = constraints_add node Procdesc.Node.get_preds in
let succs = constraints_add node Procdesc.Node.get_succs in
Node.IdMap.add (NodeCFG.id node)
{single= List.append preds.single succs.single; sum= List.append preds.sum succs.sum} acc
in
let constraints =
List.fold (NodeCFG.nodes node_cfg) ~f:compute_node_constraints ~init:Node.IdMap.empty
in
let constraints = List.fold (NodeCFG.nodes node_cfg) ~f:compute_node_constraints ~init:[] in
print_constraint_list constraints ; constraints
print_constraints_map constraints ; constraints
end
(* MinTree is used to compute:
@ -317,17 +311,6 @@ module MinTree = struct
match node with Plus l -> Plus (child :: l) | Min l -> Min (child :: l) | _ -> assert false
(* finds the subset of constraints of the form x_k <= x_j *)
let get_k_single_constraints constraints k =
List.filter_map constraints ~f:(StructuralConstraints.is_single ~lhs:k)
(* finds the subset of constraints of the form x_k <= x_j1 +...+ x_jn and
return the addends of the sum x_j1+x_j2+..+x_j_n*)
let get_k_sum_constraints constraints k =
List.filter_map constraints ~f:(StructuralConstraints.is_sum ~lhs:k)
let rec evaluate_tree t =
match t with
| Leaf (_, c) ->
@ -361,7 +344,8 @@ return the addends of the sum x_j1+x_j2+..+x_j_n*)
type t = Node.id * Node.IdSet.t [@@deriving compare]
end)
let minimum_propagation (bound_map: BoundMap.t) (constraints: StructuralConstraints.t list) self
let minimum_propagation (bound_map: BoundMap.t)
(constraints: StructuralConstraints.t Node.IdMap.t) self
((q, visited): Node.id * Node.IdSet.t) =
let rec build_min node branch visited_acc worklist =
match worklist with
@ -372,19 +356,18 @@ return the addends of the sum x_j1+x_j2+..+x_j_n*)
else
let visited_acc' = Node.IdSet.add k visited_acc in
let node = add_leaf node k (BoundMap.upperbound bound_map k) in
let k_constraints_upperbound = get_k_single_constraints constraints k in
let k_constraints = StructuralConstraints.get_constraints_of_node constraints k in
let worklist' =
List.fold k_constraints_upperbound ~init:rest ~f:(fun acc ub_id ->
List.fold k_constraints.single ~init:rest ~f:(fun acc ub_id ->
if Node.IdSet.mem ub_id visited_acc' then acc else ub_id :: acc )
in
let k_sum_constraints = get_k_sum_constraints constraints k in
let branch =
List.fold_left
~f:(fun branch set_addend ->
if Node.IdSet.is_empty (Node.IdSet.inter set_addend visited_acc') then
SetOfSetsOfNodes.add set_addend branch
else branch )
~init:branch k_sum_constraints
~init:branch k_constraints.sum
in
build_min node branch visited_acc' worklist'
in

Loading…
Cancel
Save