[hoisting] Invalidate args of impure function calls

Reviewed By: mbouaziz

Differential Revision: D10236724

fbshipit-source-id: f39d4574d
master
Ezgi Çiçek 6 years ago committed by Facebook Github Bot
parent 4954d3da4b
commit affe3d1d60

@ -8,9 +8,12 @@ open! IStd
module L = Logging
module InvariantVars = AbstractDomain.FiniteSet (Var)
module VarsInLoop = AbstractDomain.FiniteSet (Var)
module InvalidatedVars = AbstractDomain.FiniteSet (Var)
module LoopNodes = AbstractDomain.FiniteSet (Procdesc.Node)
module Models = InvariantModels
let debug fmt = L.(debug Analysis Medium) fmt
(** Map loop header node -> all nodes in the loop *)
module LoopHeadToLoopNodes = Procdesc.NodeMap
@ -20,10 +23,8 @@ let is_defined_outside loop_nodes reaching_defs var =
|> Option.value ~default:true
let is_fun_call_invariant tenv ~is_exp_invariant ~is_inv_by_default callee_pname params =
List.for_all ~f:(fun (exp, _) -> is_exp_invariant exp) params
&&
(* Take into account invariance behavior of modeled functions *)
let is_fun_pure tenv ~is_inv_by_default callee_pname params =
(* Take into account purity behavior of modeled functions *)
match Models.Call.dispatch tenv callee_pname params with
| Some inv ->
InvariantModels.is_invariant inv
@ -55,17 +56,19 @@ let is_def_unique_and_satisfy tenv var (loop_nodes : LoopNodes.t) ~is_inv_by_def
when Exp.equal exp_lhs (Var.to_exp var) && is_exp_invariant exp_rhs ->
true
| Sil.Call ((id, _), Const (Cfun callee_pname), params, _, _) when equals_var id ->
is_fun_call_invariant tenv ~is_exp_invariant ~is_inv_by_default callee_pname
params
is_fun_pure tenv ~is_inv_by_default callee_pname params
&& (* check if all params are invariant *)
List.for_all ~f:(fun (exp, _) -> is_exp_invariant exp) params
| _ ->
false ) )
loop_nodes
let is_exp_invariant inv_vars loop_nodes reaching_defs exp =
let is_exp_invariant inv_vars invalidated_vars loop_nodes reaching_defs exp =
Var.get_all_vars_in_exp exp
|> Sequence.for_all ~f:(fun var ->
InvariantVars.mem var inv_vars || is_defined_outside loop_nodes reaching_defs var )
(not (InvalidatedVars.mem var invalidated_vars))
&& (InvariantVars.mem var inv_vars || is_defined_outside loop_nodes reaching_defs var) )
let get_vars_in_loop loop_nodes =
@ -97,14 +100,55 @@ let get_vars_in_loop loop_nodes =
loop_nodes VarsInLoop.empty
let get_loaded_object var node invalidated_vars =
Procdesc.Node.get_instrs node
|> Instrs.fold
~f:(fun acc instr ->
match instr with
| Sil.Load (id, Lvar pvar, typ, _) when Var.equal var (Var.of_id id) && Typ.is_pointer typ
->
InvalidatedVars.add (Var.of_pvar pvar) acc
| _ ->
acc )
~init:invalidated_vars
let get_vars_to_invalidate node params invalidated_vars : InvalidatedVars.t =
List.fold ~init:invalidated_vars
~f:(fun acc (arg_exp, _) ->
Var.get_all_vars_in_exp arg_exp
|> Sequence.fold ~init:acc ~f:(fun acc var ->
get_loaded_object var node (InvalidatedVars.add var acc) ) )
params
(* If there is a call to an impure function in the loop, invalidate
all its non-primitive arguments. Once invalidated, it should be
never added again. *)
let get_invalidated_vars_in_loop tenv ~is_inv_by_default loop_nodes =
LoopNodes.fold
(fun node acc ->
Procdesc.Node.get_instrs node
|> Instrs.fold ~init:acc ~f:(fun acc instr ->
match instr with
| Sil.Call ((id, _), Const (Cfun callee_pname), params, _, _)
when not (is_fun_pure tenv ~is_inv_by_default callee_pname params) ->
get_vars_to_invalidate node params (InvalidatedVars.add (Var.of_id id) acc)
| _ ->
acc ) )
loop_nodes InvalidatedVars.empty
(* A variable is invariant if
- its reaching definition is outside of the loop
- o.w. its definition is constant or invariant itself *)
let get_inv_vars_in_loop tenv reaching_defs_invariant_map ~is_inv_by_default loop_head loop_nodes =
let process_var_once var inv_vars =
let process_var_once var inv_vars invalidated_vars =
(* if a variable is marked invariant once, it can't be invalidated
(i.e. invariance is monotonic) *)
if InvariantVars.mem var inv_vars || Var.is_none var then (inv_vars, false)
if
InvariantVars.mem var inv_vars || Var.is_none var || InvalidatedVars.mem var invalidated_vars
then (inv_vars, false)
else
let loop_head_id = Procdesc.Node.get_id loop_head in
ReachingDefs.Analyzer.extract_post loop_head_id reaching_defs_invariant_map
@ -117,7 +161,7 @@ let get_inv_vars_in_loop tenv reaching_defs_invariant_map ~is_inv_by_default loo
else if
(* its definition is unique and invariant *)
is_def_unique_and_satisfy tenv var def_nodes ~is_inv_by_default
(is_exp_invariant inv_vars loop_nodes reaching_defs)
(is_exp_invariant inv_vars invalidated_vars loop_nodes reaching_defs)
then (InvariantVars.add var inv_vars, true)
else (inv_vars, false) )
|> Option.value (* if a var is not declared, it must be invariant *)
@ -127,16 +171,18 @@ let get_inv_vars_in_loop tenv reaching_defs_invariant_map ~is_inv_by_default loo
let vars_in_loop = get_vars_in_loop loop_nodes in
(* until there are no changes to inv_vars, keep repeatedly
processing all the variables that occur in the loop nodes *)
let invalidated_vars = get_invalidated_vars_in_loop tenv ~is_inv_by_default loop_nodes in
let rec find_fixpoint inv_vars =
let inv_vars', modified =
InvariantVars.fold
(fun var (inv_vars, is_mod) ->
let inv_vars', is_mod' = process_var_once var inv_vars in
let inv_vars', is_mod' = process_var_once var inv_vars invalidated_vars in
(inv_vars', is_mod || is_mod') )
vars_in_loop (inv_vars, false)
in
if modified then find_fixpoint inv_vars' else inv_vars'
in
debug "\n>>> Invalidated vars: %a\n" InvalidatedVars.pp invalidated_vars ;
find_fixpoint InvariantVars.empty

