From 03120a337e2195b2296a1d401ce4a88211ff88ba Mon Sep 17 00:00:00 2001 From: Sam Blackshear Date: Wed, 19 Jul 2017 13:29:22 -0700 Subject: [PATCH] [thread-safety] refactor ThreadSafetyDomain.Access to make it easier to add new access kinds Summary: Making it simple to add a new access type for "un-annotated interface call" in an upcoming diff. Reviewed By: da319 Differential Revision: D5445914 fbshipit-source-id: f29e342 --- infer/src/checkers/ThreadSafety.ml | 129 +++++++++++++--------- infer/src/checkers/ThreadSafetyDomain.ml | 21 ++-- infer/src/checkers/ThreadSafetyDomain.mli | 6 +- 3 files changed, 91 insertions(+), 65 deletions(-) diff --git a/infer/src/checkers/ThreadSafety.ml b/infer/src/checkers/ThreadSafety.ml index 000014279..db9fdcfd6 100644 --- a/infer/src/checkers/ThreadSafety.ml +++ b/infer/src/checkers/ThreadSafety.ml @@ -29,23 +29,28 @@ let container_write_string = "infer.dummy.__CONTAINERWRITE__" (* return (name of container, name of mutating function call) pair *) let get_container_write_desc sink = - let (base_var, _), access_list = fst (ThreadSafetyDomain.TraceElem.kind sink) in - let get_container_write_desc_ call_name container_name = - match - String.chop_prefix (Typ.Fieldname.to_string call_name) ~prefix:container_write_string - with - | Some call_name - -> Some (container_name, call_name) - | None - -> None - in - match List.rev access_list with - | (FieldAccess call_name) :: (FieldAccess container_name) :: _ - -> get_container_write_desc_ call_name (Typ.Fieldname.to_string container_name) - | [(FieldAccess call_name)] - -> get_container_write_desc_ call_name (F.asprintf "%a" Var.pp base_var) - | _ - -> None + match ThreadSafetyDomain.TraceElem.kind sink with + | Write ((base_var, _), access_list) + -> ( + let get_container_write_desc_ call_name container_name = + match + String.chop_prefix (Typ.Fieldname.to_string call_name) ~prefix:container_write_string + with + | Some call_name + -> Some (container_name, call_name) + | None + -> None + in + match List.rev access_list with + | (FieldAccess call_name) :: (FieldAccess container_name) :: _ + -> get_container_write_desc_ call_name (Typ.Fieldname.to_string container_name) + | [(FieldAccess call_name)] + -> get_container_write_desc_ call_name (F.asprintf "%a" Var.pp base_var) + | _ + -> None ) + | Read _ + -> (* TODO: support Read *) + None let is_container_write_sink sink = Option.is_some (get_container_write_desc sink) @@ -259,7 +264,7 @@ module TransferFunctions (CFG : ProcCfg.S) = struct let is_unprotected is_locked is_threaded pdesc = not is_locked && not is_threaded && not (Procdesc.is_java_synchronized pdesc) - let add_access exp loc access_kind accesses locks threads attribute_map + let add_access exp loc ~is_write_access accesses locks threads attribute_map (proc_data: FormalMap.t ProcData.t) = let open Domain in (* we don't want to warn on accesses to the field if it is (a) thread-confined, or @@ -282,9 +287,7 @@ module TransferFunctions (CFG : ProcCfg.S) = struct | [] -> access_acc | access :: access_list' - -> let kind = - if List.is_empty access_list' then access_kind else ThreadSafetyDomain.Access.Read - in + -> let is_write = if List.is_empty access_list' then is_write_access else false in let access_path = (fst prefix_path, snd prefix_path @ [access]) in let access_acc' = if is_owned prefix_path attribute_map @@ -293,7 +296,9 @@ module TransferFunctions (CFG : ProcCfg.S) = struct else (* TODO: I think there's a utility function for this somewhere *) let accesses = AccessDomain.get_accesses pre access_acc in - let accesses' = PathDomain.add_sink (make_access access_path kind loc) accesses in + let accesses' = + PathDomain.add_sink (make_access access_path ~is_write loc) accesses + in AccessDomain.add pre accesses' access_acc in add_field_accesses pre access_path access_acc' access_list' @@ -467,7 +472,7 @@ module TransferFunctions (CFG : ProcCfg.S) = struct (fst receiver_ap, snd receiver_ap @ [AccessPath.FieldAccess dummy_fieldname]) in AccessDomain.add_access (Unprotected (Some 0)) - (make_access dummy_access_ap Write callee_loc) AccessDomain.empty + (make_access dummy_access_ap ~is_write:true callee_loc) AccessDomain.empty in (* TODO: for now all formals escape *) (* we need a more intelligent escape analysis, that branches on whether @@ -505,7 +510,8 @@ module TransferFunctions (CFG : ProcCfg.S) = struct let add_reads exps loc accesses locks threads attribute_map proc_data = List.fold - ~f:(fun acc exp -> add_access exp loc Read acc locks threads attribute_map proc_data) + ~f:(fun acc exp -> + add_access exp loc ~is_write_access:false acc locks threads attribute_map proc_data) exps ~init:accesses let add_escapees_from_exp rhs_exp extras escapees = @@ -726,7 +732,7 @@ module TransferFunctions (CFG : ProcCfg.S) = struct -> astate_callee ) | Assign (lhs_access_path, rhs_exp, loc) -> let rhs_accesses = - add_access rhs_exp loc Read astate.accesses astate.locks astate.threads + add_access rhs_exp loc ~is_write_access:false astate.accesses astate.locks astate.threads astate.attribute_map proc_data in let rhs_access_paths = HilExp.get_access_paths rhs_exp in @@ -743,8 +749,8 @@ module TransferFunctions (CFG : ProcCfg.S) = struct report spurious read/write races *) rhs_accesses else - add_access (AccessPath lhs_access_path) loc Write rhs_accesses astate.locks - astate.threads astate.attribute_map proc_data + add_access (AccessPath lhs_access_path) loc ~is_write_access:true rhs_accesses + astate.locks astate.threads astate.attribute_map proc_data in let attribute_map = propagate_attributes lhs_access_path rhs_exp astate.attribute_map extras @@ -794,8 +800,8 @@ module TransferFunctions (CFG : ProcCfg.S) = struct {acc with threads} in let accesses = - add_access assume_exp loc Read astate.accesses astate.locks astate.threads - astate.attribute_map proc_data + add_access assume_exp loc ~is_write_access:false astate.accesses astate.locks + astate.threads astate.attribute_map proc_data in let astate' = match HilExp.get_access_paths assume_exp with @@ -993,11 +999,12 @@ let analyze_procedure {Callbacks.proc_desc; tenv; summary} = else Summary.update_summary empty_post summary module AccessListMap = Caml.Map.Make (struct - type t = AccessPath.Raw.t + type t = AccessPath.Raw.t option (* TODO -- keep this compare to satisfy the order of tests, consider using Raw.compare *) - let compare access_path1 access_path2 = - List.compare AccessPath.compare_access (snd access_path1) (snd access_path2) + let compare = + Option.compare (fun access_path1 access_path2 -> + List.compare AccessPath.compare_access (snd access_path1) (snd access_path2) ) end) let get_current_class_and_threadsafe_superclasses tenv pname = @@ -1051,7 +1058,10 @@ let pp_access fmt sink = | Some container_write_desc -> pp_container_access fmt container_write_desc | None - -> let access_path, _ = ThreadSafetyDomain.PathDomain.Sink.kind sink in + -> let access_path = + match ThreadSafetyDomain.PathDomain.Sink.kind sink + with Read access_path | Write access_path -> access_path + in F.fprintf fmt "%a" (MF.wrap_monospaced AccessPath.pp_access_list) (snd access_path) let desc_of_sink sink = @@ -1216,19 +1226,19 @@ let report_unsafe_accesses aggregated_access_map = let is_duplicate_report access pname {reported_sites; reported_writes; reported_reads} = CallSite.Set.mem (TraceElem.call_site access) reported_sites || Typ.Procname.Set.mem pname - ( match snd (TraceElem.kind access) with - | Access.Write + ( match TraceElem.kind access with + | Access.Write _ -> reported_writes - | Access.Read + | Access.Read _ -> reported_reads ) in let update_reported access pname reported = let reported_sites = CallSite.Set.add (TraceElem.call_site access) reported.reported_sites in - match snd (TraceElem.kind access) with - | Access.Write + match TraceElem.kind access with + | Access.Write _ -> let reported_writes = Typ.Procname.Set.add pname reported.reported_writes in {reported with reported_writes; reported_sites} - | Access.Read + | Access.Read _ -> let reported_reads = Typ.Procname.Set.add pname reported.reported_reads in {reported with reported_reads; reported_sites} in @@ -1236,8 +1246,8 @@ let report_unsafe_accesses aggregated_access_map = let pname = Procdesc.get_proc_name pdesc in if is_duplicate_report access pname reported_acc then reported_acc else - match (snd (TraceElem.kind access), pre) with - | Access.Write, AccessPrecondition.Unprotected _ -> ( + match (TraceElem.kind access, pre) with + | Access.Write _, AccessPrecondition.Unprotected _ -> ( match Procdesc.get_proc_name pdesc with | Java _ -> if threaded then reported_acc @@ -1249,10 +1259,10 @@ let report_unsafe_accesses aggregated_access_map = | _ -> (* Do not report unprotected writes for ObjC_Cpp *) reported_acc ) - | Access.Write, AccessPrecondition.Protected _ + | Access.Write _, AccessPrecondition.Protected _ -> (* protected write, do nothing *) reported_acc - | Access.Read, AccessPrecondition.Unprotected _ + | Access.Read _, AccessPrecondition.Unprotected _ -> (* unprotected read. report all writes as conflicts for java *) (* for c++ filter out unprotected writes *) let is_cpp_protected_write pre = @@ -1276,7 +1286,7 @@ let report_unsafe_accesses aggregated_access_map = ~conflicts:(List.map ~f:(fun (access, _, _, _, _) -> access) all_writes) access ; update_reported access pname reported_acc ) - | Access.Read, AccessPrecondition.Protected excl + | Access.Read _, AccessPrecondition.Protected excl -> (* protected read. report unprotected writes and opposite protected writes as conflicts Thread and Lock are opposites of one another, and @@ -1403,16 +1413,31 @@ let quotient_access_map acc_map = let rec aux acc m = if AccessListMap.is_empty m then acc else - let k, vals = AccessListMap.choose m in + let k_opt, vals = AccessListMap.choose m in let _, _, _, tenv, _ = List.find_exn vals ~f:(fun (elem, _, _, _, _) -> - AccessPath.Raw.equal k (ThreadSafetyDomain.TraceElem.kind elem |> fst) ) + Option.equal + (fun e1 e2 -> AccessPath.Raw.equal e1 e2) + k_opt + (ThreadSafetyDomain.Access.get_access_path (ThreadSafetyDomain.TraceElem.kind elem)) + ) in (* assumption: the tenv for k is sufficient for k' too *) - let k_part, non_k_part = AccessListMap.partition (fun k' _ -> may_alias tenv k k') m in + let k_part, non_k_part = + AccessListMap.partition + (fun k_opt' _ -> + match (k_opt', k_opt) with + | Some k', Some k + -> may_alias tenv k k' + | None, None + -> true + | _ + -> false) + m + in if AccessListMap.is_empty k_part then failwith "may_alias is not reflexive!" ; let k_accesses = AccessListMap.fold (fun _ v acc' -> List.append v acc') k_part [] in - let new_acc = AccessListMap.add k k_accesses acc in + let new_acc = AccessListMap.add k_opt k_accesses acc in aux new_acc non_k_part in aux AccessListMap.empty acc_map @@ -1439,14 +1464,14 @@ let make_results_table file_env = (fun pre accesses acc -> PathDomain.Sinks.fold (fun access acc -> - let access_path, _ = TraceElem.kind access in - if should_filter_access access_path then acc + let access_path_opt = Access.get_access_path (TraceElem.kind access) in + if Option.exists ~f:should_filter_access access_path_opt then acc else let grouped_accesses = - try AccessListMap.find access_path acc + try AccessListMap.find access_path_opt acc with Not_found -> [] in - AccessListMap.add access_path + AccessListMap.add access_path_opt ((access, pre, threaded, tenv, pdesc) :: grouped_accesses) acc) (PathDomain.sinks accesses) acc) accesses acc diff --git a/infer/src/checkers/ThreadSafetyDomain.ml b/infer/src/checkers/ThreadSafetyDomain.ml index 2b566731e..589f5f76e 100644 --- a/infer/src/checkers/ThreadSafetyDomain.ml +++ b/infer/src/checkers/ThreadSafetyDomain.ml @@ -11,15 +11,16 @@ open! IStd module F = Format module Access = struct - type kind = Read | Write [@@deriving compare] + type t = Read of AccessPath.Raw.t | Write of AccessPath.Raw.t [@@deriving compare] - type t = AccessPath.Raw.t * kind [@@deriving compare] + let make access_path ~is_write = if is_write then Write access_path else Read access_path - let pp fmt (access_path, access_kind) = - match access_kind with - | Read + let get_access_path = function Read access_path | Write access_path -> Some access_path + + let pp fmt = function + | Read access_path -> F.fprintf fmt "Read of %a" AccessPath.Raw.pp access_path - | Write + | Write access_path -> F.fprintf fmt "Write to %a" AccessPath.Raw.pp access_path end @@ -28,9 +29,9 @@ module TraceElem = struct type t = {site: CallSite.t; kind: Kind.t} [@@deriving compare] - let is_read {kind} = match snd kind with Read -> true | Write -> false + let is_read {kind} = match kind with Read _ -> true | Write _ -> false - let is_write {kind} = match snd kind with Read -> false | Write -> true + let is_write {kind} = match kind with Read _ -> false | Write _ -> true let call_site {site} = site @@ -51,9 +52,9 @@ module TraceElem = struct end) end -let make_access access_path access_kind loc = +let make_access access_path ~is_write loc = let site = CallSite.make Typ.Procname.empty_block loc in - TraceElem.make (access_path, access_kind) site + TraceElem.make (Access.make access_path ~is_write) site (* In this domain true<=false. The intended denotations [[.]] are [[true]] = the set of all states where we know according, to annotations diff --git a/infer/src/checkers/ThreadSafetyDomain.mli b/infer/src/checkers/ThreadSafetyDomain.mli index 4ae6b1148..921e243f8 100644 --- a/infer/src/checkers/ThreadSafetyDomain.mli +++ b/infer/src/checkers/ThreadSafetyDomain.mli @@ -11,9 +11,9 @@ open! IStd module F = Format module Access : sig - type kind = Read | Write [@@deriving compare] + type t = Read of AccessPath.Raw.t | Write of AccessPath.Raw.t [@@deriving compare] - type t = AccessPath.Raw.t * kind [@@deriving compare] + val get_access_path : t -> AccessPath.Raw.t option val pp : F.formatter -> t -> unit end @@ -171,6 +171,6 @@ type summary = include AbstractDomain.WithBottom with type astate := astate -val make_access : AccessPath.Raw.t -> Access.kind -> Location.t -> TraceElem.t +val make_access : AccessPath.Raw.t -> is_write:bool -> Location.t -> TraceElem.t val pp_summary : F.formatter -> summary -> unit