[sledge] Do not solve polynomials for Mul or Div terms

Summary:
This diff adds enough interpretation of Mul and Div terms to be able
to exclude them from the domain of solution substitutions. While
non-linear arithmetic is still treated very incompletely, this change
increases the propagation power of the equality constraints that are
deduced. Mainly, this appears to be enough to avoid operations that
are semantically equivalence-preserving such as solve_for_vars from
producing equalities that are unprovable from their inputs.

Reviewed By: jvillard

Differential Revision: D20863528

fbshipit-source-id: fca74cba3
master
Josh Berdine 5 years ago committed by Facebook GitHub Bot
parent eb750ba6f9
commit 0f50d3c248

@ -14,8 +14,10 @@ type kind = Interpreted | Simplified | Atomic | Uninterpreted
let classify e =
match (e : Term.t) with
| Add _ | Mul _ -> Interpreted
| Ap2 (Memory, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _) ->
| Add _ | Mul _
|Ap2 ((Div | Memory), _, _)
|Ap3 (Extract, _, _, _)
|ApN (Concat, _) ->
Interpreted
| Ap2 ((Eq | Dq), _, _) -> Simplified
| Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted
@ -308,6 +310,8 @@ and solve_ ?f d e s =
( ((Add _ | Mul _ | Integer _ | Rational _) as p), q
| q, ((Add _ | Mul _ | Integer _ | Rational _) as p) ) ->
solve_poly ?f p q s
(* e = n / d ==> e × d = n *)
| Some (rep, Ap2 (Div, num, den)) -> solve_ ?f (Term.mul rep den) num s
| Some (rep, var) ->
assert (non_interpreted var) ;
assert (non_interpreted rep) ;

@ -61,6 +61,9 @@ end) : S with type key = Key.t = struct
let pop m = choose m |> Option.map ~f:(fun (k, v) -> (k, v, remove m k))
let pop_min_elt m =
min_elt m |> Option.map ~f:(fun (k, v) -> (k, v, remove m k))
let find_and_remove m k =
let found = ref None in
let m =

@ -47,8 +47,17 @@ module type S = sig
-> 'r
val choose : 'a t -> (key * 'a) option
(** Find an unspecified element. [O(1)]. *)
val pop : 'a t -> (key * 'a * 'a t) option
(** Find and remove an unspecified element. [O(1)]. *)
val pop_min_elt : 'a t -> (key * 'a * 'a t) option
(** Find and remove minimum element. [O(log n)]. *)
val find_and_remove : 'a t -> key -> ('a * 'a t) option
(** Find and remove an element. *)
val pp : key pp -> 'a pp -> 'a t pp
val pp_diff :