@ -7,10 +7,17 @@
class HoistIndirect {
public static int svar = 0;
int[] array;
class Test {
int a = 0;
int foo(int x) {
return x + 10;
}
void set_test(Test test) {
test.a = 5;
}
@ -19,15 +26,132 @@ class HoistIndirect {
return test.a;
}
int indirect_modification_dont_hoist_FP(int size) {
int d = 0;
Test t = new Test();
int get_sum_test(Test test, int x) {
return test.a + x;
}
Test return_only(Test t) {
return t;
}
int indirect_modification_dont_hoist(int size, Test t) {
int d = 0;
for (int i = 0; i < size; i++) {
set_test(t);
d = get_test(t); // don't hoist since t changes
}
return d;
}
void variant_arg_dont_hoist(int size, Test t) {
for (int i = 0; i < size; i++) {
set_test(t); // t is invalidated
get_sum_test(
return_only(t),
size); // foo' and return_only's arguments are variant, hence don't hoist
}
;
}
// t changes deep in the call stack
int deep_modification_dont_hoist(int size) {
int d = 0;
Test t = new Test();
for (int i = 0; i < size; i++) {
indirect_modification_dont_hoist(size, t);
}
return d;
}
// foo(3) is ok to hoist, but can't detect this right now
int indirect_modification_hoist_FN(int size) {
int d = 0;
Test t = new Test();
for (int i = 0; i < size; i++) {
set_test(t); // this (and t) is invalidated here
d = foo(3); // foo becomes variant due to implicit arg. this being invalidated above
}
return d;
}
}
void set() {
svar = 5;
}
int get() {
return svar;
}
int indirect_this_modification_dont_hoist(int size) {
int d = 0;
for (int i = 0; i < size; i++) {
d = get(); // don't hoist since this.svar changes in the loop
set();
}
return d;
}
int direct_this_modification_dont_hoist_FP(int size) {
int d = 0;
for (int i = 0; i < size; i++) {
d += get(); // don't hoist since this.svar changes in the loop
svar = i;
}
return d;
}
int this_modification_outside_hoist(int size) {
int d = 0;
set();
for (int i = 0; i < size; i++) {
d += get(); // ok to hoist since set is outside
}
return d;
}
int arg_modification_hoist(int size, Test t) {
int d = 0;
for (int i = 0; i < size; i++) {
d += get(); // ok to hoist since set_test doesn't modify this
t.set_test(t);
}
return d;
}
void set_ith(int i, int[] array) {
array[i] = 0;
}
int get_ith(int i, int[] array) {
return array[i];
}
int modified_array_dont_hoist(int size, Test t) {
int d = 0;
for (int i = 0; i < size; i++) {
set_ith(i, array);
d += get_ith(size, array); // don't hoist since array changes
}
return d;
}
static int regionFirst(int[] region) {
return region[0];
}
static void incrDest(int[] source, int[] dest) {
dest[0] = source[0] + 1;
}
void nested_change_dont_hoist_FP(int[][] nextRegionM, int p, int[] tempRegion) {
for (int i = 0; i < 10; i++) {
if (i < regionFirst(nextRegionM[p])) {
incrDest(tempRegion, nextRegionM[p]);
}
}
}
}

@ -21,5 +21,14 @@ codetoanalyze/java/hoisting/Hoist.java, Hoist.used_in_loop_body_before_def_temp_
codetoanalyze/java/hoisting/Hoist.java, Hoist.void_hoist(int):void, 0, PURE_FUNCTION, no_bucket, ERROR, [Side-effect free function void Hoist.void_hoist(int)]
codetoanalyze/java/hoisting/Hoist.java, Hoist.void_hoist(int):void, 2, INVARIANT_CALL, no_bucket, ERROR, [Loop-invariant call to void Hoist.dumb_foo() at line 183]
codetoanalyze/java/hoisting/Hoist.java, Hoist.x_not_invariant_dont_hoist(int,int,int):void, 0, PURE_FUNCTION, no_bucket, ERROR, [Side-effect free function void Hoist.x_not_invariant_dont_hoist(int,int,int)]
codetoanalyze/java/hoisting/HoistIndirect.java, HoistIndirect$Test.foo(int):int, 0, PURE_FUNCTION, no_bucket, ERROR, [Side-effect free function int HoistIndirect$Test.foo(int)]
codetoanalyze/java/hoisting/HoistIndirect.java, HoistIndirect$Test.get_sum_test(HoistIndirect$Test,int):int, 0, PURE_FUNCTION, no_bucket, ERROR, [Side-effect free function int HoistIndirect$Test.get_sum_test(HoistIndirect$Test,int)]
codetoanalyze/java/hoisting/HoistIndirect.java, HoistIndirect$Test.get_test(HoistIndirect$Test):int, 0, PURE_FUNCTION, no_bucket, ERROR, [Side-effect free function int HoistIndirect$Test.get_test(HoistIndirect$Test)]
codetoanalyze/java/hoisting/HoistIndirect.java, HoistIndirect$Test.indirect_modification_dont_hoist_FP(int):int, 6, INVARIANT_CALL, no_bucket, ERROR, [Loop-invariant call to int HoistIndirect$Test.get_test(HoistIndirect$Test) at line 28]
codetoanalyze/java/hoisting/HoistIndirect.java, HoistIndirect$Test.return_only(HoistIndirect$Test):HoistIndirect$Test, 0, PURE_FUNCTION, no_bucket, ERROR, [Side-effect free function HoistIndirect$Test HoistIndirect$Test.return_only(HoistIndirect$Test)]
codetoanalyze/java/hoisting/HoistIndirect.java, HoistIndirect.arg_modification_hoist(int,HoistIndirect$Test):int, 3, INVARIANT_CALL, no_bucket, ERROR, [Loop-invariant call to int HoistIndirect.get() at line 119]
codetoanalyze/java/hoisting/HoistIndirect.java, HoistIndirect.direct_this_modification_dont_hoist_FP(int):int, 4, INVARIANT_CALL, no_bucket, ERROR, [Loop-invariant call to int HoistIndirect.get() at line 101]
codetoanalyze/java/hoisting/HoistIndirect.java, HoistIndirect.get():int, 0, PURE_FUNCTION, no_bucket, ERROR, [Side-effect free function int HoistIndirect.get()]
codetoanalyze/java/hoisting/HoistIndirect.java, HoistIndirect.get_ith(int,int[]):int, 0, PURE_FUNCTION, no_bucket, ERROR, [Side-effect free function int HoistIndirect.get_ith(int,int[])]
codetoanalyze/java/hoisting/HoistIndirect.java, HoistIndirect.nested_change_dont_hoist_FP(int[][],int,int[]):void, 2, INVARIANT_CALL, no_bucket, ERROR, [Loop-invariant call to int HoistIndirect.regionFirst(int[]) at line 152]
codetoanalyze/java/hoisting/HoistIndirect.java, HoistIndirect.regionFirst(int[]):int, 0, PURE_FUNCTION, no_bucket, ERROR, [Side-effect free function int HoistIndirect.regionFirst(int[])]
codetoanalyze/java/hoisting/HoistIndirect.java, HoistIndirect.this_modification_outside_hoist(int):int, 4, INVARIANT_CALL, no_bucket, ERROR, [Loop-invariant call to int HoistIndirect.get() at line 111]

Loading…
Cancel
Save