[sledge] Do not solve polynomials for Mul or Div terms

Summary:
This diff adds enough interpretation of Mul and Div terms to be able
to exclude them from the domain of solution substitutions. While
non-linear arithmetic is still treated very incompletely, this change
increases the propagation power of the equality constraints that are
deduced. Mainly, this appears to be enough to avoid operations that
are semantically equivalence-preserving such as solve_for_vars from
producing equalities that are unprovable from their inputs.

Reviewed By: jvillard

Differential Revision: D20863528

fbshipit-source-id: fca74cba3
master
Josh Berdine 5 years ago committed by Facebook GitHub Bot
parent eb750ba6f9
commit 0f50d3c248

@ -14,8 +14,10 @@ type kind = Interpreted | Simplified | Atomic | Uninterpreted
let classify e = let classify e =
match (e : Term.t) with match (e : Term.t) with
| Add _ | Mul _ -> Interpreted | Add _ | Mul _
| Ap2 (Memory, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _) -> |Ap2 ((Div | Memory), _, _)
|Ap3 (Extract, _, _, _)
|ApN (Concat, _) ->
Interpreted Interpreted
| Ap2 ((Eq | Dq), _, _) -> Simplified | Ap2 ((Eq | Dq), _, _) -> Simplified
| Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted
@ -308,6 +310,8 @@ and solve_ ?f d e s =
( ((Add _ | Mul _ | Integer _ | Rational _) as p), q ( ((Add _ | Mul _ | Integer _ | Rational _) as p), q
| q, ((Add _ | Mul _ | Integer _ | Rational _) as p) ) -> | q, ((Add _ | Mul _ | Integer _ | Rational _) as p) ) ->
solve_poly ?f p q s solve_poly ?f p q s
(* e = n / d ==> e × d = n *)
| Some (rep, Ap2 (Div, num, den)) -> solve_ ?f (Term.mul rep den) num s
| Some (rep, var) -> | Some (rep, var) ->
assert (non_interpreted var) ; assert (non_interpreted var) ;
assert (non_interpreted rep) ; assert (non_interpreted rep) ;

@ -61,6 +61,9 @@ end) : S with type key = Key.t = struct
let pop m = choose m |> Option.map ~f:(fun (k, v) -> (k, v, remove m k)) let pop m = choose m |> Option.map ~f:(fun (k, v) -> (k, v, remove m k))
let pop_min_elt m =
min_elt m |> Option.map ~f:(fun (k, v) -> (k, v, remove m k))
let find_and_remove m k = let find_and_remove m k =
let found = ref None in let found = ref None in
let m = let m =

@ -47,8 +47,17 @@ module type S = sig
-> 'r -> 'r
val choose : 'a t -> (key * 'a) option val choose : 'a t -> (key * 'a) option
(** Find an unspecified element. [O(1)]. *)
val pop : 'a t -> (key * 'a * 'a t) option val pop : 'a t -> (key * 'a * 'a t) option
(** Find and remove an unspecified element. [O(1)]. *)
val pop_min_elt : 'a t -> (key * 'a * 'a t) option
(** Find and remove minimum element. [O(log n)]. *)
val find_and_remove : 'a t -> key -> ('a * 'a t) option val find_and_remove : 'a t -> key -> ('a * 'a t) option
(** Find and remove an element. *)
val pp : key pp -> 'a pp -> 'a t pp val pp : key pp -> 'a pp -> 'a t pp
val pp_diff : val pp_diff :

