[pulse] normalize any linear atom

Summary:
Linear arithmetic is able to simplify more atoms, eg `x+y <= x+y`
becomes `True` by normalising to "lhs - rhs <= 0". This does the first
step of normalisation, but to get True in this example we also need to
substitute inside atoms according to the linear equalities, which is the
next diff (for now we only substitute variables inside atoms for other
variables or for constants).

Reviewed By: skcho

Differential Revision: D23241457

fbshipit-source-id: 0da0b545c
master
Jules Villard 4 years ago committed by Facebook GitHub Bot
parent 69995cebb6
commit 36af901d79

@ -660,46 +660,55 @@ module Term = struct
(** more or less syntactic attempt at detecting when an arbitrary term is a linear formula; call (** more or less syntactic attempt at detecting when an arbitrary term is a linear formula; call
{!Atom.eval_term} first for best results *) {!Atom.eval_term} first for best results *)
let rec to_lin_arith t = let linearize t =
(* NOTE: don't duplicate simplifications between here and {!Atom.eval_term} *) let rec aux_linearize t =
let open IOption.Let_syntax in let open IOption.Let_syntax in
match t with match t with
| Var v -> | Var v ->
Some (LinArith.of_var v) Some (LinArith.of_var v)
| Const c -> | Const c ->
Some (LinArith.of_q c) Some (LinArith.of_q c)
| Linear l -> | Linear l ->
Some l Some l
| Minus t -> | Minus t ->
let+ l = to_lin_arith t in let+ l = aux_linearize t in
LinArith.minus l LinArith.minus l
| Add (t1, t2) -> | Add (t1, t2) ->
let* l1 = to_lin_arith t1 in let* l1 = aux_linearize t1 in
let+ l2 = to_lin_arith t2 in let+ l2 = aux_linearize t2 in
LinArith.add l1 l2 LinArith.add l1 l2
| Mult (Const c, t) | Mult (t, Const c) -> | Mult (Const c, t) | Mult (t, Const c) ->
let+ l = to_lin_arith t in let+ l = aux_linearize t in
LinArith.mult c l LinArith.mult c l
| Div (t, Const c) when Q.is_not_zero c -> | Div (t, Const c) when Q.is_not_zero c ->
let+ l = to_lin_arith t in let+ l = aux_linearize t in
LinArith.mult (Q.inv c) l LinArith.mult (Q.inv c) l
| Mult _ | Mult _
| Div _ | Div _
| Mod _ | Mod _
| BitNot _ | BitNot _
| BitAnd _ | BitAnd _
| BitOr _ | BitOr _
| BitShiftLeft _ | BitShiftLeft _
| BitShiftRight _ | BitShiftRight _
| BitXor _ | BitXor _
| Not _ | Not _
| And _ | And _
| Or _ | Or _
| LessThan _ | LessThan _
| LessEqual _ | LessEqual _
| Equal _ | Equal _
| NotEqual _ -> | NotEqual _ ->
None None
in
match aux_linearize t with None -> t | Some l -> Linear l
let simplify_linear = function
| Linear l -> (
match LinArith.get_as_const l with Some c -> Const c | None -> Linear l )
| t ->
t
end end
(** Basically boolean terms, used to build the part of a formula that is not equalities between (** Basically boolean terms, used to build the part of a formula that is not equalities between
@ -731,8 +740,6 @@ module Atom = struct
F.fprintf fmt "%a ≠ %a" pp_term t1 pp_term t2 F.fprintf fmt "%a ≠ %a" pp_term t1 pp_term t2
let pp = pp_with_pp_var Var.pp
let get_terms atom = let get_terms atom =
let (LessEqual (t1, t2) | LessThan (t1, t2) | Equal (t1, t2) | NotEqual (t1, t2)) = atom in let (LessEqual (t1, t2) | LessThan (t1, t2) | Equal (t1, t2) | NotEqual (t1, t2)) = atom in
(t1, t2) (t1, t2)
@ -787,7 +794,8 @@ module Atom = struct
let rec eval_term t = let rec eval_term t =
let t = let t =
Term.map_direct_subterms ~f:eval_term t |> Term.simplify_shallow |> Term.eval_const_shallow Term.map_direct_subterms ~f:eval_term t
|> Term.simplify_shallow |> Term.eval_const_shallow |> Term.linearize |> Term.simplify_linear
in in
match (t : Term.t) with match (t : Term.t) with
(* terms that are atoms can be simplified in [eval_atom] *) (* terms that are atoms can be simplified in [eval_atom] *)
@ -803,11 +811,7 @@ module Atom = struct
t t
(** This assumes that the terms in the atom have been normalized/evaluated already. (** This assumes that the terms in the atom have been normalized/evaluated already. *)
TODO: probably a better way to implement this would be to rely entirely on
[eval_term (term_of_atom (atom))], possibly implementing it as something about observing the
sign/zero-ness of [t1 - t2]. *)
and eval_atom (atom : t) = and eval_atom (atom : t) =
let t1, t2 = get_terms atom in let t1, t2 = get_terms atom in
match (t1, t2) with match (t1, t2) with
@ -821,6 +825,19 @@ module Atom = struct
eval_result_of_bool (Q.leq c1 c2) eval_result_of_bool (Q.leq c1 c2)
| LessThan _ -> | LessThan _ ->
eval_result_of_bool (Q.lt c1 c2) ) eval_result_of_bool (Q.lt c1 c2) )
| Linear l1, Linear l2 ->
let l = LinArith.subtract l1 l2 in
let t = Term.simplify_linear (Linear l) in
eval_atom
( match atom with
| Equal _ ->
Equal (t, Term.zero)
| NotEqual _ ->
NotEqual (t, Term.zero)
| LessEqual _ ->
LessEqual (t, Term.zero)
| LessThan _ ->
LessThan (t, Term.zero) )
| _ -> | _ ->
if Term.equal_syntax t1 t2 then if Term.equal_syntax t1 t2 then
match atom with match atom with
@ -865,6 +882,10 @@ module SatUnsatMonad = struct
let ( let* ) phi f = bind_normalized f phi let ( let* ) phi f = bind_normalized f phi
end end
let sat_of_eval_result (eval_result : Atom.eval_result) =
match eval_result with True -> Sat None | False -> Unsat | Atom atom -> Sat (Some atom)
module VarUF = module VarUF =
UnionFind.Make UnionFind.Make
(struct (struct
@ -1055,14 +1076,7 @@ end = struct
match q_opt with None -> VarSubst v_canon | Some q -> QSubst q ) match q_opt with None -> VarSubst v_canon | Some q -> QSubst q )
in in
let atom' = Atom.map_terms atom ~f:(fun t -> normalize_term phi t) in let atom' = Atom.map_terms atom ~f:(fun t -> normalize_term phi t) in
match Atom.eval atom' with Atom.eval atom' |> sat_of_eval_result
| True ->
Sat None
| False ->
L.d_printfln "Contradiction detected! %a ~> %a is False" Atom.pp atom Atom.pp atom' ;
Unsat
| Atom atom ->
Sat (Some atom)
let normalize_atoms phi = let normalize_atoms phi =
@ -1075,15 +1089,12 @@ end = struct
>>= function >>= function
| None -> | None ->
acc acc
| Some (Atom.Equal (t1, t2) as atom') -> ( | Some (Atom.Equal (Linear l, Const c)) | Some (Atom.Equal (Const c, Linear l)) ->
match Option.both (Term.to_lin_arith t1) (Term.to_lin_arith t2) with (* NOTE: {!normalize_atom} calls {!Atom.eval}, which normalizes linear equalities so
| None -> they end up only on one side, hence only this match case is needed to detect linear
Sat (phi, Atom.Set.add atom' facts) equalities *)
| Some (l1, l2) -> let+ phi = solve_eq ~fuel:5 l (LinArith.of_q c) phi in
(* an atom has been found to have become a linear equality, move it to the linear (phi, facts)
equalities *)
let+ phi = solve_eq ~fuel:5 l1 l2 phi in
(phi, facts) )
| Some atom' -> | Some atom' ->
Sat (phi, Atom.Set.add atom' facts) ) Sat (phi, Atom.Set.add atom' facts) )
in in

@ -130,7 +130,7 @@ let%test_module "normalization" =
( module struct ( module struct
let%expect_test _ = let%expect_test _ =
normalize (x < y) ; normalize (x < y) ;
[%expect {|true (no var=var) && true (no linear) && {x < y}|}] [%expect {|true (no var=var) && true (no linear) && {[x + -y] < 0}|}]
let%expect_test _ = let%expect_test _ =
normalize (x + i 1 - i 1 < x) ; normalize (x + i 1 - i 1 < x) ;
@ -166,6 +166,10 @@ let%test_module "normalization" =
normalize (x = i 0 && x < i 0) ; normalize (x = i 0 && x < i 0) ;
[%expect {|unsat|}] [%expect {|unsat|}]
let%expect_test _ =
normalize (x + y < x + y) ;
[%expect {|true (no var=var) && v6 = x + y v7 = x + y && {[v6 + -v7] < 0}|}]
let%expect_test "nonlinear arithmetic" = let%expect_test "nonlinear arithmetic" =
normalize (z * (x + (v * y) + i 1) / w = i 0) ; normalize (z * (x + (v * y) + i 1) / w = i 0) ;
[%expect [%expect
@ -174,7 +178,7 @@ let%test_module "normalization" =
&& &&
v7 = x + v6 v8 = x + v6 +1 v10 = 0 v7 = x + v6 v8 = x + v6 +1 v10 = 0
&& &&
{0 = v9÷w}{v6 = v×y}{v9 = z×v8} |}] {0 = [v9]÷[w]}{[v6] = [v]×[y]}{[v9] = [z]×[v8]} |}]
(* check that this becomes all linear equalities *) (* check that this becomes all linear equalities *)
let%expect_test _ = let%expect_test _ =

Loading…
Cancel
Save