diff --git a/infer/src/checkers/SiofTrace.ml b/infer/src/checkers/SiofTrace.ml index df5af7cb9..a6b719c7f 100644 --- a/infer/src/checkers/SiofTrace.ml +++ b/infer/src/checkers/SiofTrace.ml @@ -15,6 +15,8 @@ module L = Logging module GlobalVar = struct include Pvar + let matches ~caller ~callee = Pvar.equal caller callee + let pp fmt v = F.fprintf fmt "%a|%a" Mangled.pp (Pvar.get_name v) Pvar.pp_translation_unit (Pvar.get_translation_unit v) diff --git a/infer/src/checkers/Source.ml b/infer/src/checkers/Source.ml index fcc51d5cb..1d2283e70 100644 --- a/infer/src/checkers/Source.ml +++ b/infer/src/checkers/Source.ml @@ -96,6 +96,8 @@ module Dummy = struct let compare = compare + let matches ~caller ~callee = Int.equal 0 (compare caller callee) + let pp = pp end diff --git a/infer/src/checkers/ThreadSafety.ml b/infer/src/checkers/ThreadSafety.ml index bdb28aee2..898f94688 100644 --- a/infer/src/checkers/ThreadSafety.ml +++ b/infer/src/checkers/ThreadSafety.ml @@ -1108,7 +1108,7 @@ let trace_of_pname orig_sink orig_pdesc callee_pname = match Summary.read_summary orig_pdesc callee_pname with | Some {accesses} -> get_all_accesses - (fun access -> Int.equal (Access.compare (PathDomain.Sink.kind access) orig_access) 0) + (fun access -> Access.matches ~caller:orig_access ~callee:(PathDomain.Sink.kind access)) accesses | _ -> PathDomain.empty diff --git a/infer/src/checkers/ThreadSafetyDomain.ml b/infer/src/checkers/ThreadSafetyDomain.ml index f1090b703..03d06ff14 100644 --- a/infer/src/checkers/ThreadSafetyDomain.ml +++ b/infer/src/checkers/ThreadSafetyDomain.ml @@ -19,6 +19,25 @@ module Access = struct | InterfaceCall of Typ.Procname.t [@@deriving compare] + let suffix_matches (_, accesses1) (_, accesses2) = + match (List.rev accesses1, List.rev accesses2) with + | access1 :: _, access2 :: _ + -> AccessPath.equal_access access1 access2 + | _ + -> false + + let matches ~caller ~callee = + match (caller, callee) with + | Read ap1, Read ap2 | Write ap1, Write ap2 + -> suffix_matches ap1 ap2 + | ContainerRead (ap1, pname1), ContainerRead (ap2, pname2) + | ContainerWrite (ap1, pname1), ContainerWrite (ap2, pname2) + -> Typ.Procname.equal pname1 pname2 && suffix_matches ap1 ap2 + | InterfaceCall pname1, InterfaceCall pname2 + -> Typ.Procname.equal pname1 pname2 + | _ + -> false + let make_field_access access_path ~is_write = if is_write then Write access_path else Read access_path diff --git a/infer/src/checkers/ThreadSafetyDomain.mli b/infer/src/checkers/ThreadSafetyDomain.mli index f08c9df94..948fdd3cc 100644 --- a/infer/src/checkers/ThreadSafetyDomain.mli +++ b/infer/src/checkers/ThreadSafetyDomain.mli @@ -20,6 +20,10 @@ module Access : sig (** Call to method of interface not annotated with @ThreadSafe *) [@@deriving compare] + val matches : caller:t -> callee:t -> bool + (** returns true if the caller access matches the callee access after accounting for mismatch + between the formals and actuals *) + val get_access_path : t -> AccessPath.t option val equal : t -> t -> bool diff --git a/infer/src/checkers/Trace.ml b/infer/src/checkers/Trace.ml index a0689ebf6..94a9d1974 100644 --- a/infer/src/checkers/Trace.ml +++ b/infer/src/checkers/Trace.ml @@ -128,10 +128,10 @@ end module Expander (TraceElem : TraceElem.S) = struct let expand elem0 ~elems_passthroughs_of_pname ~filter_passthroughs = let rec expand_ elem (elems_passthroughs_acc, seen_acc) = - let elem_site = TraceElem.call_site elem in - let elem_kind = TraceElem.kind elem in - let seen_acc' = CallSite.Set.add elem_site seen_acc in - let elems, passthroughs = elems_passthroughs_of_pname (CallSite.pname elem_site) in + let caller_elem_site = TraceElem.call_site elem in + let caller_elem_kind = TraceElem.kind elem in + let seen_acc' = CallSite.Set.add caller_elem_site seen_acc in + let elems, passthroughs = elems_passthroughs_of_pname (CallSite.pname caller_elem_site) in let is_recursive callee_elem seen = CallSite.Set.mem (TraceElem.call_site callee_elem) seen in @@ -139,7 +139,7 @@ module Expander (TraceElem : TraceElem.S) = struct let matching_elems = List.filter ~f:(fun callee_elem -> - [%compare.equal : TraceElem.Kind.t] (TraceElem.kind callee_elem) elem_kind + TraceElem.Kind.matches ~caller:caller_elem_kind ~callee:(TraceElem.kind callee_elem) && not (is_recursive callee_elem seen_acc')) elems in @@ -148,7 +148,7 @@ module Expander (TraceElem : TraceElem.S) = struct | callee_elem :: _ -> (* TODO: pick the shortest path to a sink here instead (t14242809) *) let filtered_passthroughs = - filter_passthroughs elem_site (TraceElem.call_site callee_elem) passthroughs + filter_passthroughs caller_elem_site (TraceElem.call_site callee_elem) passthroughs in expand_ callee_elem ((elem, filtered_passthroughs) :: elems_passthroughs_acc, seen_acc') | _ diff --git a/infer/src/checkers/TraceElem.ml b/infer/src/checkers/TraceElem.ml index f223b139b..adc4e4dac 100644 --- a/infer/src/checkers/TraceElem.ml +++ b/infer/src/checkers/TraceElem.ml @@ -13,6 +13,12 @@ module F = Format module type Kind = sig type t [@@deriving compare] + val matches : caller:t -> callee:t -> bool + (** Return true if the [caller] element kind matches the [callee] element kind. Used during trace + expansion; we will only consider expanding the trace from caller into callee if this + evaluates to true. This can normally just be [equal], but something fuzzier may be required + if [t] is a type that contains identifiers from the caller/callee *) + val pp : F.formatter -> t -> unit end diff --git a/infer/src/quandary/ClangTrace.ml b/infer/src/quandary/ClangTrace.ml index c2a731001..4552d7e46 100644 --- a/infer/src/quandary/ClangTrace.ml +++ b/infer/src/quandary/ClangTrace.ml @@ -20,6 +20,8 @@ module SourceKind = struct | Other (** for testing or uncategorized sources *) [@@deriving compare] + let matches ~caller ~callee = Int.equal 0 (compare caller callee) + let of_string = function | "CommandLineFlag" -> L.die UserError "User-specified CommandLineFlag sources are not supported" @@ -146,6 +148,8 @@ module SinkKind = struct | Other (** for testing or uncategorized sinks *) [@@deriving compare] + let matches ~caller ~callee = Int.equal 0 (compare caller callee) + let of_string = function | "Allocation" -> Allocation diff --git a/infer/src/quandary/JavaTrace.ml b/infer/src/quandary/JavaTrace.ml index 71bc3ea68..1587fbb32 100644 --- a/infer/src/quandary/JavaTrace.ml +++ b/infer/src/quandary/JavaTrace.ml @@ -20,6 +20,8 @@ module SourceKind = struct | UserControlledURI (** resource locator from the browser bar *) [@@deriving compare] + let matches ~caller ~callee = Int.equal 0 (compare caller callee) + let of_string = function | "Intent" -> Intent @@ -182,6 +184,8 @@ module SinkKind = struct | Other (** for testing or uncategorized sinks *) [@@deriving compare] + let matches ~caller ~callee = Int.equal 0 (compare caller callee) + let of_string = function | "CreateFile" -> CreateFile diff --git a/infer/src/unit/TaintTests.ml b/infer/src/unit/TaintTests.ml index 91fe34eb6..d0021c98c 100644 --- a/infer/src/unit/TaintTests.ml +++ b/infer/src/unit/TaintTests.ml @@ -11,7 +11,11 @@ open! IStd module F = Format module MockTrace = Trace.Make (struct - module MockTraceElem = CallSite + module MockTraceElem = struct + include CallSite + + let matches ~caller ~callee = equal caller callee + end module Source = Source.Make (struct include MockTraceElem diff --git a/infer/src/unit/TraceTests.ml b/infer/src/unit/TraceTests.ml index e91da2997..c40e420e1 100644 --- a/infer/src/unit/TraceTests.ml +++ b/infer/src/unit/TraceTests.ml @@ -14,6 +14,8 @@ module F = Format module MockTraceElem = struct type t = Kind1 | Kind2 | Footprint [@@deriving compare] + let matches ~caller ~callee = Int.equal 0 (compare caller callee) + let call_site _ = CallSite.dummy let kind t = t @@ -33,6 +35,8 @@ module MockTraceElem = struct let compare = compare + let matches = matches + let pp = pp end