diff --git a/infer/src/checkers/Litho.ml b/infer/src/checkers/Litho.ml index 797c83b42..07c7fdf1b 100644 --- a/infer/src/checkers/Litho.ml +++ b/infer/src/checkers/Litho.ml @@ -90,7 +90,7 @@ module TransferFunctions (CFG : ProcCfg.S) = struct let return_access_path = Domain.LocalAccessPath.make (return_base, []) caller_pname in let return_calls = (try Domain.find return_access_path astate with Not_found -> Domain.CallSet.empty) - |> Domain.CallSet.add {receiver; procname= callee_procname} + |> Domain.CallSet.add (Domain.MethodCall.make receiver callee_procname) in Domain.add return_access_path return_calls astate else @@ -126,27 +126,6 @@ let report access_path call_chain summary = Reporting.log_error summary ~loc ~ltr exn -let unroll_call call astate summary = - let max_depth = Domain.cardinal astate in - let rec unroll_call_ ({receiver; procname}: Domain.MethodCall.t) (acc, depth) = - let acc' = procname :: acc in - let depth' = depth + 1 in - let is_cycle access_path = - (* detect direct cycles and cycles due to mutual recursion *) - Domain.LocalAccessPath.equal access_path receiver || depth' > max_depth - in - try - let calls' = Domain.find receiver astate in - Domain.CallSet.iter - (fun call -> - if not (is_cycle call.receiver) then unroll_call_ call (acc', depth') - else report receiver.access_path acc' summary ) - calls' - with Not_found -> report receiver.access_path acc' summary - in - unroll_call_ call ([], 0) - - let should_report proc_desc = match Procdesc.get_proc_name proc_desc with | Typ.Procname.Java java_pname -> ( @@ -155,10 +134,14 @@ let should_report proc_desc = false -let report_call_chains post summary = - Domain.iter - (fun _ call_set -> Domain.CallSet.iter (fun call -> unroll_call call post summary) call_set) - post +let report_graphql_getters summary access_path call_chain = + let call_strings = List.map ~f:(Typ.Procname.to_simplified_string ~withclass:false) call_chain in + let call_string = String.concat ~sep:"." call_strings in + let message = F.asprintf "%a.%s" AccessPath.pp access_path call_string in + let exn = Exceptions.Checkers (IssueType.graphql_field_access, Localise.verbatim_desc message) in + let loc = Specs.get_loc summary in + let ltr = [Errlog.make_trace_element 0 loc message []] in + Reporting.log_error summary ~loc ~ltr exn let postprocess astate proc_desc : Domain.astate = @@ -171,7 +154,9 @@ let checker {Callbacks.summary; proc_desc; tenv} = let proc_data = ProcData.make_default proc_desc tenv in match Analyzer.compute_post proc_data ~initial:Domain.empty with | Some post -> - if should_report proc_desc then report_call_chains post summary ; + ( if should_report proc_desc then + let f = report_graphql_getters summary in + Domain.iter_call_chains ~f post ) ; let payload = postprocess post proc_desc in Summary.update_summary payload summary | None -> diff --git a/infer/src/checkers/LithoDomain.ml b/infer/src/checkers/LithoDomain.ml index 6606f9c05..e494ff3b6 100644 --- a/infer/src/checkers/LithoDomain.ml +++ b/infer/src/checkers/LithoDomain.ml @@ -11,7 +11,6 @@ open! IStd module F = Format module L = Logging -(** Access path + its parent procedure *) module LocalAccessPath = struct type t = {access_path: AccessPath.t; parent: Typ.Procname.t} [@@deriving compare] @@ -30,10 +29,11 @@ module LocalAccessPath = struct let pp fmt t = AccessPath.pp fmt t.access_path end -(** Called procedure + it's receiver *) module MethodCall = struct type t = {receiver: LocalAccessPath.t; procname: Typ.Procname.t} [@@deriving compare] + let make receiver procname = {receiver; procname} + let pp fmt {receiver; procname} = F.fprintf fmt "%a.%a" LocalAccessPath.pp receiver Typ.Procname.pp procname end @@ -62,3 +62,31 @@ let substitute ~(f_sub: LocalAccessPath.t -> LocalAccessPath.t option) astate = in add access_path' call_set' acc ) astate empty + + +let iter_call_chains_with_suffix ~f call_suffix astate = + let max_depth = cardinal astate in + let rec unroll_call_ ({receiver; procname}: MethodCall.t) (acc, depth) = + let acc' = procname :: acc in + let depth' = depth + 1 in + let is_cycle access_path = + (* detect direct cycles and cycles due to mutual recursion *) + LocalAccessPath.equal access_path receiver || depth' > max_depth + in + try + let calls' = find receiver astate in + CallSet.iter + (fun call -> + if not (is_cycle call.receiver) then unroll_call_ call (acc', depth') + else f receiver.access_path acc' ) + calls' + with Not_found -> f receiver.access_path acc' + in + unroll_call_ call_suffix ([], 0) + + +let iter_call_chains ~f astate = + iter + (fun _ call_set -> + CallSet.iter (fun call -> iter_call_chains_with_suffix ~f call astate) call_set ) + astate diff --git a/infer/src/checkers/LithoDomain.mli b/infer/src/checkers/LithoDomain.mli new file mode 100644 index 000000000..fef9d06b9 --- /dev/null +++ b/infer/src/checkers/LithoDomain.mli @@ -0,0 +1,48 @@ +(* + * Copyright (c) 2018 - present Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + *) + +open! IStd +module F = Format + +(** Access path + its parent procedure *) +module LocalAccessPath : sig + type t = private {access_path: AccessPath.t; parent: Typ.Procname.t} [@@deriving compare] + + val make : AccessPath.t -> Typ.Procname.t -> t + + val to_formal_option : t -> FormalMap.t -> t option + + val pp : F.formatter -> t -> unit +end + +(** Called procedure + its receiver *) +module MethodCall : sig + type t = private {receiver: LocalAccessPath.t; procname: Typ.Procname.t} [@@deriving compare] + + val make : LocalAccessPath.t -> Typ.Procname.t -> t + + val pp : F.formatter -> t -> unit +end + +module CallSet : module type of AbstractDomain.FiniteSet (MethodCall) + +include module type of AbstractDomain.Map (LocalAccessPath) (CallSet) + +val substitute : f_sub:(LocalAccessPath.t -> LocalAccessPath.t option) -> astate -> astate +(** Substitute each access path in the domain using [f_sub]. If [f_sub] returns None, the + original access path is retained; otherwise, the new one is used *) + +val iter_call_chains_with_suffix : + f:(AccessPath.t -> Typ.Procname.t list -> unit) -> MethodCall.t -> astate -> unit +(** Unroll the domain to enumerate all the call chains ending in [call] and apply [f] to each + maximal chain. For example, if the domain encodes the chains foo().bar().goo() and foo().baz(), + [f] will be called once on foo().bar().goo() and once on foo().baz() *) + +val iter_call_chains : f:(AccessPath.t -> Typ.Procname.t list -> unit) -> astate -> unit +(** Apply [f] to each maximal call chain encoded in [astate] *)