[thread-safety] refactor ThreadSafetyDomain.Access to make it easier to add new access kinds

Summary: Making it simple to add a new access type for "un-annotated interface call" in an upcoming diff.

Reviewed By: da319

Differential Revision: D5445914

fbshipit-source-id: f29e342
master
Sam Blackshear 8 years ago committed by Facebook Github Bot
parent 30e1f4295b
commit 03120a337e

@ -29,23 +29,28 @@ let container_write_string = "infer.dummy.__CONTAINERWRITE__"
(* return (name of container, name of mutating function call) pair *) (* return (name of container, name of mutating function call) pair *)
let get_container_write_desc sink = let get_container_write_desc sink =
let (base_var, _), access_list = fst (ThreadSafetyDomain.TraceElem.kind sink) in match ThreadSafetyDomain.TraceElem.kind sink with
let get_container_write_desc_ call_name container_name = | Write ((base_var, _), access_list)
match -> (
String.chop_prefix (Typ.Fieldname.to_string call_name) ~prefix:container_write_string let get_container_write_desc_ call_name container_name =
with match
| Some call_name String.chop_prefix (Typ.Fieldname.to_string call_name) ~prefix:container_write_string
-> Some (container_name, call_name) with
| None | Some call_name
-> None -> Some (container_name, call_name)
in | None
match List.rev access_list with -> None
| (FieldAccess call_name) :: (FieldAccess container_name) :: _ in
-> get_container_write_desc_ call_name (Typ.Fieldname.to_string container_name) match List.rev access_list with
| [(FieldAccess call_name)] | (FieldAccess call_name) :: (FieldAccess container_name) :: _
-> get_container_write_desc_ call_name (F.asprintf "%a" Var.pp base_var) -> get_container_write_desc_ call_name (Typ.Fieldname.to_string container_name)
| _ | [(FieldAccess call_name)]
-> None -> get_container_write_desc_ call_name (F.asprintf "%a" Var.pp base_var)
| _
-> None )
| Read _
-> (* TODO: support Read *)
None
let is_container_write_sink sink = Option.is_some (get_container_write_desc sink) let is_container_write_sink sink = Option.is_some (get_container_write_desc sink)
@ -259,7 +264,7 @@ module TransferFunctions (CFG : ProcCfg.S) = struct
let is_unprotected is_locked is_threaded pdesc = let is_unprotected is_locked is_threaded pdesc =
not is_locked && not is_threaded && not (Procdesc.is_java_synchronized pdesc) not is_locked && not is_threaded && not (Procdesc.is_java_synchronized pdesc)
let add_access exp loc access_kind accesses locks threads attribute_map let add_access exp loc ~is_write_access accesses locks threads attribute_map
(proc_data: FormalMap.t ProcData.t) = (proc_data: FormalMap.t ProcData.t) =
let open Domain in let open Domain in
(* we don't want to warn on accesses to the field if it is (a) thread-confined, or (* we don't want to warn on accesses to the field if it is (a) thread-confined, or
@ -282,9 +287,7 @@ module TransferFunctions (CFG : ProcCfg.S) = struct
| [] | []
-> access_acc -> access_acc
| access :: access_list' | access :: access_list'
-> let kind = -> let is_write = if List.is_empty access_list' then is_write_access else false in
if List.is_empty access_list' then access_kind else ThreadSafetyDomain.Access.Read
in
let access_path = (fst prefix_path, snd prefix_path @ [access]) in let access_path = (fst prefix_path, snd prefix_path @ [access]) in
let access_acc' = let access_acc' =
if is_owned prefix_path attribute_map if is_owned prefix_path attribute_map
@ -293,7 +296,9 @@ module TransferFunctions (CFG : ProcCfg.S) = struct
else else
(* TODO: I think there's a utility function for this somewhere *) (* TODO: I think there's a utility function for this somewhere *)
let accesses = AccessDomain.get_accesses pre access_acc in let accesses = AccessDomain.get_accesses pre access_acc in
let accesses' = PathDomain.add_sink (make_access access_path kind loc) accesses in let accesses' =
PathDomain.add_sink (make_access access_path ~is_write loc) accesses
in
AccessDomain.add pre accesses' access_acc AccessDomain.add pre accesses' access_acc
in in
add_field_accesses pre access_path access_acc' access_list' add_field_accesses pre access_path access_acc' access_list'
@ -467,7 +472,7 @@ module TransferFunctions (CFG : ProcCfg.S) = struct
(fst receiver_ap, snd receiver_ap @ [AccessPath.FieldAccess dummy_fieldname]) (fst receiver_ap, snd receiver_ap @ [AccessPath.FieldAccess dummy_fieldname])
in in
AccessDomain.add_access (Unprotected (Some 0)) AccessDomain.add_access (Unprotected (Some 0))
(make_access dummy_access_ap Write callee_loc) AccessDomain.empty (make_access dummy_access_ap ~is_write:true callee_loc) AccessDomain.empty
in in
(* TODO: for now all formals escape *) (* TODO: for now all formals escape *)
(* we need a more intelligent escape analysis, that branches on whether (* we need a more intelligent escape analysis, that branches on whether
@ -505,7 +510,8 @@ module TransferFunctions (CFG : ProcCfg.S) = struct
let add_reads exps loc accesses locks threads attribute_map proc_data = let add_reads exps loc accesses locks threads attribute_map proc_data =
List.fold List.fold
~f:(fun acc exp -> add_access exp loc Read acc locks threads attribute_map proc_data) ~f:(fun acc exp ->
add_access exp loc ~is_write_access:false acc locks threads attribute_map proc_data)
exps ~init:accesses exps ~init:accesses
let add_escapees_from_exp rhs_exp extras escapees = let add_escapees_from_exp rhs_exp extras escapees =
@ -726,7 +732,7 @@ module TransferFunctions (CFG : ProcCfg.S) = struct
-> astate_callee ) -> astate_callee )
| Assign (lhs_access_path, rhs_exp, loc) | Assign (lhs_access_path, rhs_exp, loc)
-> let rhs_accesses = -> let rhs_accesses =
add_access rhs_exp loc Read astate.accesses astate.locks astate.threads add_access rhs_exp loc ~is_write_access:false astate.accesses astate.locks astate.threads
astate.attribute_map proc_data astate.attribute_map proc_data
in in
let rhs_access_paths = HilExp.get_access_paths rhs_exp in let rhs_access_paths = HilExp.get_access_paths rhs_exp in
@ -743,8 +749,8 @@ module TransferFunctions (CFG : ProcCfg.S) = struct
report spurious read/write races *) report spurious read/write races *)
rhs_accesses rhs_accesses
else else
add_access (AccessPath lhs_access_path) loc Write rhs_accesses astate.locks add_access (AccessPath lhs_access_path) loc ~is_write_access:true rhs_accesses
astate.threads astate.attribute_map proc_data astate.locks astate.threads astate.attribute_map proc_data
in in
let attribute_map = let attribute_map =
propagate_attributes lhs_access_path rhs_exp astate.attribute_map extras propagate_attributes lhs_access_path rhs_exp astate.attribute_map extras
@ -794,8 +800,8 @@ module TransferFunctions (CFG : ProcCfg.S) = struct
{acc with threads} {acc with threads}
in in
let accesses = let accesses =
add_access assume_exp loc Read astate.accesses astate.locks astate.threads add_access assume_exp loc ~is_write_access:false astate.accesses astate.locks
astate.attribute_map proc_data astate.threads astate.attribute_map proc_data
in in
let astate' = let astate' =
match HilExp.get_access_paths assume_exp with match HilExp.get_access_paths assume_exp with
@ -993,11 +999,12 @@ let analyze_procedure {Callbacks.proc_desc; tenv; summary} =
else Summary.update_summary empty_post summary else Summary.update_summary empty_post summary
module AccessListMap = Caml.Map.Make (struct module AccessListMap = Caml.Map.Make (struct
type t = AccessPath.Raw.t type t = AccessPath.Raw.t option
(* TODO -- keep this compare to satisfy the order of tests, consider using Raw.compare *) (* TODO -- keep this compare to satisfy the order of tests, consider using Raw.compare *)
let compare access_path1 access_path2 = let compare =
List.compare AccessPath.compare_access (snd access_path1) (snd access_path2) Option.compare (fun access_path1 access_path2 ->
List.compare AccessPath.compare_access (snd access_path1) (snd access_path2) )
end) end)
let get_current_class_and_threadsafe_superclasses tenv pname = let get_current_class_and_threadsafe_superclasses tenv pname =
@ -1051,7 +1058,10 @@ let pp_access fmt sink =
| Some container_write_desc | Some container_write_desc
-> pp_container_access fmt container_write_desc -> pp_container_access fmt container_write_desc
| None | None
-> let access_path, _ = ThreadSafetyDomain.PathDomain.Sink.kind sink in -> let access_path =
match ThreadSafetyDomain.PathDomain.Sink.kind sink
with Read access_path | Write access_path -> access_path
in
F.fprintf fmt "%a" (MF.wrap_monospaced AccessPath.pp_access_list) (snd access_path) F.fprintf fmt "%a" (MF.wrap_monospaced AccessPath.pp_access_list) (snd access_path)
let desc_of_sink sink = let desc_of_sink sink =
@ -1216,19 +1226,19 @@ let report_unsafe_accesses aggregated_access_map =
let is_duplicate_report access pname {reported_sites; reported_writes; reported_reads} = let is_duplicate_report access pname {reported_sites; reported_writes; reported_reads} =
CallSite.Set.mem (TraceElem.call_site access) reported_sites CallSite.Set.mem (TraceElem.call_site access) reported_sites
|| Typ.Procname.Set.mem pname || Typ.Procname.Set.mem pname
( match snd (TraceElem.kind access) with ( match TraceElem.kind access with
| Access.Write | Access.Write _
-> reported_writes -> reported_writes
| Access.Read | Access.Read _
-> reported_reads ) -> reported_reads )
in in
let update_reported access pname reported = let update_reported access pname reported =
let reported_sites = CallSite.Set.add (TraceElem.call_site access) reported.reported_sites in let reported_sites = CallSite.Set.add (TraceElem.call_site access) reported.reported_sites in
match snd (TraceElem.kind access) with match TraceElem.kind access with
| Access.Write | Access.Write _
-> let reported_writes = Typ.Procname.Set.add pname reported.reported_writes in -> let reported_writes = Typ.Procname.Set.add pname reported.reported_writes in
{reported with reported_writes; reported_sites} {reported with reported_writes; reported_sites}
| Access.Read | Access.Read _
-> let reported_reads = Typ.Procname.Set.add pname reported.reported_reads in -> let reported_reads = Typ.Procname.Set.add pname reported.reported_reads in
{reported with reported_reads; reported_sites} {reported with reported_reads; reported_sites}
in in
@ -1236,8 +1246,8 @@ let report_unsafe_accesses aggregated_access_map =
let pname = Procdesc.get_proc_name pdesc in let pname = Procdesc.get_proc_name pdesc in
if is_duplicate_report access pname reported_acc then reported_acc if is_duplicate_report access pname reported_acc then reported_acc
else else
match (snd (TraceElem.kind access), pre) with match (TraceElem.kind access, pre) with
| Access.Write, AccessPrecondition.Unprotected _ -> ( | Access.Write _, AccessPrecondition.Unprotected _ -> (
match Procdesc.get_proc_name pdesc with match Procdesc.get_proc_name pdesc with
| Java _ | Java _
-> if threaded then reported_acc -> if threaded then reported_acc
@ -1249,10 +1259,10 @@ let report_unsafe_accesses aggregated_access_map =
| _ | _
-> (* Do not report unprotected writes for ObjC_Cpp *) -> (* Do not report unprotected writes for ObjC_Cpp *)
reported_acc ) reported_acc )
| Access.Write, AccessPrecondition.Protected _ | Access.Write _, AccessPrecondition.Protected _
-> (* protected write, do nothing *) -> (* protected write, do nothing *)
reported_acc reported_acc
| Access.Read, AccessPrecondition.Unprotected _ | Access.Read _, AccessPrecondition.Unprotected _
-> (* unprotected read. report all writes as conflicts for java *) -> (* unprotected read. report all writes as conflicts for java *)
(* for c++ filter out unprotected writes *) (* for c++ filter out unprotected writes *)
let is_cpp_protected_write pre = let is_cpp_protected_write pre =
@ -1276,7 +1286,7 @@ let report_unsafe_accesses aggregated_access_map =
~conflicts:(List.map ~f:(fun (access, _, _, _, _) -> access) all_writes) ~conflicts:(List.map ~f:(fun (access, _, _, _, _) -> access) all_writes)
access ; access ;
update_reported access pname reported_acc ) update_reported access pname reported_acc )
| Access.Read, AccessPrecondition.Protected excl | Access.Read _, AccessPrecondition.Protected excl
-> (* protected read. -> (* protected read.
report unprotected writes and opposite protected writes as conflicts report unprotected writes and opposite protected writes as conflicts
Thread and Lock are opposites of one another, and Thread and Lock are opposites of one another, and
@ -1403,16 +1413,31 @@ let quotient_access_map acc_map =
let rec aux acc m = let rec aux acc m =
if AccessListMap.is_empty m then acc if AccessListMap.is_empty m then acc
else else
let k, vals = AccessListMap.choose m in let k_opt, vals = AccessListMap.choose m in
let _, _, _, tenv, _ = let _, _, _, tenv, _ =
List.find_exn vals ~f:(fun (elem, _, _, _, _) -> List.find_exn vals ~f:(fun (elem, _, _, _, _) ->
AccessPath.Raw.equal k (ThreadSafetyDomain.TraceElem.kind elem |> fst) ) Option.equal
(fun e1 e2 -> AccessPath.Raw.equal e1 e2)
k_opt
(ThreadSafetyDomain.Access.get_access_path (ThreadSafetyDomain.TraceElem.kind elem))
)
in in
(* assumption: the tenv for k is sufficient for k' too *) (* assumption: the tenv for k is sufficient for k' too *)
let k_part, non_k_part = AccessListMap.partition (fun k' _ -> may_alias tenv k k') m in let k_part, non_k_part =
AccessListMap.partition
(fun k_opt' _ ->
match (k_opt', k_opt) with
| Some k', Some k
-> may_alias tenv k k'
| None, None
-> true
| _
-> false)
m
in
if AccessListMap.is_empty k_part then failwith "may_alias is not reflexive!" ; if AccessListMap.is_empty k_part then failwith "may_alias is not reflexive!" ;
let k_accesses = AccessListMap.fold (fun _ v acc' -> List.append v acc') k_part [] in let k_accesses = AccessListMap.fold (fun _ v acc' -> List.append v acc') k_part [] in
let new_acc = AccessListMap.add k k_accesses acc in let new_acc = AccessListMap.add k_opt k_accesses acc in
aux new_acc non_k_part aux new_acc non_k_part
in in
aux AccessListMap.empty acc_map aux AccessListMap.empty acc_map
@ -1439,14 +1464,14 @@ let make_results_table file_env =
(fun pre accesses acc -> (fun pre accesses acc ->
PathDomain.Sinks.fold PathDomain.Sinks.fold
(fun access acc -> (fun access acc ->
let access_path, _ = TraceElem.kind access in let access_path_opt = Access.get_access_path (TraceElem.kind access) in
if should_filter_access access_path then acc if Option.exists ~f:should_filter_access access_path_opt then acc
else else
let grouped_accesses = let grouped_accesses =
try AccessListMap.find access_path acc try AccessListMap.find access_path_opt acc
with Not_found -> [] with Not_found -> []
in in
AccessListMap.add access_path AccessListMap.add access_path_opt
((access, pre, threaded, tenv, pdesc) :: grouped_accesses) acc) ((access, pre, threaded, tenv, pdesc) :: grouped_accesses) acc)
(PathDomain.sinks accesses) acc) (PathDomain.sinks accesses) acc)
accesses acc accesses acc

@ -11,15 +11,16 @@ open! IStd
module F = Format module F = Format
module Access = struct module Access = struct
type kind = Read | Write [@@deriving compare] type t = Read of AccessPath.Raw.t | Write of AccessPath.Raw.t [@@deriving compare]
type t = AccessPath.Raw.t * kind [@@deriving compare] let make access_path ~is_write = if is_write then Write access_path else Read access_path
let pp fmt (access_path, access_kind) = let get_access_path = function Read access_path | Write access_path -> Some access_path
match access_kind with
| Read let pp fmt = function
| Read access_path
-> F.fprintf fmt "Read of %a" AccessPath.Raw.pp access_path -> F.fprintf fmt "Read of %a" AccessPath.Raw.pp access_path
| Write | Write access_path
-> F.fprintf fmt "Write to %a" AccessPath.Raw.pp access_path -> F.fprintf fmt "Write to %a" AccessPath.Raw.pp access_path
end end
@ -28,9 +29,9 @@ module TraceElem = struct
type t = {site: CallSite.t; kind: Kind.t} [@@deriving compare] type t = {site: CallSite.t; kind: Kind.t} [@@deriving compare]
let is_read {kind} = match snd kind with Read -> true | Write -> false let is_read {kind} = match kind with Read _ -> true | Write _ -> false
let is_write {kind} = match snd kind with Read -> false | Write -> true let is_write {kind} = match kind with Read _ -> false | Write _ -> true
let call_site {site} = site let call_site {site} = site
@ -51,9 +52,9 @@ module TraceElem = struct
end) end)
end end
let make_access access_path access_kind loc = let make_access access_path ~is_write loc =
let site = CallSite.make Typ.Procname.empty_block loc in let site = CallSite.make Typ.Procname.empty_block loc in
TraceElem.make (access_path, access_kind) site TraceElem.make (Access.make access_path ~is_write) site
(* In this domain true<=false. The intended denotations [[.]] are (* In this domain true<=false. The intended denotations [[.]] are
[[true]] = the set of all states where we know according, to annotations [[true]] = the set of all states where we know according, to annotations

@ -11,9 +11,9 @@ open! IStd
module F = Format module F = Format
module Access : sig module Access : sig
type kind = Read | Write [@@deriving compare] type t = Read of AccessPath.Raw.t | Write of AccessPath.Raw.t [@@deriving compare]
type t = AccessPath.Raw.t * kind [@@deriving compare] val get_access_path : t -> AccessPath.Raw.t option
val pp : F.formatter -> t -> unit val pp : F.formatter -> t -> unit
end end
@ -171,6 +171,6 @@ type summary =
include AbstractDomain.WithBottom with type astate := astate include AbstractDomain.WithBottom with type astate := astate
val make_access : AccessPath.Raw.t -> Access.kind -> Location.t -> TraceElem.t val make_access : AccessPath.Raw.t -> is_write:bool -> Location.t -> TraceElem.t
val pp_summary : F.formatter -> summary -> unit val pp_summary : F.formatter -> summary -> unit

Loading…
Cancel
Save