diff --git a/infer/src/pulse/PulseFormula.ml b/infer/src/pulse/PulseFormula.ml index f0b8684bc..c9ff0ee50 100644 --- a/infer/src/pulse/PulseFormula.ml +++ b/infer/src/pulse/PulseFormula.ml @@ -792,6 +792,14 @@ module Atom = struct type eval_result = True | False | Atom of t + module EvalResultMonad = struct + let bind_eval_result eval_result f = + match eval_result with True | False -> eval_result | Atom atom -> f atom + + + let ( let* ) x f = bind_eval_result x f + end + let eval_result_of_bool b = if b then True else False let term_of_eval_result = function @@ -819,37 +827,31 @@ module Atom = struct let term_is_atom t = atom_of_term t |> Option.is_some - let rec eval_term t = - let t = - Term.map_direct_subterms ~f:eval_term t - |> Term.simplify_shallow |> Term.eval_const_shallow |> Term.linearize |> Term.simplify_linear + let eval_const_shallow atom = + let on_const f = + match get_terms atom with + | Const c1, Const c2 -> + f c1 c2 |> eval_result_of_bool + | _ -> + Atom atom in - match atom_of_term t with - | Some atom -> - (* terms that are atoms can be simplified in [eval_atom] *) - eval_atom atom |> term_of_eval_result - | None -> - t + match atom with + | Equal _ -> + on_const Q.equal + | NotEqual _ -> + on_const Q.not_equal + | LessEqual _ -> + on_const Q.leq + | LessThan _ -> + on_const Q.lt - (** This assumes that the terms in the atom have been normalized/evaluated already. *) - and eval_atom (atom : t) = - let t1, t2 = get_terms atom in - match (atom, t1, t2) with - | _, Const c1, Const c2 -> ( - match atom with - | Equal _ -> - eval_result_of_bool (Q.equal c1 c2) - | NotEqual _ -> - eval_result_of_bool (Q.not_equal c1 c2) - | LessEqual _ -> - eval_result_of_bool (Q.leq c1 c2) - | LessThan _ -> - eval_result_of_bool (Q.lt c1 c2) ) - | _, Linear l1, Linear l2 -> + let get_as_linear atom = + match get_terms atom with + | Linear l1, Linear l2 -> let l = LinArith.subtract l1 l2 in let t = Term.simplify_linear (Linear l) in - eval_atom + Some ( match atom with | Equal _ -> Equal (t, Term.zero) @@ -859,18 +861,39 @@ module Atom = struct LessEqual (t, Term.zero) | LessThan _ -> LessThan (t, Term.zero) ) - | (Equal _ | NotEqual _), _, _ - when (Term.is_zero t1 && term_is_atom t2) || (Term.is_zero t2 && term_is_atom t1) -> ( - (* [atom = 0] or [atom ≠ 0] can be interpreted as [!atom] and [atom], respectively *) - match (atom, atom_of_term t1, atom_of_term t2) with - | _, Some _, Some _ | _, None, None | LessEqual _, _, _ | LessThan _, _, _ -> - (* impossible thanks to the match pattern and guard *) - assert false - | NotEqual _, Some atom, None | NotEqual _, None, Some atom -> - (* [atom] is true *) eval_atom atom - | Equal _, Some atom, None | Equal _, None, Some atom -> - (* [atom] is false *) eval_atom (nnot atom) ) - | _ when Term.equal_syntax t1 t2 -> ( + | _ -> + None + + + let get_as_embedded_atom atom = + let of_terms is_equal t1 t2 = + match (atom_of_term t1, t2) with + | Some atom, Term.Const c -> + (* [atom = 0] or [atom ≠ 1] means [atom] is false, [atom ≠ 0] or [atom = 1] means [atom] + is true *) + let positive = (is_equal && Q.is_one c) || ((not is_equal) && Q.is_zero c) in + if positive then Some atom else Some (nnot atom) + | _ -> + None + in + (* [of_terms] is written for only one side, the one where [t1] is the potential atom *) + let of_terms_symmetry is_equal atom = + let t1, t2 = get_terms atom in + let t1, t2 = if term_is_atom t1 then (t1, t2) else (t2, t1) in + of_terms is_equal t1 t2 + in + match atom with + | Equal (Const _, _) | Equal (_, Const _) -> + of_terms_symmetry true atom + | NotEqual (Const _, _) | NotEqual (_, Const _) -> + of_terms_symmetry false atom + | _ -> + None + + + let eval_syntactically_equal_terms atom = + let t1, t2 = get_terms atom in + if Term.equal_syntax t1 t2 then match atom with | Equal _ -> True @@ -879,9 +902,36 @@ module Atom = struct | LessEqual _ -> True | LessThan _ -> - False ) - | _ -> - Atom atom + False + else Atom atom + + + (** This assumes that the terms in the atom have been normalized/evaluated already. *) + let rec eval_atom (atom : t) = + let open EvalResultMonad in + let* atom = eval_const_shallow atom in + match get_as_linear atom with + | Some atom' -> + eval_atom atom' + | None -> ( + match get_as_embedded_atom atom with + | Some atom' -> + eval_atom atom' + | None -> + eval_syntactically_equal_terms atom ) + + + let rec eval_term t = + let t = + Term.map_direct_subterms ~f:eval_term t + |> Term.simplify_shallow |> Term.eval_const_shallow |> Term.linearize |> Term.simplify_linear + in + match atom_of_term t with + | Some atom -> + (* terms that are atoms can be simplified in [eval_atom] *) + eval_atom atom |> term_of_eval_result + | None -> + t let eval (atom : t) = map_terms atom ~f:eval_term |> eval_atom