[pulse] streamline atom normalization

Summary:
This does a bunch of things at once (sorry):

- Refactor atom/term normalisation so that terms that are really just
  atoms become atoms.

- Use this to not bother adding special cases in the functions exported
  in the .mli: `and_less_than`, `and_equal_binop`, `prune_binop`, etc.
  all had special cases to avoid introducing terms that could be atoms.
  That's not great because the same smarts wasn't applied to terms that
  would only become atom-like after some normalisation, and led to weird
  and duplicated code. Now it's much cleaner: just add the most
  straighforward fact and normalise!

- Fix a bug: adding a new equality `x = linear` should *not* be done
  using `Normalizer.merge_var_linarith` as this is an internal function
  that assumes that `x` is the right representative in `x - linear`.
  Instead, for abitrary equations of that form, `solve_eq` should be used.

- When `normalize_linear_eqs` discovers new linear equalities, normalize
  again. Add fuel there too to avoid spending too much time doing that.
  It could be that we don't need/want fuel there but then we'd need to
  think very hard about why there's no infinite recursion possible and
  that seems harder.

Reviewed By: skcho

Differential Revision: D23241282

fbshipit-source-id: e5b8c4759
master
Jules Villard 5 years ago committed by Facebook GitHub Bot
parent 7df30b0c4e
commit ecdb153579

