[thread-safety] refactor trace generation logic

Summary: Prereq for reporting a call stack for both the read and write in a read/write race.

Reviewed By: peterogithub

Differential Revision: D4845603

fbshipit-source-id: ebfeb9b
master
Sam Blackshear 8 years ago committed by Facebook Github Bot
parent 7cef8ae3b5
commit ede9a31089

@ -15,14 +15,12 @@ module L = Logging
module type S = sig module type S = sig
include Trace.S include Trace.S
(** a path from some procedure via the given passthroughs to the given call stack, with
passthroughs for each callee *)
type sink_path = Passthroughs.t * (Sink.t * Passthroughs.t) list type sink_path = Passthroughs.t * (Sink.t * Passthroughs.t) list
(** get a path for each of the reportable flows to a sink in this trace *)
val get_reportable_sink_paths : t -> trace_of_pname:(Typ.Procname.t -> t) -> sink_path list val get_reportable_sink_paths : t -> trace_of_pname:(Typ.Procname.t -> t) -> sink_path list
(** update sink with the given call site *) val get_reportable_sink_path : Sink.t -> trace_of_pname:(Typ.Procname.t -> t) -> sink_path option
val with_callsite : t -> CallSite.t -> t val with_callsite : t -> CallSite.t -> t
val of_sink : Sink.t -> t val of_sink : Sink.t -> t
@ -72,6 +70,12 @@ module Make (TraceElem : TraceElem.S) = struct
let sinks = Sinks.add sink Sinks.empty in let sinks = Sinks.add sink Sinks.empty in
update_sinks empty sinks update_sinks empty sinks
let get_reportable_sink_path sink ~trace_of_pname =
match get_reportable_sink_paths (of_sink sink) ~trace_of_pname with
| [] -> None
| [report] -> Some report
| _ -> failwithf "Should not get >1 report for 1 sink"
let pp fmt t = let pp fmt t =
let pp_passthroughs_if_not_empty fmt p = let pp_passthroughs_if_not_empty fmt p =
if not (Passthroughs.is_empty p) then if not (Passthroughs.is_empty p) then

@ -20,6 +20,9 @@ module type S = sig
(** get a path for each of the reportable flows to a sink in this trace *) (** get a path for each of the reportable flows to a sink in this trace *)
val get_reportable_sink_paths : t -> trace_of_pname:(Typ.Procname.t -> t) -> sink_path list val get_reportable_sink_paths : t -> trace_of_pname:(Typ.Procname.t -> t) -> sink_path list
(** get a report for a single sink *)
val get_reportable_sink_path : Sink.t -> trace_of_pname:(Typ.Procname.t -> t) -> sink_path option
(** update sink with the given call site *) (** update sink with the given call site *)
val with_callsite : t -> CallSite.t -> t val with_callsite : t -> CallSite.t -> t

