[litho] utility function for iterating over all call chains encoded in the domain

Summary:
and add mli. We already had the logic for iterating over call chains, but it was overfitted to the should-update analysis.
Will use the generalized version in a follow-up.

Reviewed By: jvillard

Differential Revision: D6740692

fbshipit-source-id: 8c0d89f
master
Sam Blackshear 7 years ago committed by Facebook Github Bot
parent 4ad80615ef
commit 360151eb10

@ -90,7 +90,7 @@ module TransferFunctions (CFG : ProcCfg.S) = struct
let return_access_path = Domain.LocalAccessPath.make (return_base, []) caller_pname in
let return_calls =
(try Domain.find return_access_path astate with Not_found -> Domain.CallSet.empty)
|> Domain.CallSet.add {receiver; procname= callee_procname}
|> Domain.CallSet.add (Domain.MethodCall.make receiver callee_procname)
in
Domain.add return_access_path return_calls astate
else
@ -126,27 +126,6 @@ let report access_path call_chain summary =
Reporting.log_error summary ~loc ~ltr exn
let unroll_call call astate summary =
let max_depth = Domain.cardinal astate in
let rec unroll_call_ ({receiver; procname}: Domain.MethodCall.t) (acc, depth) =
let acc' = procname :: acc in
let depth' = depth + 1 in
let is_cycle access_path =
(* detect direct cycles and cycles due to mutual recursion *)
Domain.LocalAccessPath.equal access_path receiver || depth' > max_depth
in
try
let calls' = Domain.find receiver astate in
Domain.CallSet.iter
(fun call ->
if not (is_cycle call.receiver) then unroll_call_ call (acc', depth')
else report receiver.access_path acc' summary )
calls'
with Not_found -> report receiver.access_path acc' summary
in
unroll_call_ call ([], 0)
let should_report proc_desc =
match Procdesc.get_proc_name proc_desc with
| Typ.Procname.Java java_pname -> (
@ -155,10 +134,14 @@ let should_report proc_desc =
false
let report_call_chains post summary =
Domain.iter
(fun _ call_set -> Domain.CallSet.iter (fun call -> unroll_call call post summary) call_set)
post
let report_graphql_getters summary access_path call_chain =
let call_strings = List.map ~f:(Typ.Procname.to_simplified_string ~withclass:false) call_chain in
let call_string = String.concat ~sep:"." call_strings in
let message = F.asprintf "%a.%s" AccessPath.pp access_path call_string in
let exn = Exceptions.Checkers (IssueType.graphql_field_access, Localise.verbatim_desc message) in
let loc = Specs.get_loc summary in
let ltr = [Errlog.make_trace_element 0 loc message []] in
Reporting.log_error summary ~loc ~ltr exn
let postprocess astate proc_desc : Domain.astate =
@ -171,7 +154,9 @@ let checker {Callbacks.summary; proc_desc; tenv} =
let proc_data = ProcData.make_default proc_desc tenv in
match Analyzer.compute_post proc_data ~initial:Domain.empty with
| Some post ->
if should_report proc_desc then report_call_chains post summary ;
( if should_report proc_desc then
let f = report_graphql_getters summary in
Domain.iter_call_chains ~f post ) ;
let payload = postprocess post proc_desc in
Summary.update_summary payload summary
| None ->

@ -11,7 +11,6 @@ open! IStd
module F = Format
module L = Logging
(** Access path + its parent procedure *)
module LocalAccessPath = struct
type t = {access_path: AccessPath.t; parent: Typ.Procname.t} [@@deriving compare]
@ -30,10 +29,11 @@ module LocalAccessPath = struct
let pp fmt t = AccessPath.pp fmt t.access_path
end
(** Called procedure + it's receiver *)
module MethodCall = struct
type t = {receiver: LocalAccessPath.t; procname: Typ.Procname.t} [@@deriving compare]
let make receiver procname = {receiver; procname}
let pp fmt {receiver; procname} =
F.fprintf fmt "%a.%a" LocalAccessPath.pp receiver Typ.Procname.pp procname
end
@ -62,3 +62,31 @@ let substitute ~(f_sub: LocalAccessPath.t -> LocalAccessPath.t option) astate =
in
add access_path' call_set' acc )
astate empty
let iter_call_chains_with_suffix ~f call_suffix astate =
let max_depth = cardinal astate in
let rec unroll_call_ ({receiver; procname}: MethodCall.t) (acc, depth) =
let acc' = procname :: acc in
let depth' = depth + 1 in
let is_cycle access_path =
(* detect direct cycles and cycles due to mutual recursion *)
LocalAccessPath.equal access_path receiver || depth' > max_depth
in
try
let calls' = find receiver astate in
CallSet.iter
(fun call ->
if not (is_cycle call.receiver) then unroll_call_ call (acc', depth')
else f receiver.access_path acc' )
calls'
with Not_found -> f receiver.access_path acc'
in
unroll_call_ call_suffix ([], 0)
let iter_call_chains ~f astate =
iter
(fun _ call_set ->
CallSet.iter (fun call -> iter_call_chains_with_suffix ~f call astate) call_set )
astate

@ -0,0 +1,48 @@
(*
* Copyright (c) 2018 - present Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*)
open! IStd
module F = Format
(** Access path + its parent procedure *)
module LocalAccessPath : sig
type t = private {access_path: AccessPath.t; parent: Typ.Procname.t} [@@deriving compare]
val make : AccessPath.t -> Typ.Procname.t -> t
val to_formal_option : t -> FormalMap.t -> t option
val pp : F.formatter -> t -> unit
end
(** Called procedure + its receiver *)
module MethodCall : sig
type t = private {receiver: LocalAccessPath.t; procname: Typ.Procname.t} [@@deriving compare]
val make : LocalAccessPath.t -> Typ.Procname.t -> t
val pp : F.formatter -> t -> unit
end
module CallSet : module type of AbstractDomain.FiniteSet (MethodCall)
include module type of AbstractDomain.Map (LocalAccessPath) (CallSet)
val substitute : f_sub:(LocalAccessPath.t -> LocalAccessPath.t option) -> astate -> astate
(** Substitute each access path in the domain using [f_sub]. If [f_sub] returns None, the
original access path is retained; otherwise, the new one is used *)
val iter_call_chains_with_suffix :
f:(AccessPath.t -> Typ.Procname.t list -> unit) -> MethodCall.t -> astate -> unit
(** Unroll the domain to enumerate all the call chains ending in [call] and apply [f] to each
maximal chain. For example, if the domain encodes the chains foo().bar().goo() and foo().baz(),
[f] will be called once on foo().bar().goo() and once on foo().baz() *)
val iter_call_chains : f:(AccessPath.t -> Typ.Procname.t list -> unit) -> astate -> unit
(** Apply [f] to each maximal call chain encoded in [astate] *)
Loading…
Cancel
Save