diff --git a/infer/src/IR/HilExp.ml b/infer/src/IR/HilExp.ml index f0e3e0366..1927427d0 100644 --- a/infer/src/IR/HilExp.ml +++ b/infer/src/IR/HilExp.ml @@ -207,20 +207,21 @@ module AccessExpression = struct let pp = pp_access_expr - let to_accesses ~f_array_offset ae = - let rec aux accesses = function + let to_accesses_fold access_expr ~init ~f_array_offset = + let rec aux accum accesses = function | Base base -> - (base, accesses) - | FieldOffset (ae, fld) -> - aux (Access.FieldAccess fld :: accesses) ae - | ArrayOffset (ae, typ, index) -> - aux (Access.ArrayAccess (typ, f_array_offset index) :: accesses) ae - | AddressOf ae -> - aux (Access.TakeAddress :: accesses) ae - | Dereference ae -> - aux (Access.Dereference :: accesses) ae + (accum, base, accesses) + | FieldOffset (access_expr, fld) -> + aux accum (Access.FieldAccess fld :: accesses) access_expr + | ArrayOffset (access_expr, typ, index) -> + let accum, index' = f_array_offset accum index in + aux accum (Access.ArrayAccess (typ, index') :: accesses) access_expr + | AddressOf access_expr -> + aux accum (Access.TakeAddress :: accesses) access_expr + | Dereference access_expr -> + aux accum (Access.Dereference :: accesses) access_expr in - aux [] ae + aux init [] access_expr (** convert to an AccessPath.t, ignoring AddressOf and Dereference for now *) diff --git a/infer/src/IR/HilExp.mli b/infer/src/IR/HilExp.mli index 52db9f620..78716049c 100644 --- a/infer/src/IR/HilExp.mli +++ b/infer/src/IR/HilExp.mli @@ -50,10 +50,11 @@ module AccessExpression : sig val dereference : access_expression -> access_expression (** guarantees that we never build [Dereference (AddressOf t)] expressions: these become [t] *) - val to_accesses : - f_array_offset:(t option -> 'array_index) - -> access_expression - -> AccessPath.base * 'array_index Access.t list + val to_accesses_fold : + access_expression + -> init:'accum + -> f_array_offset:('accum -> t option -> 'accum * 'array_index) + -> 'accum * AccessPath.base * 'array_index Access.t list val to_access_path : access_expression -> AccessPath.t diff --git a/infer/src/checkers/PulseDomain.ml b/infer/src/checkers/PulseDomain.ml index 0021c4504..c936fbb65 100644 --- a/infer/src/checkers/PulseDomain.ml +++ b/infer/src/checkers/PulseDomain.ml @@ -78,7 +78,8 @@ end module Attributes = AbstractDomain.FiniteSet (Attribute) module Memory : sig - module Access : PrettyPrintable.PrintableOrderedType with type t = unit HilExp.Access.t + module Access : + PrettyPrintable.PrintableOrderedType with type t = AbstractAddressSet.t HilExp.Access.t module Edges : PrettyPrintable.PPMap with type key = Access.t @@ -116,9 +117,9 @@ module Memory : sig val is_std_vector_reserved : AbstractAddressSet.t -> t -> bool end = struct module Access = struct - type t = unit HilExp.Access.t [@@deriving compare] + type t = AbstractAddressSet.t HilExp.Access.t [@@deriving compare] - let pp = HilExp.Access.pp (fun _ () -> ()) + let pp = HilExp.Access.pp AbstractAddressSet.pp end module Edges = PrettyPrintable.MakePPMap (Access) @@ -622,11 +623,28 @@ module Operations = struct {astate with stack} + let rec to_accesses location access_expr astate = + let exception Failed_fold of Diagnostic.t in + try + HilExp.AccessExpression.to_accesses_fold access_expr ~init:astate + ~f_array_offset:(fun astate hil_exp_opt -> + match hil_exp_opt with + | None -> + (astate, AbstractAddressSet.mk_fresh ()) + | Some hil_exp -> ( + match eval_hil_exp location hil_exp astate with + | Ok result -> + result + | Error diag -> + raise (Failed_fold diag) ) ) + |> Result.return + with Failed_fold diag -> Error diag + + (** add addresses to the state to give a address to the destination of the given access path *) - let walk_access_expr ~on_last astate access_expr location = - let (access_var, _), access_list = - HilExp.AccessExpression.to_accesses ~f_array_offset:(fun _ -> ()) access_expr - in + and walk_access_expr ~on_last astate access_expr location = + to_accesses location access_expr astate + >>= fun (astate, (access_var, _), access_list) -> if Config.write_html then L.d_printfln "Accessing %a -> [%a]" Var.pp access_var (Pp.seq ~sep:"," Memory.Access.pp) @@ -653,7 +671,28 @@ module Operations = struct Return an error state if it traverses some known invalid address or if the end destination is known to be invalid. *) - let materialize_address astate access_expr = walk_access_expr ~on_last:`Access astate access_expr + and materialize_address astate access_expr = walk_access_expr ~on_last:`Access astate access_expr + + and read location access_expr astate = + materialize_address astate access_expr location + >>= fun (astate, addr) -> + let actor = {access_expr; location} in + check_addr_access_set actor addr astate >>| fun astate -> (astate, addr) + + + and read_all location access_exprs astate = + List.fold_result access_exprs ~init:astate ~f:(fun astate access_expr -> + read location access_expr astate >>| fst ) + + + and eval_hil_exp location (hil_exp : HilExp.t) astate = + match hil_exp with + | AccessExpression access_expr -> + read location access_expr astate + | _ -> + read_all location (HilExp.get_access_exprs hil_exp) astate + >>| fun astate -> (astate, AbstractAddressSet.mk_fresh ()) + (** Use the stack and heap to walk the access path represented by the given expression down to an abstract address representing what the expression points to, and replace that with the given @@ -686,18 +725,6 @@ module Operations = struct >>| fst - let read location access_expr astate = - materialize_address astate access_expr location - >>= fun (astate, addr) -> - let actor = {access_expr; location} in - check_addr_access_set actor addr astate >>| fun astate -> (astate, addr) - - - let read_all location access_exprs astate = - List.fold_result access_exprs ~init:astate ~f:(fun astate access_expr -> - read location access_expr astate >>| fst ) - - let write location access_expr addr astate = overwrite_address astate access_expr addr location >>| fun (astate, _) -> astate @@ -708,6 +735,28 @@ module Operations = struct check_addr_access_set {access_expr; location} addr astate >>| mark_invalid_set cause addr + let invalidate_array_elements cause location access_expr astate = + materialize_address astate access_expr location + >>= fun (astate, addrs) -> + check_addr_access_set {access_expr; location} addrs astate + >>| fun astate -> + AbstractAddressSet.fold + (fun addr astate -> + match Memory.find_opt addr astate.heap with + | None -> + astate + | Some (edges, _) -> + Memory.Edges.fold + (fun access dest_addrs astate -> + match (access : Memory.Access.t) with + | ArrayAccess _ -> + mark_invalid_set cause dest_addrs astate + | _ -> + astate ) + edges astate ) + addrs astate + + let remove_vars vars astate = let stack = List.fold ~f:(fun var stack -> Stack.remove stack var) ~init:astate.stack vars in if phys_equal stack astate.stack then astate else {astate with stack} diff --git a/infer/src/checkers/PulseDomain.mli b/infer/src/checkers/PulseDomain.mli index 9a8097605..9ed2e7f57 100644 --- a/infer/src/checkers/PulseDomain.mli +++ b/infer/src/checkers/PulseDomain.mli @@ -71,4 +71,7 @@ val write : Location.t -> HilExp.AccessExpression.t -> AbstractAddressSet.t -> t val invalidate : PulseInvalidation.t -> Location.t -> HilExp.AccessExpression.t -> t -> t access_result +val invalidate_array_elements : + PulseInvalidation.t -> Location.t -> HilExp.AccessExpression.t -> t -> t access_result + val remove_vars : Var.t list -> t -> t diff --git a/infer/src/checkers/PulseModels.ml b/infer/src/checkers/PulseModels.ml index 7735d6e53..23a2a9b6c 100644 --- a/infer/src/checkers/PulseModels.ml +++ b/infer/src/checkers/PulseModels.ml @@ -97,16 +97,20 @@ module StdVector = struct let to_internal_array vector = HilExp.AccessExpression.field_offset vector internal_array let deref_internal_array vector = - HilExp.AccessExpression.(array_offset (dereference (to_internal_array vector)) Typ.void None) + HilExp.AccessExpression.(dereference (to_internal_array vector)) + + + let element_of_internal_array vector index = + HilExp.AccessExpression.array_offset (deref_internal_array vector) Typ.void index let reallocate_internal_array vector vector_f location astate = - let array = to_internal_array vector in - (* all elements should be invalidated *) - let array_elements = deref_internal_array vector in + let array_address = to_internal_array vector in + let array = deref_internal_array vector in let invalidation = PulseInvalidation.StdVector (vector_f, vector, location) in - PulseDomain.invalidate invalidation location array_elements astate - >>= PulseDomain.havoc location array + PulseDomain.invalidate_array_elements invalidation location array astate + >>= PulseDomain.invalidate invalidation location array + >>= PulseDomain.havoc location array_address let invalidate_references invalidation : model = @@ -121,8 +125,10 @@ module StdVector = struct let at : model = fun location ~ret ~actuals astate -> match actuals with - | [AccessExpression vector; _index] -> - PulseDomain.read location (deref_internal_array vector) astate + | [AccessExpression vector_access_expr; index_exp] -> + PulseDomain.read location + (element_of_internal_array vector_access_expr (Some index_exp)) + astate >>= fun (astate, loc) -> PulseDomain.write location (HilExp.AccessExpression.base ret) loc astate | _ ->