@ -902,85 +902,93 @@ let calculate_addendum_message tenv pname =
else "" else ""
| _ -> "" | _ -> ""
(* keep only the accesses of the given kind *) let filter_by_access access_filter trace =
let filter_by_kind access_kind trace =
let open ThreadSafetyDomain in let open ThreadSafetyDomain in
PathDomain.Sinks.filter PathDomain.Sinks.filter access_filter (PathDomain.sinks trace)
(fun sink -> phys_equal access_kind (snd (TraceElem.kind sink)))
(PathDomain.sinks trace)
|> PathDomain.update_sinks trace |> PathDomain.update_sinks trace
let get_all_accesses_with_pre pre_filter access_kind accesses = (* keep only the accesses of the given kind *)
let filter_by_kind access_kind trace =
filter_by_access
(fun sink -> phys_equal access_kind (snd (ThreadSafetyDomain.TraceElem.kind sink)))
trace
let get_all_accesses_with_pre pre_filter access_filter accesses =
let open ThreadSafetyDomain in let open ThreadSafetyDomain in
AccessDomain.fold AccessDomain.fold
(fun pre trace acc -> (fun pre trace acc ->
if pre_filter pre if pre_filter pre
then PathDomain.join (filter_by_kind access_kind trace) acc then PathDomain.join (filter_by_access access_filter trace) acc
else acc) else acc)
accesses accesses
PathDomain.empty PathDomain.empty
(* get all of the unprotected accesses of the given kind *)
let get_unprotected_accesses =
get_all_accesses_with_pre
(function ThreadSafetyDomain.AccessPrecondition.Unprotected _ -> true | Protected -> false)
let get_all_accesses = get_all_accesses_with_pre (fun _ -> true) let get_all_accesses = get_all_accesses_with_pre (fun _ -> true)
let get_possibly_unsafe_reads = get_unprotected_accesses Read let pp_sink fmt sink =
let sink_pname = CallSite.pname (ThreadSafetyDomain.PathDomain.Sink.call_site sink) in
let get_possibly_unsafe_writes = get_unprotected_accesses Write if Typ.Procname.equal sink_pname Typ.Procname.empty_block
then
let access_path, _ = ThreadSafetyDomain.PathDomain.Sink.kind sink in
F.fprintf fmt "access to %a" AccessPath.pp_access_list (snd access_path)
else
F.fprintf fmt "call to %a" Typ.Procname.pp sink_pname
let desc_of_sink final_sink_site sink =
if CallSite.equal (ThreadSafetyDomain.PathDomain.Sink.call_site sink) final_sink_site &&
is_container_write_sink sink
then
let access_path, _ = ThreadSafetyDomain.PathDomain.Sink.kind sink in
F.asprintf
"access to container %a"
AccessPath.pp_access_list (snd (AccessPath.Raw.truncate access_path))
else
F.asprintf "%a" pp_sink sink
let trace_of_pname orig_sink orig_pdesc callee_pname =
let open ThreadSafetyDomain in
let orig_access = PathDomain.Sink.kind orig_sink in
match Summary.read_summary orig_pdesc callee_pname with
| Some (_, _, access_map, _) ->
get_all_accesses
(fun access ->
Int.equal (Access.compare (PathDomain.Sink.kind access) orig_access) 0)
access_map
| _ ->
PathDomain.empty
(*A helper function used in the error reporting*) let make_trace ((_, sinks) as path) =
let pp_accesses_sink fmt ~is_write_access sink = let open ThreadSafetyDomain in
let access_path, _ = ThreadSafetyDomain.PathDomain.Sink.kind sink in let final_sink, _ = List.hd_exn sinks in
let container_write = is_write_access && is_container_write_sink sink in PathDomain.to_sink_loc_trace
F.fprintf fmt ~desc_of_sink:(desc_of_sink (PathDomain.Sink.call_site final_sink))
(if container_write then "container %a" else "%a") path
AccessPath.pp_access_list
(if container_write
then snd (AccessPath.Raw.truncate access_path)
else snd access_path)
(* trace is really a set of accesses*) let report_thread_safety_violation tenv pdesc ~make_description access =
let report_thread_safety_violations tenv pdesc ~get_unsafe_accesses ~make_description trace =
let open ThreadSafetyDomain in let open ThreadSafetyDomain in
let pname = Procdesc.get_proc_name pdesc in let pname = Procdesc.get_proc_name pdesc in
let trace_of_pname callee_pname =
match Summary.read_summary pdesc callee_pname with
| Some (_, _, accesses, _) -> get_unsafe_accesses accesses
| _ -> PathDomain.empty in
let report_one_path ((_, sinks) as path) = let report_one_path ((_, sinks) as path) =
let initial_sink, _ = List.last_exn sinks in let initial_sink, _ = List.last_exn sinks in
let final_sink, _ = List.hd_exn sinks in let final_sink, _ = List.hd_exn sinks in
let initial_sink_site = PathDomain.Sink.call_site initial_sink in let initial_sink_site = PathDomain.Sink.call_site initial_sink in
let final_sink_site = PathDomain.Sink.call_site final_sink in let final_sink_site = PathDomain.Sink.call_site final_sink in
let desc_of_sink sink = let loc = CallSite.loc initial_sink_site in
if let ltr = make_trace path in
CallSite.equal (PathDomain.Sink.call_site sink) final_sink_site
then
Format.asprintf "access to %a" (pp_accesses_sink ~is_write_access:true) sink
else
Format.asprintf
"call to %a" Typ.Procname.pp (CallSite.pname (PathDomain.Sink.call_site sink)) in
let loc = CallSite.loc (PathDomain.Sink.call_site initial_sink) in
let ltr = PathDomain.to_sink_loc_trace ~desc_of_sink path in
let msg = Localise.to_issue_id Localise.thread_safety_violation in let msg = Localise.to_issue_id Localise.thread_safety_violation in
let description = make_description tenv pname final_sink_site initial_sink_site final_sink in let description = make_description tenv pname final_sink_site initial_sink_site final_sink in
let exn = Exceptions.Checkers (msg, Localise.verbatim_desc description) in let exn = Exceptions.Checkers (msg, Localise.verbatim_desc description) in
Reporting.log_error pname ~loc ~ltr exn in Reporting.log_error pname ~loc ~ltr exn in
List.iter let trace_of_pname = trace_of_pname access pdesc in
~f:report_one_path Option.iter ~f:report_one_path (PathDomain.get_reportable_sink_path access ~trace_of_pname)
(PathDomain.get_reportable_sink_paths trace ~trace_of_pname)
let make_unprotected_write_description tenv pname final_sink_site initial_sink_site final_sink = let make_unprotected_write_description tenv pname final_sink_site initial_sink_site final_sink =
Format.asprintf Format.asprintf
"Unprotected write. Public method %a%s %s %a outside of synchronization.%s" "Unprotected write. Public method %a%s %s %a outside of synchronization.%s"
(MF.wrap_monospaced Typ.Procname.pp) pname (MF.wrap_monospaced Typ.Procname.pp) pname
(if CallSite.equal final_sink_site initial_sink_site then "" else " indirectly") (if CallSite.equal final_sink_site initial_sink_site then "" else " indirectly")
(if is_container_write_sink final_sink then "mutates" else "writes to field") (if is_container_write_sink final_sink then "mutates" else "writes to field")
(MF.wrap_monospaced (pp_accesses_sink ~is_write_access:true)) final_sink (MF.wrap_monospaced pp_sink) final_sink
(calculate_addendum_message tenv pname) (calculate_addendum_message tenv pname)
let make_read_write_race_description let make_read_write_race_description
@ -1006,7 +1014,7 @@ let make_read_write_race_description
Format.asprintf "Read/Write race. Public method %a%s reads from field %a. %s %s" Format.asprintf "Read/Write race. Public method %a%s reads from field %a. %s %s"
(MF.wrap_monospaced Typ.Procname.pp) pname (MF.wrap_monospaced Typ.Procname.pp) pname
(if CallSite.equal final_sink_site initial_sink_site then "" else " indirectly") (if CallSite.equal final_sink_site initial_sink_site then "" else " indirectly")
(MF.wrap_monospaced (pp_accesses_sink ~is_write_access:false)) final_sink (MF.wrap_monospaced pp_sink) final_sink
conflicts_description conflicts_description
(calculate_addendum_message tenv pname) (calculate_addendum_message tenv pname)
@ -1069,12 +1077,11 @@ let report_unsafe_accesses ~is_file_threadsafe aggregated_access_map =
else else
begin begin
(* unprotected write. warn. *) (* unprotected write. warn. *)
report_thread_safety_violations report_thread_safety_violation
tenv tenv
pdesc pdesc
~get_unsafe_accesses:get_possibly_unsafe_writes
~make_description:make_unprotected_write_description ~make_description:make_unprotected_write_description
(PathDomain.of_sink access); access;
update_reported access pname reported_acc update_reported access pname reported_acc
end end
| Access.Write, AccessPrecondition.Protected -> | Access.Write, AccessPrecondition.Protected ->
@ -1092,12 +1099,11 @@ let report_unsafe_accesses ~is_file_threadsafe aggregated_access_map =
reported_acc reported_acc
else else
begin begin
report_thread_safety_violations report_thread_safety_violation
tenv tenv
pdesc pdesc
~get_unsafe_accesses:get_possibly_unsafe_reads
~make_description:(make_read_write_race_description all_writes) ~make_description:(make_read_write_race_description all_writes)
(PathDomain.of_sink access); access;
update_reported access pname reported_acc update_reported access pname reported_acc
end end
| Access.Read, AccessPrecondition.Protected -> | Access.Read, AccessPrecondition.Protected ->
@ -1117,12 +1123,11 @@ let report_unsafe_accesses ~is_file_threadsafe aggregated_access_map =
else else
begin begin
(* protected read with conflicting unprotected write(s). warn. *) (* protected read with conflicting unprotected write(s). warn. *)
report_thread_safety_violations report_thread_safety_violation
tenv tenv
pdesc pdesc
~get_unsafe_accesses:(get_all_accesses Read)
~make_description:(make_read_write_race_description unprotected_writes) ~make_description:(make_read_write_race_description unprotected_writes)
(PathDomain.of_sink access); access;
update_reported access pname reported_acc update_reported access pname reported_acc
end in end in
AccessListMap.fold AccessListMap.fold

Loading…
Cancel
Save