From 0f50d3c248497948dc553b95a34f4ce65d389cbc Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Thu, 16 Apr 2020 03:38:59 -0700 Subject: [PATCH] [sledge] Do not solve polynomials for Mul or Div terms Summary: This diff adds enough interpretation of Mul and Div terms to be able to exclude them from the domain of solution substitutions. While non-linear arithmetic is still treated very incompletely, this change increases the propagation power of the equality constraints that are deduced. Mainly, this appears to be enough to avoid operations that are semantically equivalence-preserving such as solve_for_vars from producing equalities that are unprovable from their inputs. Reviewed By: jvillard Differential Revision: D20863528 fbshipit-source-id: fca74cba3 --- sledge/lib/equality.ml | 8 ++- sledge/lib/import/map.ml | 3 + sledge/lib/import/map_intf.ml | 9 +++ sledge/lib/import/qset.ml | 3 +- sledge/lib/import/qset_intf.ml | 14 ++-- sledge/lib/term.ml | 118 ++++++++++++++++++++++----------- 6 files changed, 110 insertions(+), 45 deletions(-) diff --git a/sledge/lib/equality.ml b/sledge/lib/equality.ml index eb0bffe67..733d42f28 100644 --- a/sledge/lib/equality.ml +++ b/sledge/lib/equality.ml @@ -14,8 +14,10 @@ type kind = Interpreted | Simplified | Atomic | Uninterpreted let classify e = match (e : Term.t) with - | Add _ | Mul _ -> Interpreted - | Ap2 (Memory, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _) -> + | Add _ | Mul _ + |Ap2 ((Div | Memory), _, _) + |Ap3 (Extract, _, _, _) + |ApN (Concat, _) -> Interpreted | Ap2 ((Eq | Dq), _, _) -> Simplified | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted @@ -308,6 +310,8 @@ and solve_ ?f d e s = ( ((Add _ | Mul _ | Integer _ | Rational _) as p), q | q, ((Add _ | Mul _ | Integer _ | Rational _) as p) ) -> solve_poly ?f p q s + (* e = n / d ==> e × d = n *) + | Some (rep, Ap2 (Div, num, den)) -> solve_ ?f (Term.mul rep den) num s | Some (rep, var) -> assert (non_interpreted var) ; assert (non_interpreted rep) ; diff --git a/sledge/lib/import/map.ml b/sledge/lib/import/map.ml index 57f093710..dc755027c 100644 --- a/sledge/lib/import/map.ml +++ b/sledge/lib/import/map.ml @@ -61,6 +61,9 @@ end) : S with type key = Key.t = struct let pop m = choose m |> Option.map ~f:(fun (k, v) -> (k, v, remove m k)) + let pop_min_elt m = + min_elt m |> Option.map ~f:(fun (k, v) -> (k, v, remove m k)) + let find_and_remove m k = let found = ref None in let m = diff --git a/sledge/lib/import/map_intf.ml b/sledge/lib/import/map_intf.ml index 747de3d8d..5bdf92a6f 100644 --- a/sledge/lib/import/map_intf.ml +++ b/sledge/lib/import/map_intf.ml @@ -47,8 +47,17 @@ module type S = sig -> 'r val choose : 'a t -> (key * 'a) option + (** Find an unspecified element. [O(1)]. *) + val pop : 'a t -> (key * 'a * 'a t) option + (** Find and remove an unspecified element. [O(1)]. *) + + val pop_min_elt : 'a t -> (key * 'a * 'a t) option + (** Find and remove minimum element. [O(log n)]. *) + val find_and_remove : 'a t -> key -> ('a * 'a t) option + (** Find and remove an element. *) + val pp : key pp -> 'a pp -> 'a t pp val pp_diff : diff --git a/sledge/lib/import/qset.ml b/sledge/lib/import/qset.ml index 1a162536f..23437806a 100644 --- a/sledge/lib/import/qset.ml +++ b/sledge/lib/import/qset.ml @@ -55,6 +55,7 @@ struct M.change m x ~f:(function Some j -> if_nz Q.(i + j) | None -> if_nz i) let remove m x = M.remove m x + let find_and_remove = M.find_and_remove let union m n = M.merge m n ~f:(fun ~key:_ -> function @@ -77,8 +78,8 @@ struct let count m x = match M.find m x with Some q -> q | None -> Q.zero let choose = M.choose let pop = M.pop - let min_elt_exn = M.min_elt_exn let min_elt = M.min_elt + let pop_min_elt = M.pop_min_elt let to_list m = M.to_alist m let iter m ~f = M.iteri ~f:(fun ~key ~data -> f key data) m let exists m ~f = M.existsi ~f:(fun ~key ~data -> f key data) m diff --git a/sledge/lib/import/qset_intf.ml b/sledge/lib/import/qset_intf.ml index 7f9e97754..f8fcd03f1 100644 --- a/sledge/lib/import/qset_intf.ml +++ b/sledge/lib/import/qset_intf.ml @@ -54,13 +54,19 @@ module type S = sig (** Multiplicity of an element. [O(log n)]. *) val choose : t -> (elt * Q.t) option - val pop : t -> (elt * Q.t * t) option + (** Find an unspecified element. [O(1)]. *) - val min_elt_exn : t -> elt * Q.t - (** Minimum element. *) + val pop : t -> (elt * Q.t * t) option + (** Find and remove an unspecified element. [O(1)]. *) val min_elt : t -> (elt * Q.t) option - (** Minimum element. *) + (** Minimum element. [O(log n)]. *) + + val pop_min_elt : t -> (elt * Q.t * t) option + (** Find and remove minimum element. [O(log n)]. *) + + val find_and_remove : t -> elt -> (Q.t * t) option + (** Find and remove an element. *) val to_list : t -> (elt * Q.t) list (** Convert to a list of elements in ascending order. *) diff --git a/sledge/lib/term.ml b/sledge/lib/term.ml index df385eea6..2c1a9454d 100644 --- a/sledge/lib/term.ml +++ b/sledge/lib/term.ml @@ -502,8 +502,16 @@ module Prod = struct assert (match term with Integer _ | Rational _ -> false | _ -> true) ; Qset.add prod term Q.one - let singleton term = add term empty + let of_ term = add term empty let union = Qset.union + + let to_term prod = + match Qset.pop prod with + | None -> one + | Some (factor, power, prod') + when Qset.is_empty prod' && Q.equal Q.one power -> + factor + | _ -> Mul prod end let rec simp_add_ es poly = @@ -547,21 +555,21 @@ and simp_mul2 e f = | Rational {data= c}, x | x, Rational {data= c} -> Sum.to_term (Sum.of_ ~coeff:c x) (* (∏ᵤ₌₀ⁱ xᵤ) × (∏ᵥ₌ᵢ₊₁ⁿ xᵥ) ==> ∏ⱼ₌₀ⁿ xⱼ *) - | Mul xs1, Mul xs2 -> Mul (Prod.union xs1 xs2) + | Mul xs1, Mul xs2 -> Prod.to_term (Prod.union xs1 xs2) (* (∏ᵢ xᵢ) × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ cᵤ × ∏ᵢ xᵢ × ∏ⱼ yᵤⱼ *) | (Mul prod as m), Add sum | Add sum, (Mul prod as m) -> Sum.to_term (Sum.map sum ~f:(function - | Mul args -> Mul (Prod.union prod args) + | Mul args -> Prod.to_term (Prod.union prod args) | (Integer _ | Rational _) as c -> simp_mul2 c m - | mono -> Mul (Prod.add mono prod) )) + | mono -> Prod.to_term (Prod.add mono prod) )) (* x₀ × (∏ᵢ₌₁ⁿ xᵢ) ==> ∏ᵢ₌₀ⁿ xᵢ *) - | Mul xs1, x | x, Mul xs1 -> Mul (Prod.add x xs1) + | Mul xs1, x | x, Mul xs1 -> Prod.to_term (Prod.add x xs1) (* e × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ e × cᵤ × ∏ⱼ yᵤⱼ *) | Add args, e | e, Add args -> simp_add_ (Sum.map ~f:(fun m -> simp_mul2 e m) args) zero (* x₁ × x₂ ==> ∏ᵢ₌₁² xᵢ *) - | _ -> Mul (Prod.add e (Prod.singleton f)) + | _ -> Prod.to_term (Prod.add e (Prod.of_ f)) let rec simp_div x y = match (x, y) with @@ -1222,6 +1230,11 @@ let fold_terms e ~init ~f = let iter_vars e ~f = iter_terms e ~f:(function Var _ as v -> f (v :> Var.t) | _ -> ()) +let exists_vars e ~f = + with_return (fun {return} -> + iter_vars e ~f:(fun v -> if f v then return true) ; + false ) + let fold_vars e ~init ~f = fold_terms e ~init ~f:(fun s -> function | Var _ as v -> f s (v :> Var.t) | _ -> s ) @@ -1248,42 +1261,69 @@ let height e = (** Solve *) -let find_for ?for_ args = - let exists_var args ~f = - with_return (fun {return} -> - Qset.iter args ~f:(fun arg _ -> - iter_vars arg ~f:(fun v -> if f v then return true) ) ; - false ) - in - let remove_if_non_occuring rejected args c q = - let args = Qset.remove args c in - let fv_c = fv c in - if exists_var ~f:(Var.Set.mem fv_c) args then None - else Some (c, q, Qset.union rejected args) - in - let rec find_for_ rejected args = - let* c, q = Qset.min_elt args in - remove_if_non_occuring rejected args c q - |> Option.or_else ~f:(fun () -> - find_for_ (Qset.add rejected c q) (Qset.remove args c) ) +let exists_fv_in vs qset = + Qset.exists qset ~f:(fun e _ -> exists_vars e ~f:(Var.Set.mem vs)) + +let exists_fv_in4 vs w x y z = + exists_fv_in vs w || exists_fv_in vs x || exists_fv_in vs y + || exists_fv_in vs z + +(* solve [0 = rejected_sum + (coeff × prod) + sum] *) +let solve_for_factor rejected_sum coeff prod sum = + let rec find_factor rejected_prod prod = + let* factor, power, prod = Qset.pop_min_elt prod in + if + (not (Q.equal Q.one power)) + || exists_fv_in4 (fv factor) rejected_sum rejected_prod prod sum + then find_factor (Qset.add rejected_prod factor power) prod + else Some (factor, Qset.union rejected_prod prod) in - match for_ with - | Some c -> - let q = Qset.count args c in - if Q.equal Q.zero q then None - else remove_if_non_occuring Qset.empty args c q - | None -> find_for_ Qset.empty args - + let+ factor, prod = find_factor Qset.empty prod in + (* solve [0 = rejected_sum + (coeff × factor × prod) + sum] yielding + [factor = (rejected_sum + sum) / (-coeff × prod)] *) + ( factor + , div + (Sum.to_term (Qset.union rejected_sum sum)) + (mul (rational (Q.neg coeff)) (Prod.to_term prod)) ) + +(* solve [0 = rejected_sum + (coeff × mono) + sum] *) +let solve_for_mono rejected_sum coeff mono sum = + match mono with + | Mul prod -> solve_for_factor rejected_sum coeff prod sum + | _ -> + if exists_fv_in (fv mono) sum then None + else + Some + ( mono + , Sum.to_term + (Sum.mul_const + (Q.inv (Q.neg coeff)) + (Qset.union rejected_sum sum)) ) + +(* solve [0 = rejected + sum] *) +let rec solve_sum rejected_sum sum = + let* mono, coeff, sum = Qset.pop_min_elt sum in + solve_for_mono rejected_sum coeff mono sum + |> Option.or_else ~f:(fun () -> + solve_sum (Qset.add rejected_sum mono coeff) sum ) + +let rec solve_div = function + (* [n / d = t] ==> [n = d × t] *) + | Some (Ap2 (Div, num, den), trm) -> solve_div (Some (num, mul den trm)) + | o -> o + +(* solve [0 = e] *) let solve_zero_eq ?for_ e = - [%Trace.call fun {pf} -> pf "%a%a" pp e (Option.pp " for %a" pp) for_] + [%Trace.call fun {pf} -> pf "0 = %a%a" pp e (Option.pp " for %a" pp) for_] ; ( match e with - | Add args -> - let+ c, q, args = find_for ?for_ args in - let n = Sum.to_term (Qset.remove args c) in - let d = rational (Q.neg q) in - let r = div n d in - (c, r) + | Add sum -> + ( match for_ with + | None -> solve_sum Qset.empty sum + | Some mono -> + let* coeff, sum = Qset.find_and_remove sum mono in + solve_for_mono Qset.empty coeff mono sum ) + |> solve_div | _ -> None ) |> [%Trace.retn fun {pf} s -> @@ -1292,5 +1332,7 @@ let solve_zero_eq ?for_ e = Format.fprintf fs "%a ↦ %a" pp c pp r )) s ; match (for_, s) with + | Some (Mul prod), Some (var, _) -> + assert (not (Q.equal Q.zero (Qset.count prod var))) | Some f, Some (c, _) -> assert (equal f c) | _ -> ()]