diff --git a/infer/src/checkers/SinkTrace.ml b/infer/src/checkers/SinkTrace.ml index c51f3b063..c59aab475 100644 --- a/infer/src/checkers/SinkTrace.ml +++ b/infer/src/checkers/SinkTrace.ml @@ -15,14 +15,12 @@ module L = Logging module type S = sig include Trace.S - (** a path from some procedure via the given passthroughs to the given call stack, with - passthroughs for each callee *) type sink_path = Passthroughs.t * (Sink.t * Passthroughs.t) list - (** get a path for each of the reportable flows to a sink in this trace *) val get_reportable_sink_paths : t -> trace_of_pname:(Typ.Procname.t -> t) -> sink_path list - (** update sink with the given call site *) + val get_reportable_sink_path : Sink.t -> trace_of_pname:(Typ.Procname.t -> t) -> sink_path option + val with_callsite : t -> CallSite.t -> t val of_sink : Sink.t -> t @@ -72,6 +70,12 @@ module Make (TraceElem : TraceElem.S) = struct let sinks = Sinks.add sink Sinks.empty in update_sinks empty sinks + let get_reportable_sink_path sink ~trace_of_pname = + match get_reportable_sink_paths (of_sink sink) ~trace_of_pname with + | [] -> None + | [report] -> Some report + | _ -> failwithf "Should not get >1 report for 1 sink" + let pp fmt t = let pp_passthroughs_if_not_empty fmt p = if not (Passthroughs.is_empty p) then diff --git a/infer/src/checkers/SinkTrace.mli b/infer/src/checkers/SinkTrace.mli index 05ea5870f..5e6b468cc 100644 --- a/infer/src/checkers/SinkTrace.mli +++ b/infer/src/checkers/SinkTrace.mli @@ -20,6 +20,9 @@ module type S = sig (** get a path for each of the reportable flows to a sink in this trace *) val get_reportable_sink_paths : t -> trace_of_pname:(Typ.Procname.t -> t) -> sink_path list + (** get a report for a single sink *) + val get_reportable_sink_path : Sink.t -> trace_of_pname:(Typ.Procname.t -> t) -> sink_path option + (** update sink with the given call site *) val with_callsite : t -> CallSite.t -> t diff --git a/infer/src/checkers/ThreadSafety.ml b/infer/src/checkers/ThreadSafety.ml index 55f203ba7..386a3a5c0 100644 --- a/infer/src/checkers/ThreadSafety.ml +++ b/infer/src/checkers/ThreadSafety.ml @@ -902,85 +902,93 @@ let calculate_addendum_message tenv pname = else "" | _ -> "" -(* keep only the accesses of the given kind *) -let filter_by_kind access_kind trace = +let filter_by_access access_filter trace = let open ThreadSafetyDomain in - PathDomain.Sinks.filter - (fun sink -> phys_equal access_kind (snd (TraceElem.kind sink))) - (PathDomain.sinks trace) + PathDomain.Sinks.filter access_filter (PathDomain.sinks trace) |> PathDomain.update_sinks trace -let get_all_accesses_with_pre pre_filter access_kind accesses = +(* keep only the accesses of the given kind *) +let filter_by_kind access_kind trace = + filter_by_access + (fun sink -> phys_equal access_kind (snd (ThreadSafetyDomain.TraceElem.kind sink))) + trace + +let get_all_accesses_with_pre pre_filter access_filter accesses = let open ThreadSafetyDomain in AccessDomain.fold (fun pre trace acc -> if pre_filter pre - then PathDomain.join (filter_by_kind access_kind trace) acc + then PathDomain.join (filter_by_access access_filter trace) acc else acc) accesses PathDomain.empty -(* get all of the unprotected accesses of the given kind *) -let get_unprotected_accesses = - get_all_accesses_with_pre - (function ThreadSafetyDomain.AccessPrecondition.Unprotected _ -> true | Protected -> false) - let get_all_accesses = get_all_accesses_with_pre (fun _ -> true) -let get_possibly_unsafe_reads = get_unprotected_accesses Read - -let get_possibly_unsafe_writes = get_unprotected_accesses Write +let pp_sink fmt sink = + let sink_pname = CallSite.pname (ThreadSafetyDomain.PathDomain.Sink.call_site sink) in + if Typ.Procname.equal sink_pname Typ.Procname.empty_block + then + let access_path, _ = ThreadSafetyDomain.PathDomain.Sink.kind sink in + F.fprintf fmt "access to %a" AccessPath.pp_access_list (snd access_path) + else + F.fprintf fmt "call to %a" Typ.Procname.pp sink_pname + +let desc_of_sink final_sink_site sink = + if CallSite.equal (ThreadSafetyDomain.PathDomain.Sink.call_site sink) final_sink_site && + is_container_write_sink sink + then + let access_path, _ = ThreadSafetyDomain.PathDomain.Sink.kind sink in + F.asprintf + "access to container %a" + AccessPath.pp_access_list (snd (AccessPath.Raw.truncate access_path)) + else + F.asprintf "%a" pp_sink sink + +let trace_of_pname orig_sink orig_pdesc callee_pname = + let open ThreadSafetyDomain in + let orig_access = PathDomain.Sink.kind orig_sink in + match Summary.read_summary orig_pdesc callee_pname with + | Some (_, _, access_map, _) -> + get_all_accesses + (fun access -> + Int.equal (Access.compare (PathDomain.Sink.kind access) orig_access) 0) + access_map + | _ -> + PathDomain.empty -(*A helper function used in the error reporting*) -let pp_accesses_sink fmt ~is_write_access sink = - let access_path, _ = ThreadSafetyDomain.PathDomain.Sink.kind sink in - let container_write = is_write_access && is_container_write_sink sink in - F.fprintf fmt - (if container_write then "container %a" else "%a") - AccessPath.pp_access_list - (if container_write - then snd (AccessPath.Raw.truncate access_path) - else snd access_path) +let make_trace ((_, sinks) as path) = + let open ThreadSafetyDomain in + let final_sink, _ = List.hd_exn sinks in + PathDomain.to_sink_loc_trace + ~desc_of_sink:(desc_of_sink (PathDomain.Sink.call_site final_sink)) + path -(* trace is really a set of accesses*) -let report_thread_safety_violations tenv pdesc ~get_unsafe_accesses ~make_description trace = +let report_thread_safety_violation tenv pdesc ~make_description access = let open ThreadSafetyDomain in let pname = Procdesc.get_proc_name pdesc in - let trace_of_pname callee_pname = - match Summary.read_summary pdesc callee_pname with - | Some (_, _, accesses, _) -> get_unsafe_accesses accesses - | _ -> PathDomain.empty in let report_one_path ((_, sinks) as path) = let initial_sink, _ = List.last_exn sinks in let final_sink, _ = List.hd_exn sinks in let initial_sink_site = PathDomain.Sink.call_site initial_sink in let final_sink_site = PathDomain.Sink.call_site final_sink in - let desc_of_sink sink = - if - CallSite.equal (PathDomain.Sink.call_site sink) final_sink_site - then - Format.asprintf "access to %a" (pp_accesses_sink ~is_write_access:true) sink - else - Format.asprintf - "call to %a" Typ.Procname.pp (CallSite.pname (PathDomain.Sink.call_site sink)) in - let loc = CallSite.loc (PathDomain.Sink.call_site initial_sink) in - let ltr = PathDomain.to_sink_loc_trace ~desc_of_sink path in + let loc = CallSite.loc initial_sink_site in + let ltr = make_trace path in let msg = Localise.to_issue_id Localise.thread_safety_violation in let description = make_description tenv pname final_sink_site initial_sink_site final_sink in let exn = Exceptions.Checkers (msg, Localise.verbatim_desc description) in Reporting.log_error pname ~loc ~ltr exn in - List.iter - ~f:report_one_path - (PathDomain.get_reportable_sink_paths trace ~trace_of_pname) + let trace_of_pname = trace_of_pname access pdesc in + Option.iter ~f:report_one_path (PathDomain.get_reportable_sink_path access ~trace_of_pname) let make_unprotected_write_description tenv pname final_sink_site initial_sink_site final_sink = Format.asprintf "Unprotected write. Public method %a%s %s %a outside of synchronization.%s" (MF.wrap_monospaced Typ.Procname.pp) pname (if CallSite.equal final_sink_site initial_sink_site then "" else " indirectly") - (if is_container_write_sink final_sink then "mutates" else "writes to field") - (MF.wrap_monospaced (pp_accesses_sink ~is_write_access:true)) final_sink + (if is_container_write_sink final_sink then "mutates" else "writes to field") + (MF.wrap_monospaced pp_sink) final_sink (calculate_addendum_message tenv pname) let make_read_write_race_description @@ -1006,7 +1014,7 @@ let make_read_write_race_description Format.asprintf "Read/Write race. Public method %a%s reads from field %a. %s %s" (MF.wrap_monospaced Typ.Procname.pp) pname (if CallSite.equal final_sink_site initial_sink_site then "" else " indirectly") - (MF.wrap_monospaced (pp_accesses_sink ~is_write_access:false)) final_sink + (MF.wrap_monospaced pp_sink) final_sink conflicts_description (calculate_addendum_message tenv pname) @@ -1069,12 +1077,11 @@ let report_unsafe_accesses ~is_file_threadsafe aggregated_access_map = else begin (* unprotected write. warn. *) - report_thread_safety_violations + report_thread_safety_violation tenv pdesc - ~get_unsafe_accesses:get_possibly_unsafe_writes ~make_description:make_unprotected_write_description - (PathDomain.of_sink access); + access; update_reported access pname reported_acc end | Access.Write, AccessPrecondition.Protected -> @@ -1092,12 +1099,11 @@ let report_unsafe_accesses ~is_file_threadsafe aggregated_access_map = reported_acc else begin - report_thread_safety_violations + report_thread_safety_violation tenv pdesc - ~get_unsafe_accesses:get_possibly_unsafe_reads ~make_description:(make_read_write_race_description all_writes) - (PathDomain.of_sink access); + access; update_reported access pname reported_acc end | Access.Read, AccessPrecondition.Protected -> @@ -1117,12 +1123,11 @@ let report_unsafe_accesses ~is_file_threadsafe aggregated_access_map = else begin (* protected read with conflicting unprotected write(s). warn. *) - report_thread_safety_violations + report_thread_safety_violation tenv pdesc - ~get_unsafe_accesses:(get_all_accesses Read) ~make_description:(make_read_write_race_description unprotected_writes) - (PathDomain.of_sink access); + access; update_reported access pname reported_acc end in AccessListMap.fold