diff --git a/infer/src/checkers/ThreadSafety.ml b/infer/src/checkers/ThreadSafety.ml index c418de6d1..f8b918dfb 100644 --- a/infer/src/checkers/ThreadSafety.ml +++ b/infer/src/checkers/ThreadSafety.ml @@ -146,16 +146,14 @@ module TransferFunctions (CFG : ProcCfg.S) = struct attribute_map let add_path_to_state exp typ loc path_state id_map attribute_map tenv = - (* remove the last field of the access path, if it has any *) - let truncate = function - | base, [] - | base, _ :: [] -> base, [] - | base, accesses -> base, IList.rev (List.tl_exn (IList.rev accesses)) in (* we don't want to warn on writes to the field if it is (a) thread-confined, or (b) volatile *) let is_safe_write access_path tenv = - let is_thread_safe_write accesses tenv = match IList.rev accesses with - | AccessPath.FieldAccess (fieldname, Typ.Tstruct typename) :: _ -> + let is_thread_safe_write accesses tenv = + match IList.rev accesses, + AccessPath.Raw.get_typ (AccessPath.Raw.truncate access_path) tenv with + | AccessPath.FieldAccess fieldname :: _, + Some (Typ.Tstruct typename | Tptr (Tstruct typename, _)) -> begin match Tenv.lookup tenv typename with | Some struct_typ -> @@ -177,7 +175,8 @@ module TransferFunctions (CFG : ProcCfg.S) = struct else IList.fold_left (fun acc rawpath -> - if not (is_owned (truncate rawpath) attribute_map) && not (is_safe_write rawpath tenv) + if not (is_owned (AccessPath.Raw.truncate rawpath) attribute_map) && + not (is_safe_write rawpath tenv) then Domain.PathDomain.add_sink (Domain.make_access rawpath loc) acc else acc) path_state diff --git a/infer/src/checkers/accessPath.ml b/infer/src/checkers/accessPath.ml index 4e82a3e93..e0187b976 100644 --- a/infer/src/checkers/accessPath.ml +++ b/infer/src/checkers/accessPath.ml @@ -21,7 +21,7 @@ let equal_base = [%compare.equal : base] type access = | ArrayAccess of Typ.t - | FieldAccess of Ident.fieldname * Typ.t + | FieldAccess of Ident.fieldname [@@deriving compare] let equal_access = [%compare.equal : access] @@ -30,7 +30,7 @@ let pp_base fmt (pvar, _) = Var.pp fmt pvar let pp_access fmt = function - | FieldAccess (field_name, _) -> Ident.pp_fieldname fmt field_name + | FieldAccess field_name -> Ident.pp_fieldname fmt field_name | ArrayAccess _ -> F.fprintf fmt "[_]" let pp_access_list fmt accesses = @@ -41,6 +41,26 @@ module Raw = struct type t = base * access list [@@deriving compare] let equal = [%compare.equal : t] + let truncate = function + | base, [] + | base, _ :: [] -> base, [] + | base, accesses -> base, List.rev (List.tl_exn (List.rev accesses)) + + let get_typ ((_, base_typ), accesses) tenv = + let rec accesses_get_typ last_typ tenv = function + | [] -> + Some last_typ + | FieldAccess field_name :: accesses -> + let lookup = Tenv.lookup tenv in + begin + match StructTyp.get_field_type_and_annotation ~lookup field_name last_typ with + | Some (field_typ, _) -> accesses_get_typ field_typ tenv accesses + | None -> None + end + | ArrayAccess array_typ :: accesses -> + accesses_get_typ array_typ tenv accesses in + accesses_get_typ base_typ tenv accesses + let pp fmt = function | base, [] -> pp_base fmt base | base, accesses -> F.fprintf fmt "%a.%a" pp_base base pp_access_list accesses @@ -84,7 +104,7 @@ let of_exp exp0 typ0 ~(f_resolve_id : Var.t -> Raw.t option) = | Exp.Lvar pvar -> (base_of_pvar pvar typ, accesses) :: acc | Exp.Lfield (root_exp, fld, root_exp_typ) -> - let field_access = FieldAccess (fld, typ) in + let field_access = FieldAccess fld in of_exp_ root_exp root_exp_typ (field_access :: accesses) acc | Exp.Lindex (root_exp, _) -> let array_access = ArrayAccess typ in diff --git a/infer/src/checkers/accessPath.mli b/infer/src/checkers/accessPath.mli index 6336b6686..7315f18b8 100644 --- a/infer/src/checkers/accessPath.mli +++ b/infer/src/checkers/accessPath.mli @@ -15,7 +15,7 @@ type base = Var.t * Typ.t [@@deriving compare] type access = | ArrayAccess of Typ.t (* array element type. index is unknown *) - | FieldAccess of Ident.fieldname * Typ.t (* field name * field type *) + | FieldAccess of Ident.fieldname (* field name *) [@@deriving compare] module Raw : sig @@ -23,7 +23,16 @@ module Raw : sig representedas (x, [f; g]) *) type t = base * access list [@@deriving compare] + (** remove the last access of the access path if the access list is non-empty. returns the + original access path if the access list is empty *) + val truncate : t -> t + + (** get the typ of the last access in the list of accesses if the list is non-empty, or the base + if the list is empty. that is, for x.f.g, return typ(g), and for x, return typ(x) *) + val get_typ : t -> Tenv.t -> Typ.t option + val equal : t -> t -> bool + val pp : Format.formatter -> t -> unit end diff --git a/infer/src/unit/accessPathTestUtils.ml b/infer/src/unit/accessPathTestUtils.ml index b7d757c02..a811e972c 100644 --- a/infer/src/unit/accessPathTestUtils.ml +++ b/infer/src/unit/accessPathTestUtils.ml @@ -17,7 +17,7 @@ let make_fieldname fld_str = Ident.create_fieldname (Mangled.from_string fld_str) 0 let make_field_access access_str = - AccessPath.FieldAccess (make_fieldname access_str, Typ.Tvoid) + AccessPath.FieldAccess (make_fieldname access_str) let make_array_access typ = AccessPath.ArrayAccess typ