@ -55,6 +55,7 @@ struct
M.change m x ~f:(function Some j -> if_nz Q.(i + j) | None -> if_nz i)
let remove m x = M.remove m x
let find_and_remove = M.find_and_remove
let union m n =
M.merge m n ~f:(fun ~key:_ -> function
@ -77,8 +78,8 @@ struct
let count m x = match M.find m x with Some q -> q | None -> Q.zero
let choose = M.choose
let pop = M.pop
let min_elt_exn = M.min_elt_exn
let min_elt = M.min_elt
let pop_min_elt = M.pop_min_elt
let to_list m = M.to_alist m
let iter m ~f = M.iteri ~f:(fun ~key ~data -> f key data) m
let exists m ~f = M.existsi ~f:(fun ~key ~data -> f key data) m

@ -54,13 +54,19 @@ module type S = sig
(** Multiplicity of an element. [O(log n)]. *)
val choose : t -> (elt * Q.t) option
val pop : t -> (elt * Q.t * t) option
(** Find an unspecified element. [O(1)]. *)
val min_elt_exn : t -> elt * Q.t
(** Minimum element. *)
val pop : t -> (elt * Q.t * t) option
(** Find and remove an unspecified element. [O(1)]. *)
val min_elt : t -> (elt * Q.t) option
(** Minimum element. *)
(** Minimum element. [O(log n)]. *)
val pop_min_elt : t -> (elt * Q.t * t) option
(** Find and remove minimum element. [O(log n)]. *)
val find_and_remove : t -> elt -> (Q.t * t) option
(** Find and remove an element. *)
val to_list : t -> (elt * Q.t) list
(** Convert to a list of elements in ascending order. *)

@ -502,8 +502,16 @@ module Prod = struct
assert (match term with Integer _ | Rational _ -> false | _ -> true) ;
Qset.add prod term Q.one
let singleton term = add term empty
let of_ term = add term empty
let union = Qset.union
let to_term prod =
match Qset.pop prod with
| None -> one
| Some (factor, power, prod')
when Qset.is_empty prod' && Q.equal Q.one power ->
factor
| _ -> Mul prod
end
let rec simp_add_ es poly =
@ -547,21 +555,21 @@ and simp_mul2 e f =
| Rational {data= c}, x | x, Rational {data= c} ->
Sum.to_term (Sum.of_ ~coeff:c x)
(* (∏ᵤ₌₀ⁱ xᵤ) × (∏ᵥ₌ᵢ₊₁ⁿ xᵥ) ==> ∏ⱼ₌₀ⁿ xⱼ *)
| Mul xs1, Mul xs2 -> Mul (Prod.union xs1 xs2)
| Mul xs1, Mul xs2 -> Prod.to_term (Prod.union xs1 xs2)
(* (∏ᵢ xᵢ) × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ cᵤ × ∏ᵢ xᵢ × ∏ⱼ yᵤⱼ *)
| (Mul prod as m), Add sum | Add sum, (Mul prod as m) ->
Sum.to_term
(Sum.map sum ~f:(function
| Mul args -> Mul (Prod.union prod args)
| Mul args -> Prod.to_term (Prod.union prod args)
| (Integer _ | Rational _) as c -> simp_mul2 c m
| mono -> Mul (Prod.add mono prod) ))
| mono -> Prod.to_term (Prod.add mono prod) ))
(* x₀ × (∏ᵢ₌₁ⁿ xᵢ) ==> ∏ᵢ₌₀ⁿ xᵢ *)
| Mul xs1, x | x, Mul xs1 -> Mul (Prod.add x xs1)
| Mul xs1, x | x, Mul xs1 -> Prod.to_term (Prod.add x xs1)
(* e × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ e × cᵤ × ∏ⱼ yᵤⱼ *)
| Add args, e | e, Add args ->
simp_add_ (Sum.map ~f:(fun m -> simp_mul2 e m) args) zero
(* x₁ × x₂ ==> ∏ᵢ₌₁² xᵢ *)
| _ -> Mul (Prod.add e (Prod.singleton f))
| _ -> Prod.to_term (Prod.add e (Prod.of_ f))
let rec simp_div x y =
match (x, y) with
@ -1222,6 +1230,11 @@ let fold_terms e ~init ~f =
let iter_vars e ~f =
iter_terms e ~f:(function Var _ as v -> f (v :> Var.t) | _ -> ())
let exists_vars e ~f =
with_return (fun {return} ->
iter_vars e ~f:(fun v -> if f v then return true) ;
false )
let fold_vars e ~init ~f =
fold_terms e ~init ~f:(fun s -> function
| Var _ as v -> f s (v :> Var.t) | _ -> s )
@ -1248,42 +1261,69 @@ let height e =
(** Solve *)
let find_for ?for_ args =
let exists_var args ~f =
with_return (fun {return} ->
Qset.iter args ~f:(fun arg _ ->
iter_vars arg ~f:(fun v -> if f v then return true) ) ;
false )
in
let remove_if_non_occuring rejected args c q =
let args = Qset.remove args c in
let fv_c = fv c in
if exists_var ~f:(Var.Set.mem fv_c) args then None
else Some (c, q, Qset.union rejected args)
in
let rec find_for_ rejected args =
let* c, q = Qset.min_elt args in
remove_if_non_occuring rejected args c q
|> Option.or_else ~f:(fun () ->
find_for_ (Qset.add rejected c q) (Qset.remove args c) )
let exists_fv_in vs qset =
Qset.exists qset ~f:(fun e _ -> exists_vars e ~f:(Var.Set.mem vs))
let exists_fv_in4 vs w x y z =
exists_fv_in vs w || exists_fv_in vs x || exists_fv_in vs y
|| exists_fv_in vs z
(* solve [0 = rejected_sum + (coeff × prod) + sum] *)
let solve_for_factor rejected_sum coeff prod sum =
let rec find_factor rejected_prod prod =
let* factor, power, prod = Qset.pop_min_elt prod in
if
(not (Q.equal Q.one power))
|| exists_fv_in4 (fv factor) rejected_sum rejected_prod prod sum
then find_factor (Qset.add rejected_prod factor power) prod
else Some (factor, Qset.union rejected_prod prod)
in
match for_ with
| Some c ->
let q = Qset.count args c in
if Q.equal Q.zero q then None
else remove_if_non_occuring Qset.empty args c q
| None -> find_for_ Qset.empty args
let+ factor, prod = find_factor Qset.empty prod in
(* solve [0 = rejected_sum + (coeff × factor × prod) + sum] yielding
[factor = (rejected_sum + sum) / (-coeff × prod)] *)
( factor
, div
(Sum.to_term (Qset.union rejected_sum sum))
(mul (rational (Q.neg coeff)) (Prod.to_term prod)) )
(* solve [0 = rejected_sum + (coeff × mono) + sum] *)
let solve_for_mono rejected_sum coeff mono sum =
match mono with
| Mul prod -> solve_for_factor rejected_sum coeff prod sum
| _ ->
if exists_fv_in (fv mono) sum then None
else
Some
( mono
, Sum.to_term
(Sum.mul_const
(Q.inv (Q.neg coeff))
(Qset.union rejected_sum sum)) )
(* solve [0 = rejected + sum] *)
let rec solve_sum rejected_sum sum =
let* mono, coeff, sum = Qset.pop_min_elt sum in
solve_for_mono rejected_sum coeff mono sum
|> Option.or_else ~f:(fun () ->
solve_sum (Qset.add rejected_sum mono coeff) sum )
let rec solve_div = function
(* [n / d = t] ==> [n = d × t] *)
| Some (Ap2 (Div, num, den), trm) -> solve_div (Some (num, mul den trm))
| o -> o
(* solve [0 = e] *)
let solve_zero_eq ?for_ e =
[%Trace.call fun {pf} -> pf "%a%a" pp e (Option.pp " for %a" pp) for_]
[%Trace.call fun {pf} -> pf "0 = %a%a" pp e (Option.pp " for %a" pp) for_]
;
( match e with
| Add args ->
let+ c, q, args = find_for ?for_ args in
let n = Sum.to_term (Qset.remove args c) in
let d = rational (Q.neg q) in
let r = div n d in
(c, r)
| Add sum ->
( match for_ with
| None -> solve_sum Qset.empty sum
| Some mono ->
let* coeff, sum = Qset.find_and_remove sum mono in
solve_for_mono Qset.empty coeff mono sum )
|> solve_div
| _ -> None )
|>
[%Trace.retn fun {pf} s ->
@ -1292,5 +1332,7 @@ let solve_zero_eq ?for_ e =
Format.fprintf fs "%a ↦ %a" pp c pp r ))
s ;
match (for_, s) with
| Some (Mul prod), Some (var, _) ->
assert (not (Q.equal Q.zero (Qset.count prod var)))
| Some f, Some (c, _) -> assert (equal f c)
| _ -> ()]

Loading…
Cancel
Save