@ -55,6 +55,7 @@ struct
M.change m x ~f:(function Some j -> if_nz Q.(i + j) | None -> if_nz i) M.change m x ~f:(function Some j -> if_nz Q.(i + j) | None -> if_nz i)
let remove m x = M.remove m x let remove m x = M.remove m x
let find_and_remove = M.find_and_remove
let union m n = let union m n =
M.merge m n ~f:(fun ~key:_ -> function M.merge m n ~f:(fun ~key:_ -> function
@ -77,8 +78,8 @@ struct
let count m x = match M.find m x with Some q -> q | None -> Q.zero let count m x = match M.find m x with Some q -> q | None -> Q.zero
let choose = M.choose let choose = M.choose
let pop = M.pop let pop = M.pop
let min_elt_exn = M.min_elt_exn
let min_elt = M.min_elt let min_elt = M.min_elt
let pop_min_elt = M.pop_min_elt
let to_list m = M.to_alist m let to_list m = M.to_alist m
let iter m ~f = M.iteri ~f:(fun ~key ~data -> f key data) m let iter m ~f = M.iteri ~f:(fun ~key ~data -> f key data) m
let exists m ~f = M.existsi ~f:(fun ~key ~data -> f key data) m let exists m ~f = M.existsi ~f:(fun ~key ~data -> f key data) m

@ -54,13 +54,19 @@ module type S = sig
(** Multiplicity of an element. [O(log n)]. *) (** Multiplicity of an element. [O(log n)]. *)
val choose : t -> (elt * Q.t) option val choose : t -> (elt * Q.t) option
val pop : t -> (elt * Q.t * t) option (** Find an unspecified element. [O(1)]. *)
val min_elt_exn : t -> elt * Q.t val pop : t -> (elt * Q.t * t) option
(** Minimum element. *) (** Find and remove an unspecified element. [O(1)]. *)
val min_elt : t -> (elt * Q.t) option val min_elt : t -> (elt * Q.t) option
(** Minimum element. *) (** Minimum element. [O(log n)]. *)
val pop_min_elt : t -> (elt * Q.t * t) option
(** Find and remove minimum element. [O(log n)]. *)
val find_and_remove : t -> elt -> (Q.t * t) option
(** Find and remove an element. *)
val to_list : t -> (elt * Q.t) list val to_list : t -> (elt * Q.t) list
(** Convert to a list of elements in ascending order. *) (** Convert to a list of elements in ascending order. *)

@ -502,8 +502,16 @@ module Prod = struct
assert (match term with Integer _ | Rational _ -> false | _ -> true) ; assert (match term with Integer _ | Rational _ -> false | _ -> true) ;
Qset.add prod term Q.one Qset.add prod term Q.one
let singleton term = add term empty let of_ term = add term empty
let union = Qset.union let union = Qset.union
let to_term prod =
match Qset.pop prod with
| None -> one
| Some (factor, power, prod')
when Qset.is_empty prod' && Q.equal Q.one power ->
factor
| _ -> Mul prod
end end
let rec simp_add_ es poly = let rec simp_add_ es poly =
@ -547,21 +555,21 @@ and simp_mul2 e f =
| Rational {data= c}, x | x, Rational {data= c} -> | Rational {data= c}, x | x, Rational {data= c} ->
Sum.to_term (Sum.of_ ~coeff:c x) Sum.to_term (Sum.of_ ~coeff:c x)
(* (∏ᵤ₌₀ⁱ xᵤ) × (∏ᵥ₌ᵢ₊₁ⁿ xᵥ) ==> ∏ⱼ₌₀ⁿ xⱼ *) (* (∏ᵤ₌₀ⁱ xᵤ) × (∏ᵥ₌ᵢ₊₁ⁿ xᵥ) ==> ∏ⱼ₌₀ⁿ xⱼ *)
| Mul xs1, Mul xs2 -> Mul (Prod.union xs1 xs2) | Mul xs1, Mul xs2 -> Prod.to_term (Prod.union xs1 xs2)
(* (∏ᵢ xᵢ) × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ cᵤ × ∏ᵢ xᵢ × ∏ⱼ yᵤⱼ *) (* (∏ᵢ xᵢ) × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ cᵤ × ∏ᵢ xᵢ × ∏ⱼ yᵤⱼ *)
| (Mul prod as m), Add sum | Add sum, (Mul prod as m) -> | (Mul prod as m), Add sum | Add sum, (Mul prod as m) ->
Sum.to_term Sum.to_term
(Sum.map sum ~f:(function (Sum.map sum ~f:(function
| Mul args -> Mul (Prod.union prod args) | Mul args -> Prod.to_term (Prod.union prod args)
| (Integer _ | Rational _) as c -> simp_mul2 c m | (Integer _ | Rational _) as c -> simp_mul2 c m
| mono -> Mul (Prod.add mono prod) )) | mono -> Prod.to_term (Prod.add mono prod) ))
(* x₀ × (∏ᵢ₌₁ⁿ xᵢ) ==> ∏ᵢ₌₀ⁿ xᵢ *) (* x₀ × (∏ᵢ₌₁ⁿ xᵢ) ==> ∏ᵢ₌₀ⁿ xᵢ *)
| Mul xs1, x | x, Mul xs1 -> Mul (Prod.add x xs1) | Mul xs1, x | x, Mul xs1 -> Prod.to_term (Prod.add x xs1)
(* e × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ e × cᵤ × ∏ⱼ yᵤⱼ *) (* e × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ e × cᵤ × ∏ⱼ yᵤⱼ *)
| Add args, e | e, Add args -> | Add args, e | e, Add args ->
simp_add_ (Sum.map ~f:(fun m -> simp_mul2 e m) args) zero simp_add_ (Sum.map ~f:(fun m -> simp_mul2 e m) args) zero
(* x₁ × x₂ ==> ∏ᵢ₌₁² xᵢ *) (* x₁ × x₂ ==> ∏ᵢ₌₁² xᵢ *)
| _ -> Mul (Prod.add e (Prod.singleton f)) | _ -> Prod.to_term (Prod.add e (Prod.of_ f))
let rec simp_div x y = let rec simp_div x y =
match (x, y) with match (x, y) with
@ -1222,6 +1230,11 @@ let fold_terms e ~init ~f =
let iter_vars e ~f = let iter_vars e ~f =
iter_terms e ~f:(function Var _ as v -> f (v :> Var.t) | _ -> ()) iter_terms e ~f:(function Var _ as v -> f (v :> Var.t) | _ -> ())
let exists_vars e ~f =
with_return (fun {return} ->
iter_vars e ~f:(fun v -> if f v then return true) ;
false )
let fold_vars e ~init ~f = let fold_vars e ~init ~f =
fold_terms e ~init ~f:(fun s -> function fold_terms e ~init ~f:(fun s -> function
| Var _ as v -> f s (v :> Var.t) | _ -> s ) | Var _ as v -> f s (v :> Var.t) | _ -> s )
@ -1248,42 +1261,69 @@ let height e =
(** Solve *) (** Solve *)
let find_for ?for_ args = let exists_fv_in vs qset =
let exists_var args ~f = Qset.exists qset ~f:(fun e _ -> exists_vars e ~f:(Var.Set.mem vs))
with_return (fun {return} ->
Qset.iter args ~f:(fun arg _ -> let exists_fv_in4 vs w x y z =
iter_vars arg ~f:(fun v -> if f v then return true) ) ; exists_fv_in vs w || exists_fv_in vs x || exists_fv_in vs y
false ) || exists_fv_in vs z
in
let remove_if_non_occuring rejected args c q = (* solve [0 = rejected_sum + (coeff × prod) + sum] *)
let args = Qset.remove args c in let solve_for_factor rejected_sum coeff prod sum =
let fv_c = fv c in let rec find_factor rejected_prod prod =
if exists_var ~f:(Var.Set.mem fv_c) args then None let* factor, power, prod = Qset.pop_min_elt prod in
else Some (c, q, Qset.union rejected args) if
(not (Q.equal Q.one power))
|| exists_fv_in4 (fv factor) rejected_sum rejected_prod prod sum
then find_factor (Qset.add rejected_prod factor power) prod
else Some (factor, Qset.union rejected_prod prod)
in in
let rec find_for_ rejected args = let+ factor, prod = find_factor Qset.empty prod in
let* c, q = Qset.min_elt args in (* solve [0 = rejected_sum + (coeff × factor × prod) + sum] yielding
remove_if_non_occuring rejected args c q [factor = (rejected_sum + sum) / (-coeff × prod)] *)
( factor
, div
(Sum.to_term (Qset.union rejected_sum sum))
(mul (rational (Q.neg coeff)) (Prod.to_term prod)) )
(* solve [0 = rejected_sum + (coeff × mono) + sum] *)
let solve_for_mono rejected_sum coeff mono sum =
match mono with
| Mul prod -> solve_for_factor rejected_sum coeff prod sum
| _ ->
if exists_fv_in (fv mono) sum then None
else
Some
( mono
, Sum.to_term
(Sum.mul_const
(Q.inv (Q.neg coeff))
(Qset.union rejected_sum sum)) )
(* solve [0 = rejected + sum] *)
let rec solve_sum rejected_sum sum =
let* mono, coeff, sum = Qset.pop_min_elt sum in
solve_for_mono rejected_sum coeff mono sum
|> Option.or_else ~f:(fun () -> |> Option.or_else ~f:(fun () ->
find_for_ (Qset.add rejected c q) (Qset.remove args c) ) solve_sum (Qset.add rejected_sum mono coeff) sum )
in
match for_ with let rec solve_div = function
| Some c -> (* [n / d = t] ==> [n = d × t] *)
let q = Qset.count args c in | Some (Ap2 (Div, num, den), trm) -> solve_div (Some (num, mul den trm))
if Q.equal Q.zero q then None | o -> o
else remove_if_non_occuring Qset.empty args c q
| None -> find_for_ Qset.empty args
(* solve [0 = e] *)
let solve_zero_eq ?for_ e = let solve_zero_eq ?for_ e =
[%Trace.call fun {pf} -> pf "%a%a" pp e (Option.pp " for %a" pp) for_] [%Trace.call fun {pf} -> pf "0 = %a%a" pp e (Option.pp " for %a" pp) for_]
; ;
( match e with ( match e with
| Add args -> | Add sum ->
let+ c, q, args = find_for ?for_ args in ( match for_ with
let n = Sum.to_term (Qset.remove args c) in | None -> solve_sum Qset.empty sum
let d = rational (Q.neg q) in | Some mono ->
let r = div n d in let* coeff, sum = Qset.find_and_remove sum mono in
(c, r) solve_for_mono Qset.empty coeff mono sum )
|> solve_div
| _ -> None ) | _ -> None )
|> |>
[%Trace.retn fun {pf} s -> [%Trace.retn fun {pf} s ->
@ -1292,5 +1332,7 @@ let solve_zero_eq ?for_ e =
Format.fprintf fs "%a ↦ %a" pp c pp r )) Format.fprintf fs "%a ↦ %a" pp c pp r ))
s ; s ;
match (for_, s) with match (for_, s) with
| Some (Mul prod), Some (var, _) ->
assert (not (Q.equal Q.zero (Qset.count prod var)))
| Some f, Some (c, _) -> assert (equal f c) | Some f, Some (c, _) -> assert (equal f c)
| _ -> ()] | _ -> ()]

Loading…
Cancel
Save