[pulse] refactor Atom.eval_atom

Summary:
This function had become a bit hard to read and the part about embedded
atoms was not very clear and also a bit incomplete (need to handle "= 1"
and "≠ 1" too).

Reviewed By: skcho

Differential Revision: D23242216

fbshipit-source-id: 239fade97
master
Jules Villard 4 years ago committed by Facebook GitHub Bot
parent ecdb153579
commit 8b23fee8f8

@ -792,6 +792,14 @@ module Atom = struct
type eval_result = True | False | Atom of t
module EvalResultMonad = struct
let bind_eval_result eval_result f =
match eval_result with True | False -> eval_result | Atom atom -> f atom
let ( let* ) x f = bind_eval_result x f
end
let eval_result_of_bool b = if b then True else False
let term_of_eval_result = function
@ -819,37 +827,31 @@ module Atom = struct
let term_is_atom t = atom_of_term t |> Option.is_some
let rec eval_term t =
let t =
Term.map_direct_subterms ~f:eval_term t
|> Term.simplify_shallow |> Term.eval_const_shallow |> Term.linearize |> Term.simplify_linear
let eval_const_shallow atom =
let on_const f =
match get_terms atom with
| Const c1, Const c2 ->
f c1 c2 |> eval_result_of_bool
| _ ->
Atom atom
in
match atom_of_term t with
| Some atom ->
(* terms that are atoms can be simplified in [eval_atom] *)
eval_atom atom |> term_of_eval_result
| None ->
t
(** This assumes that the terms in the atom have been normalized/evaluated already. *)
and eval_atom (atom : t) =
let t1, t2 = get_terms atom in
match (atom, t1, t2) with
| _, Const c1, Const c2 -> (
match atom with
| Equal _ ->
eval_result_of_bool (Q.equal c1 c2)
on_const Q.equal
| NotEqual _ ->
eval_result_of_bool (Q.not_equal c1 c2)
on_const Q.not_equal
| LessEqual _ ->
eval_result_of_bool (Q.leq c1 c2)
on_const Q.leq
| LessThan _ ->
eval_result_of_bool (Q.lt c1 c2) )
| _, Linear l1, Linear l2 ->
on_const Q.lt
let get_as_linear atom =
match get_terms atom with
| Linear l1, Linear l2 ->
let l = LinArith.subtract l1 l2 in
let t = Term.simplify_linear (Linear l) in
eval_atom
Some
( match atom with
| Equal _ ->
Equal (t, Term.zero)
@ -859,18 +861,39 @@ module Atom = struct
LessEqual (t, Term.zero)
| LessThan _ ->
LessThan (t, Term.zero) )
| (Equal _ | NotEqual _), _, _
when (Term.is_zero t1 && term_is_atom t2) || (Term.is_zero t2 && term_is_atom t1) -> (
(* [atom = 0] or [atom ≠ 0] can be interpreted as [!atom] and [atom], respectively *)
match (atom, atom_of_term t1, atom_of_term t2) with
| _, Some _, Some _ | _, None, None | LessEqual _, _, _ | LessThan _, _, _ ->
(* impossible thanks to the match pattern and guard *)
assert false
| NotEqual _, Some atom, None | NotEqual _, None, Some atom ->
(* [atom] is true *) eval_atom atom
| Equal _, Some atom, None | Equal _, None, Some atom ->
(* [atom] is false *) eval_atom (nnot atom) )
| _ when Term.equal_syntax t1 t2 -> (
| _ ->
None
let get_as_embedded_atom atom =
let of_terms is_equal t1 t2 =
match (atom_of_term t1, t2) with
| Some atom, Term.Const c ->
(* [atom = 0] or [atom ≠ 1] means [atom] is false, [atom ≠ 0] or [atom = 1] means [atom]
is true *)
let positive = (is_equal && Q.is_one c) || ((not is_equal) && Q.is_zero c) in
if positive then Some atom else Some (nnot atom)
| _ ->
None
in
(* [of_terms] is written for only one side, the one where [t1] is the potential atom *)
let of_terms_symmetry is_equal atom =
let t1, t2 = get_terms atom in
let t1, t2 = if term_is_atom t1 then (t1, t2) else (t2, t1) in
of_terms is_equal t1 t2
in
match atom with
| Equal (Const _, _) | Equal (_, Const _) ->
of_terms_symmetry true atom
| NotEqual (Const _, _) | NotEqual (_, Const _) ->
of_terms_symmetry false atom
| _ ->
None
let eval_syntactically_equal_terms atom =
let t1, t2 = get_terms atom in
if Term.equal_syntax t1 t2 then
match atom with
| Equal _ ->
True
@ -879,9 +902,36 @@ module Atom = struct
| LessEqual _ ->
True
| LessThan _ ->
False )
| _ ->
Atom atom
False
else Atom atom
(** This assumes that the terms in the atom have been normalized/evaluated already. *)
let rec eval_atom (atom : t) =
let open EvalResultMonad in
let* atom = eval_const_shallow atom in
match get_as_linear atom with
| Some atom' ->
eval_atom atom'
| None -> (
match get_as_embedded_atom atom with
| Some atom' ->
eval_atom atom'
| None ->
eval_syntactically_equal_terms atom )
let rec eval_term t =
let t =
Term.map_direct_subterms ~f:eval_term t
|> Term.simplify_shallow |> Term.eval_const_shallow |> Term.linearize |> Term.simplify_linear
in
match atom_of_term t with
| Some atom ->
(* terms that are atoms can be simplified in [eval_atom] *)
eval_atom atom |> term_of_eval_result
| None ->
t
let eval (atom : t) = map_terms atom ~f:eval_term |> eval_atom

Loading…
Cancel
Save