diff --git a/infer/src/nullsafe/AssignmentRule.ml b/infer/src/nullsafe/AssignmentRule.ml index a22fa0f27..f88e67369 100644 --- a/infer/src/nullsafe/AssignmentRule.ml +++ b/infer/src/nullsafe/AssignmentRule.ml @@ -6,7 +6,7 @@ *) open! IStd -type violation = {lhs: Nullability.t; rhs: Nullability.t} [@@deriving compare] +type violation = {lhs: AnnotatedNullability.t; rhs: InferredNullability.t} [@@deriving compare] module ReportableViolation = struct type t = {nullsafe_mode: NullsafeMode.t; violation: violation} @@ -28,12 +28,12 @@ module ReportableViolation = struct let falls_under_optimistic_third_party = Config.nullsafe_optimistic_third_party_params_in_non_strict && NullsafeMode.equal nullsafe_mode Default - && Nullability.equal lhs ThirdPartyNonnull + && Nullability.equal (AnnotatedNullability.get_nullability lhs) ThirdPartyNonnull in let is_non_reportable = falls_under_optimistic_third_party || (* In certain modes, we trust rhs to be non-nullable and don't report violation *) - Nullability.is_considered_nonnull ~nullsafe_mode rhs + Nullability.is_considered_nonnull ~nullsafe_mode (InferredNullability.get_nullability rhs) in if is_non_reportable then None else Some {nullsafe_mode; violation} @@ -200,10 +200,11 @@ module ReportableViolation = struct (error_message, issue_type, assignment_location) - let get_description ~assignment_location assignment_type ~rhs_origin - {nullsafe_mode; violation= {rhs}} = + let get_description ~assignment_location assignment_type {nullsafe_mode; violation= {rhs}} = + let rhs_origin = InferredNullability.get_origin rhs in let user_friendly_nullable = - ErrorRenderingUtils.UserFriendlyNullable.from_nullability rhs + ErrorRenderingUtils.UserFriendlyNullable.from_nullability + (InferredNullability.get_nullability rhs) |> IOption.if_none_eval ~f:(fun () -> Logging.die InternalError "get_description:: Assignment violation should not be possible for non-nullable \ @@ -223,5 +224,9 @@ module ReportableViolation = struct end let check ~lhs ~rhs = - let is_subtype = Nullability.is_subtype ~supertype:lhs ~subtype:rhs in + let is_subtype = + Nullability.is_subtype + ~supertype:(AnnotatedNullability.get_nullability lhs) + ~subtype:(InferredNullability.get_nullability rhs) + in Result.ok_if_true is_subtype ~error:{lhs; rhs} diff --git a/infer/src/nullsafe/AssignmentRule.mli b/infer/src/nullsafe/AssignmentRule.mli index de9d0b3ee..54b9cdb91 100644 --- a/infer/src/nullsafe/AssignmentRule.mli +++ b/infer/src/nullsafe/AssignmentRule.mli @@ -12,7 +12,7 @@ open! IStd type violation [@@deriving compare] -val check : lhs:Nullability.t -> rhs:Nullability.t -> (unit, violation) result +val check : lhs:AnnotatedNullability.t -> rhs:InferredNullability.t -> (unit, violation) result (** If `null` can leak from a "less strict" type to "more strict" type, this is an Assignment Rule violation. *) @@ -41,11 +41,7 @@ module ReportableViolation : sig (** Severity of the violation to be reported *) val get_description : - assignment_location:Location.t - -> assignment_type - -> rhs_origin:TypeOrigin.t - -> t - -> string * IssueType.t * Location.t + assignment_location:Location.t -> assignment_type -> t -> string * IssueType.t * Location.t (** Given context around violation, return error message together with the info where to put this message *) end diff --git a/infer/src/nullsafe/eradicateChecks.ml b/infer/src/nullsafe/eradicateChecks.ml index ec6ee0ea1..f5f991563 100644 --- a/infer/src/nullsafe/eradicateChecks.ml +++ b/infer/src/nullsafe/eradicateChecks.ml @@ -143,12 +143,9 @@ let check_field_assignment in Annotations.ia_is_cleanup ret_annotation_deprecated in - let declared_nullability = - AnnotatedNullability.get_nullability annotated_field.annotated_type.nullability - in + let declared_nullability = annotated_field.annotated_type.nullability in let assignment_check_result = - AssignmentRule.check ~lhs:declared_nullability - ~rhs:(InferredNullability.get_nullability inferred_nullability_rhs) + AssignmentRule.check ~lhs:declared_nullability ~rhs:inferred_nullability_rhs in Result.iter_error assignment_check_result ~f:(fun assignment_violation -> let should_report = @@ -159,12 +156,10 @@ let check_field_assignment && not (field_is_in_cleanup_context ()) in if should_report then - let rhs_origin = InferredNullability.get_origin inferred_nullability_rhs in TypeErr.register_error analysis_data find_canonical_duplicate (TypeErr.Bad_assignment { assignment_violation ; assignment_location= loc - ; rhs_origin ; assignment_type= AssignmentRule.ReportableViolation.AssigningToField fname }) (Some instr_ref) ~nullsafe_mode loc ) ) @@ -341,16 +336,14 @@ let check_return_not_nullable ({IntraproceduralAnalysis.proc_desc= curr_pdesc; _ ~nullsafe_mode find_canonical_duplicate loc (ret_signature : AnnotatedSignature.ret_signature) ret_inferred_nullability = (* Returning from a function is essentially an assignment the actual return value to the formal `return` *) - let lhs = AnnotatedNullability.get_nullability ret_signature.ret_annotated_type.nullability in - let rhs = InferredNullability.get_nullability ret_inferred_nullability in + let lhs = ret_signature.ret_annotated_type.nullability in + let rhs = ret_inferred_nullability in Result.iter_error (AssignmentRule.check ~lhs ~rhs) ~f:(fun assignment_violation -> - let rhs_origin = InferredNullability.get_origin ret_inferred_nullability in let curr_pname = Procdesc.get_proc_name curr_pdesc in TypeErr.register_error analysis_data find_canonical_duplicate (Bad_assignment { assignment_violation ; assignment_location= loc - ; rhs_origin ; assignment_type= ReturningFromFunction curr_pname }) None ~nullsafe_mode loc ) @@ -434,12 +427,10 @@ let check_call_parameters ({IntraproceduralAnalysis.tenv; _} as analysis_data) ~ | None -> "formal parameter " ^ Mangled.to_string formal.mangled in - let rhs_origin = InferredNullability.get_origin nullability_actual in TypeErr.register_error analysis_data find_canonical_duplicate (Bad_assignment { assignment_violation ; assignment_location= loc - ; rhs_origin ; assignment_type= PassingParamToFunction { param_signature= formal @@ -452,8 +443,8 @@ let check_call_parameters ({IntraproceduralAnalysis.tenv; _} as analysis_data) ~ if PatternMatch.type_is_class formal.param_annotated_type.typ then (* Passing a param to a function is essentially an assignment the actual param value to the formal param *) - let lhs = AnnotatedNullability.get_nullability formal.param_annotated_type.nullability in - let rhs = InferredNullability.get_nullability nullability_actual in + let lhs = formal.param_annotated_type.nullability in + let rhs = nullability_actual in Result.iter_error (AssignmentRule.check ~lhs ~rhs) ~f:(report ~nullsafe_mode) in List.iter ~f:check resolved_params diff --git a/infer/src/nullsafe/typeErr.ml b/infer/src/nullsafe/typeErr.ml index 0963ad102..ba7058c0c 100644 --- a/infer/src/nullsafe/typeErr.ml +++ b/infer/src/nullsafe/typeErr.ml @@ -80,8 +80,7 @@ type err_instance = | Bad_assignment of { assignment_violation: AssignmentRule.violation ; assignment_location: Location.t - ; assignment_type: AssignmentRule.ReportableViolation.assignment_type - ; rhs_origin: TypeOrigin.t } + ; assignment_type: AssignmentRule.ReportableViolation.assignment_type } [@@deriving compare] let pp_err_instance fmt err_instance = @@ -96,8 +95,8 @@ let pp_err_instance fmt err_instance = F.pp_print_string fmt "Over_annotation" | Nullable_dereference _ -> F.pp_print_string fmt "Nullable_dereference" - | Bad_assignment {rhs_origin} -> - F.fprintf fmt "Bad_assignment: rhs %s" (TypeOrigin.to_string rhs_origin) + | Bad_assignment _ -> + F.fprintf fmt "Bad_assignment" module H = Hashtbl.Make (struct @@ -257,7 +256,7 @@ let get_error_info_if_reportable_lazy ~nullsafe_mode err_instance = , IssueType.eradicate_field_not_initialized , None , NullsafeMode.severity nullsafe_mode ) ) - | Bad_assignment {rhs_origin; assignment_location; assignment_type; assignment_violation} -> + | Bad_assignment {assignment_location; assignment_type; assignment_violation} -> (* If violation is reportable, create tuple, otherwise None *) let+ reportable_violation = AssignmentRule.ReportableViolation.from nullsafe_mode assignment_violation @@ -265,7 +264,7 @@ let get_error_info_if_reportable_lazy ~nullsafe_mode err_instance = lazy (let description, issue_type, error_location = AssignmentRule.ReportableViolation.get_description ~assignment_location assignment_type - ~rhs_origin reportable_violation + reportable_violation in let severity = AssignmentRule.ReportableViolation.get_severity reportable_violation in (description, issue_type, Some error_location, severity) ) diff --git a/infer/src/nullsafe/typeErr.mli b/infer/src/nullsafe/typeErr.mli index e83a876ff..bce876caa 100644 --- a/infer/src/nullsafe/typeErr.mli +++ b/infer/src/nullsafe/typeErr.mli @@ -54,8 +54,7 @@ type err_instance = | Bad_assignment of { assignment_violation: AssignmentRule.violation ; assignment_location: Location.t - ; assignment_type: AssignmentRule.ReportableViolation.assignment_type - ; rhs_origin: TypeOrigin.t } + ; assignment_type: AssignmentRule.ReportableViolation.assignment_type } [@@deriving compare] val pp_err_instance : Format.formatter -> err_instance -> unit