@ -65,10 +65,6 @@ module LinArith : sig
val of_var : Var.t -> t
val of_intlit : IntLit.t -> t
val of_operand : operand -> t
val get_as_const : t -> Q.t option
(** [get_as_const l] is [Some c] if [l=c], else [None] *)
@ -158,10 +154,6 @@ end = struct
let of_q q = (Var.Map.empty, q)
let of_intlit i = IntLit.to_big_int i |> Q.of_bigint |> of_q
let of_operand = function AbstractValueOperand v -> of_var v | LiteralOperand i -> of_intlit i
let get_as_const (vs, c) = if Var.Map.is_empty vs then Some c else None
let get_as_var (vs, c) =
@ -725,8 +717,6 @@ module Atom = struct
| NotEqual of Term.t * Term.t
[@@deriving compare]
type atom = t
let pp_with_pp_var pp_var fmt atom =
(* add parens around terms that look like atoms to disambiguate *)
let needs_paren (t : Term.t) =
@ -770,6 +760,23 @@ module Atom = struct
(acc, t')
let equal t1 t2 = Equal (t1, t2)
let less_equal t1 t2 = LessEqual (t1, t2)
let less_than t1 t2 = LessThan (t1, t2)
let nnot = function
| Equal (t1, t2) ->
NotEqual (t1, t2)
| NotEqual (t1, t2) ->
Equal (t1, t2)
| LessEqual (t1, t2) ->
LessThan (t2, t1)
| LessThan (t1, t2) ->
LessEqual (t2, t1)
let map_terms atom ~f = fold_map_terms atom ~init:() ~f:(fun () t -> ((), f t)) |> snd
let to_term : t -> Term.t = function
@ -796,30 +803,40 @@ module Atom = struct
to_term atom
let rec eval_term t =
let t =
Term.map_direct_subterms ~f:eval_term t
|> Term.simplify_shallow |> Term.eval_const_shallow |> Term.linearize |> Term.simplify_linear
in
match (t : Term.t) with
let atom_of_term : Term.t -> t option = function
(* terms that are atoms can be simplified in [eval_atom] *)
| LessEqual (t1, t2) ->
eval_atom (LessEqual (t1, t2) : atom) |> term_of_eval_result
Some (LessEqual (t1, t2))
| LessThan (t1, t2) ->
eval_atom (LessThan (t1, t2) : atom) |> term_of_eval_result
Some (LessThan (t1, t2))
| Equal (t1, t2) ->
eval_atom (Equal (t1, t2) : atom) |> term_of_eval_result
Some (Equal (t1, t2))
| NotEqual (t1, t2) ->
eval_atom (NotEqual (t1, t2) : atom) |> term_of_eval_result
Some (NotEqual (t1, t2))
| _ ->
None
let term_is_atom t = atom_of_term t |> Option.is_some
let rec eval_term t =
let t =
Term.map_direct_subterms ~f:eval_term t
|> Term.simplify_shallow |> Term.eval_const_shallow |> Term.linearize |> Term.simplify_linear
in
match atom_of_term t with
| Some atom ->
(* terms that are atoms can be simplified in [eval_atom] *)
eval_atom atom |> term_of_eval_result
| None ->
t
(** This assumes that the terms in the atom have been normalized/evaluated already. *)
and eval_atom (atom : t) =
let t1, t2 = get_terms atom in
match (t1, t2) with
| Const c1, Const c2 -> (
match (atom, t1, t2) with
| _, Const c1, Const c2 -> (
match atom with
| Equal _ ->
eval_result_of_bool (Q.equal c1 c2)
@ -829,7 +846,7 @@ module Atom = struct
eval_result_of_bool (Q.leq c1 c2)
| LessThan _ ->
eval_result_of_bool (Q.lt c1 c2) )
| Linear l1, Linear l2 ->
| _, Linear l1, Linear l2 ->
let l = LinArith.subtract l1 l2 in
let t = Term.simplify_linear (Linear l) in
eval_atom
@ -842,18 +859,29 @@ module Atom = struct
LessEqual (t, Term.zero)
| LessThan _ ->
LessThan (t, Term.zero) )
| (Equal _ | NotEqual _), _, _
when (Term.is_zero t1 && term_is_atom t2) || (Term.is_zero t2 && term_is_atom t1) -> (
(* [atom = 0] or [atom ≠ 0] can be interpreted as [!atom] and [atom], respectively *)
match (atom, atom_of_term t1, atom_of_term t2) with
| _, Some _, Some _ | _, None, None | LessEqual _, _, _ | LessThan _, _, _ ->
(* impossible thanks to the match pattern and guard *)
assert false
| NotEqual _, Some atom, None | NotEqual _, None, Some atom ->
(* [atom] is true *) eval_atom atom
| Equal _, Some atom, None | Equal _, None, Some atom ->
(* [atom] is false *) eval_atom (nnot atom) )
| _ when Term.equal_syntax t1 t2 -> (
match atom with
| Equal _ ->
True
| NotEqual _ ->
False
| LessEqual _ ->
True
| LessThan _ ->
False )
| _ ->
if Term.equal_syntax t1 t2 then
match atom with
| Equal _ ->
True
| NotEqual _ ->
False
| LessEqual _ ->
True
| LessThan _ ->
False
else Atom atom
Atom atom
let eval (atom : t) = map_terms atom ~f:eval_term |> eval_atom
@ -937,6 +965,8 @@ module Normalizer : sig
val and_var_var : Var.t -> Var.t -> t -> t normalized
val and_atom : Atom.t -> t -> t normalized
val normalize : t -> t normalized
end = struct
(* Use the monadic notations when normalizing formulas. *)
@ -1057,16 +1087,31 @@ end = struct
(** an arbitrary value *)
let fuel = 5
let and_var_linarith v l phi = merge_var_linarith ~fuel v l phi
let and_var_linarith v l phi = solve_eq ~fuel l (LinArith.of_var v) phi
let and_var_var v1 v2 phi = merge_vars ~fuel v1 v2 phi
let normalize_linear_eqs phi0 =
(* reconstruct the relation from scratch *)
Var.Map.fold
(fun v l phi -> bind_normalized (and_var_linarith v (apply phi0 l)) phi)
phi0.linear_eqs
(Sat {phi0 with linear_eqs= Var.Map.empty})
let rec normalize_linear_eqs ~fuel phi0 =
let* changed, phi' =
(* reconstruct the relation from scratch *)
Var.Map.fold
(fun v l acc ->
let* changed, phi = acc in
let l' = apply phi0 l in
let+ phi' = and_var_linarith v l' phi in
(changed || not (phys_equal l l'), phi') )
phi0.linear_eqs
(Sat (false, {phi0 with linear_eqs= Var.Map.empty}))
in
if changed then
if fuel > 0 then (
L.d_printfln "going around one more time normalizing the linear equalities" ;
(* do another pass if we can affort it *)
normalize_linear_eqs ~fuel:(fuel - 1) phi' )
else (
L.d_printfln "ran out of fuel normalizing the linear equalities" ;
Sat phi' )
else Sat phi0
let normalize_atom phi (atom : Atom.t) =
@ -1083,142 +1128,55 @@ end = struct
Atom.eval atom' |> sat_of_eval_result
let normalize_atoms phi =
let+ phi, atoms =
IContainer.fold_of_pervasives_set_fold Atom.Set.fold phi.atoms
~init:(Sat (phi, Atom.Set.empty))
~f:(fun acc atom ->
let* phi, facts = acc in
normalize_atom phi atom
>>= function
| None ->
acc
| Some (Atom.Equal (Linear l, Const c)) | Some (Atom.Equal (Const c, Linear l)) ->
(* NOTE: {!normalize_atom} calls {!Atom.eval}, which normalizes linear equalities so
they end up only on one side, hence only this match case is needed to detect linear
equalities *)
let+ phi = solve_eq ~fuel:5 l (LinArith.of_q c) phi in
(phi, facts)
| Some atom' ->
Sat (phi, Atom.Set.add atom' facts) )
in
{phi with atoms}
let and_atom atom phi =
normalize_atom phi atom
>>= function
| None ->
Sat phi
| Some (Atom.Equal (Linear l, Const c)) | Some (Atom.Equal (Const c, Linear l)) ->
(* NOTE: {!normalize_atom} calls {!Atom.eval}, which normalizes linear equalities so
they end up only on one side, hence only this match case is needed to detect linear
equalities *)
solve_eq ~fuel l (LinArith.of_q c) phi
| Some atom' ->
Sat {phi with atoms= Atom.Set.add atom' phi.atoms}
let normalize phi = normalize_linear_eqs phi >>= normalize_atoms
end
let normalize_atoms phi =
let atoms0 = phi.atoms in
let init = Sat {phi with atoms= Atom.Set.empty} in
IContainer.fold_of_pervasives_set_fold Atom.Set.fold atoms0 ~init ~f:(fun acc atom ->
let* phi = acc in
and_atom atom phi )
let and_equal op1 op2 phi =
match (op1, op2) with
| LiteralOperand i1, LiteralOperand i2 ->
if IntLit.eq i1 i2 then Sat phi else Unsat
| AbstractValueOperand v, LiteralOperand i | LiteralOperand i, AbstractValueOperand v ->
Normalizer.and_var_linarith v (LinArith.of_intlit i) phi
| AbstractValueOperand v1, AbstractValueOperand v2 ->
Normalizer.and_var_var v1 v2 phi
let normalize phi = normalize_linear_eqs ~fuel phi >>= normalize_atoms
end
let and_atom atom phi = {phi with atoms= Atom.Set.add atom phi.atoms}
let and_mk_atom mk_atom op1 op2 phi =
Normalizer.and_atom (mk_atom (Term.of_operand op1) (Term.of_operand op2)) phi
let and_less_equal op1 op2 phi =
match (op1, op2) with
| LiteralOperand i1, LiteralOperand i2 ->
if IntLit.leq i1 i2 then Sat phi else Unsat
| _ ->
Sat (and_atom (LessEqual (Term.of_operand op1, Term.of_operand op2)) phi)
let and_equal = and_mk_atom Atom.equal
let and_less_than op1 op2 phi =
match (op1, op2) with
| LiteralOperand i1, LiteralOperand i2 ->
if IntLit.lt i1 i2 then Sat phi else Unsat
| _ ->
Sat (and_atom (LessThan (Term.of_operand op1, Term.of_operand op2)) phi)
let and_less_equal = and_mk_atom Atom.less_equal
let and_less_than = and_mk_atom Atom.less_than
let and_equal_unop v (op : Unop.t) x phi =
match op with
| Neg ->
Normalizer.and_var_linarith v LinArith.(minus (of_operand x)) phi
| BNot | LNot ->
Sat (and_atom (Equal (Term.Var v, Term.of_unop op (Term.of_operand x))) phi)
Normalizer.and_atom (Equal (Var v, Term.of_unop op (Term.of_operand x))) phi
let and_equal_binop v (bop : Binop.t) x y phi =
let and_linear_eq l = Normalizer.and_var_linarith v l phi in
match bop with
| PlusA _ | PlusPI ->
LinArith.(add (of_operand x) (of_operand y)) |> and_linear_eq
| MinusA _ | MinusPI | MinusPP ->
LinArith.(subtract (of_operand x) (of_operand y)) |> and_linear_eq
(* TODO: some of the below could become linear arithmetic after simplifications (e.g. up to constants) *)
| Mult _
| Div
| Mod
| Shiftlt
| Shiftrt
| BAnd
| BXor
| BOr
(* TODO: (most) logical operators should be translated into the formula structure *)
| Lt
| Gt
| Le
| Ge
| Eq
| Ne
| LAnd
| LOr ->
Sat
(and_atom
(Equal (Term.Var v, Term.of_binop bop (Term.of_operand x) (Term.of_operand y)))
phi)
Normalizer.and_atom (Equal (Var v, Term.of_binop bop (Term.of_operand x) (Term.of_operand y))) phi
let prune_binop ~negated (bop : Binop.t) x y phi =
let atom op x y =
let tx = Term.of_operand x in
let ty = Term.of_operand y in
let atom : Atom.t =
match op with
| `LessThan ->
LessThan (tx, ty)
| `LessEqual ->
LessEqual (tx, ty)
| `NotEqual ->
NotEqual (tx, ty)
in
Sat (and_atom atom phi)
in
match (bop, negated) with
| Eq, false | Ne, true ->
and_equal x y phi
| Ne, false | Eq, true ->
atom `NotEqual x y
| Lt, false | Ge, true ->
atom `LessThan x y
| Le, false | Gt, true ->
atom `LessEqual x y
| Ge, false | Lt, true ->
atom `LessEqual y x
| Gt, false | Le, true ->
atom `LessThan y x
| LAnd, _
| LOr, _
| Mult _, _
| Div, _
| Mod, _
| Shiftlt, _
| Shiftrt, _
| BAnd, _
| BXor, _
| BOr, _
| PlusA _, _
| PlusPI, _
| MinusA _, _
| MinusPI, _
| MinusPP, _ ->
Sat phi
let tx = Term.of_operand x in
let ty = Term.of_operand y in
let t = Term.of_binop bop tx ty in
let atom = if negated then Atom.Equal (t, Term.zero) else Atom.NotEqual (t, Term.zero) in
Normalizer.and_atom atom phi
let normalize phi = Normalizer.normalize phi
@ -1260,7 +1218,7 @@ let and_fold_map_variables phi0 ~up_to_f:phi_foreign ~init ~f =
let acc_f, v' = f acc_f v in
(acc_f, VarSubst v') )
in
let phi = and_atom atom phi in
let phi = Normalizer.and_atom atom phi |> sat_value_exn in
(acc_f, phi) )
in
try Sat (and_var_eqs (init, phi0) |> and_linear_eqs |> and_atoms) with Contradiction -> Unsat

@ -176,9 +176,9 @@ let%test_module "normalization" =
{|
true (no var=var)
&&
v7 = x + v6 v8 = x + v6 +1 v10 = 0
x = -v6 + v8 -1 v7 = v8 -1 v10 = 0
&&
{0 = [v9]÷[w]}{[v6] = [v]×[y]}{[v9] = [z]×[x + v6 +1]} |}]
{0 = [v9]÷[w]}{[v6] = [v]×[y]}{[v9] = [z]×[v8]} |}]
(* check that this becomes all linear equalities *)
let%expect_test _ =
@ -187,7 +187,7 @@ let%test_module "normalization" =
{|
true (no var=var)
&&
x = -v6 + 1/12·v9 -1 y = 1/3·v6 v7 = x + v6 v8 = x + v6 +1 v9 = 0 v10 = 0
x = -v6 -1 y = 1/3·v6 v7 = -1 v8 = 0 v9 = 0 v10 = 0
&&
true (no atoms)|}]
@ -198,8 +198,8 @@ let%test_module "normalization" =
{|
true (no var=var)
&&
x = -v6 + 1/12·v9 -1 y = 1/3·v6 z = 12 w = 7 v = 3 v7 = x + v6
v8 = x + v6 +1 v9 = 0 v10 = 0
x = -v6 + v8 -1 y = 1/3·v6 z = 12 w = 7 v = 3 v7 = v8 -1
v8 = 1/12·v9 v9 = 0 v10 = 0
&&
true (no atoms)|}]
end )
@ -221,17 +221,12 @@ let%test_module "variable elimination" =
(* should keep most of this or realize that [w = z] hence this boils down to [z+1 = 0] *)
let%expect_test _ =
simplify ~keep:[y_var; z_var] (x = y + z && w = x - y && v = w + i 1 && v = i 0) ;
[%expect {|x=v6 z=w=v7 && x = y -1 z = -1 && true (no atoms)|}]
[%expect {|x=v6 && x = y -1 z = -1 && true (no atoms)|}]
let%expect_test _ =
simplify ~keep:[x_var; y_var] (x = y + z && w + x + y = i 0 && v = w + i 1) ;
[%expect
{|
x=v6 v=v9
&&
x = 1/2·z + -1/2·w y = -1/2·z + -1/2·w v = w +1 v7 = 1/2·z + 1/2·w
&&
true (no atoms)|}]
{|x=v6 v=v9 && x = -v + v7 +1 y = -v7 z = -v + 2·v7 +1 w = v -1 && true (no atoms)|}]
let%expect_test _ =
simplify ~keep:[x_var; y_var] (x = y + i 4 && x = w && y = z) ;
@ -242,9 +237,9 @@ let%test_module "non-linear simplifications" =
( module struct
let%expect_test "zero propagation" =
simplify ~keep:[w_var] (((i 0 / (x * z)) & v) * v mod y = w) ;
[%expect {|w=v10 && w = 0 && true (no atoms)|}]
[%expect {|true (no var=var) && w = 0 && true (no atoms)|}]
let%expect_test "constant propagation: bitshift" =
simplify ~keep:[x_var] (of_binop Shiftlt (of_binop Shiftrt (i 0b111) (i 2)) (i 2) = x) ;
[%expect {|x=v7 && x = 4 && true (no atoms)|}]
[%expect {|true (no var=var) && x = 4 && true (no atoms)|}]
end )

Loading…
Cancel
Save