diff --git a/infer/src/nullsafe/typeState.ml b/infer/src/nullsafe/typeState.ml index 533c1ed52..02e4a6aa4 100644 --- a/infer/src/nullsafe/typeState.ml +++ b/infer/src/nullsafe/typeState.ml @@ -47,46 +47,17 @@ let range_add_locs (typ, ta, locs1) locs2 = (typ, ta, locs') -(** Only keep variables if they are present on both sides of the join. *) -let only_keep_intersection = true - -(** Join two maps. - If only_keep_intersection is true, keep only variables present on both sides. *) let map_join m1 m2 = - let tjoined = ref (if only_keep_intersection then M.empty else m1) in - let range_join (typ1, ta1, locs1) (typ2, ta2, locs2) = - match TypeAnnotation.join ta1 ta2 with - | None -> - None - | Some ta' -> - let typ' = type_join typ1 typ2 in - let locs' = locs_join locs1 locs2 in - Some (typ', ta', locs') - in - let extend_lhs exp2 range2 = - (* extend lhs if possible, otherwise return false *) - try - let range1 = M.find exp2 m1 in - match range_join range1 range2 with - | None -> - if only_keep_intersection then tjoined := M.add exp2 range1 !tjoined - | Some range' -> - tjoined := M.add exp2 range' !tjoined - with Caml.Not_found -> - if not only_keep_intersection then tjoined := M.add exp2 range2 !tjoined - in - let missing_rhs exp1 range1 = - (* handle elements missing in the rhs *) - try ignore (M.find exp1 m2) - with Caml.Not_found -> - let t1, ta1, locs1 = range1 in - let range1' = - let ta1' = TypeAnnotation.with_origin ta1 TypeOrigin.Undef in - (t1, ta1', locs1) - in - if not only_keep_intersection then tjoined := M.add exp1 range1' !tjoined + let range_join _exp range1_opt range2_opt = + Option.both range1_opt range2_opt + |> Option.map ~f:(fun (((typ1, ta1, locs1) as range1), (typ2, ta2, locs2)) -> + TypeAnnotation.join ta1 ta2 + |> Option.value_map ~default:range1 ~f:(fun ta' -> + let typ' = type_join typ1 typ2 in + let locs' = locs_join locs1 locs2 in + (typ', ta', locs') ) ) in - if phys_equal m1 m2 then m1 else ( M.iter extend_lhs m2 ; M.iter missing_rhs m1 ; !tjoined ) + if phys_equal m1 m2 then m1 else M.merge range_join m1 m2 let join t1 t2 =