diff --git a/sledge/lib/equality.ml b/sledge/lib/equality.ml index eb0bffe67..733d42f28 100644 --- a/sledge/lib/equality.ml +++ b/sledge/lib/equality.ml @@ -14,8 +14,10 @@ type kind = Interpreted | Simplified | Atomic | Uninterpreted let classify e = match (e : Term.t) with - | Add _ | Mul _ -> Interpreted - | Ap2 (Memory, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _) -> + | Add _ | Mul _ + |Ap2 ((Div | Memory), _, _) + |Ap3 (Extract, _, _, _) + |ApN (Concat, _) -> Interpreted | Ap2 ((Eq | Dq), _, _) -> Simplified | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted @@ -308,6 +310,8 @@ and solve_ ?f d e s = ( ((Add _ | Mul _ | Integer _ | Rational _) as p), q | q, ((Add _ | Mul _ | Integer _ | Rational _) as p) ) -> solve_poly ?f p q s + (* e = n / d ==> e × d = n *) + | Some (rep, Ap2 (Div, num, den)) -> solve_ ?f (Term.mul rep den) num s | Some (rep, var) -> assert (non_interpreted var) ; assert (non_interpreted rep) ; diff --git a/sledge/lib/import/map.ml b/sledge/lib/import/map.ml index 57f093710..dc755027c 100644 --- a/sledge/lib/import/map.ml +++ b/sledge/lib/import/map.ml @@ -61,6 +61,9 @@ end) : S with type key = Key.t = struct let pop m = choose m |> Option.map ~f:(fun (k, v) -> (k, v, remove m k)) + let pop_min_elt m = + min_elt m |> Option.map ~f:(fun (k, v) -> (k, v, remove m k)) + let find_and_remove m k = let found = ref None in let m = diff --git a/sledge/lib/import/map_intf.ml b/sledge/lib/import/map_intf.ml index 747de3d8d..5bdf92a6f 100644 --- a/sledge/lib/import/map_intf.ml +++ b/sledge/lib/import/map_intf.ml @@ -47,8 +47,17 @@ module type S = sig -> 'r val choose : 'a t -> (key * 'a) option + (** Find an unspecified element. [O(1)]. *) + val pop : 'a t -> (key * 'a * 'a t) option + (** Find and remove an unspecified element. [O(1)]. *) + + val pop_min_elt : 'a t -> (key * 'a * 'a t) option + (** Find and remove minimum element. [O(log n)]. *) + val find_and_remove : 'a t -> key -> ('a * 'a t) option + (** Find and remove an element. *) + val pp : key pp -> 'a pp -> 'a t pp val pp_diff : diff --git a/sledge/lib/import/qset.ml b/sledge/lib/import/qset.ml index 1a162536f..23437806a 100644 --- a/sledge/lib/import/qset.ml +++ b/sledge/lib/import/qset.ml @@ -55,6 +55,7 @@ struct M.change m x ~f:(function Some j -> if_nz Q.(i + j) | None -> if_nz i) let remove m x = M.remove m x + let find_and_remove = M.find_and_remove let union m n = M.merge m n ~f:(fun ~key:_ -> function @@ -77,8 +78,8 @@ struct let count m x = match M.find m x with Some q -> q | None -> Q.zero let choose = M.choose let pop = M.pop - let min_elt_exn = M.min_elt_exn let min_elt = M.min_elt + let pop_min_elt = M.pop_min_elt let to_list m = M.to_alist m let iter m ~f = M.iteri ~f:(fun ~key ~data -> f key data) m let exists m ~f = M.existsi ~f:(fun ~key ~data -> f key data) m diff --git a/sledge/lib/import/qset_intf.ml b/sledge/lib/import/qset_intf.ml index 7f9e97754..f8fcd03f1 100644 --- a/sledge/lib/import/qset_intf.ml +++ b/sledge/lib/import/qset_intf.ml @@ -54,13 +54,19 @@ module type S = sig (** Multiplicity of an element. [O(log n)]. *) val choose : t -> (elt * Q.t) option - val pop : t -> (elt * Q.t * t) option + (** Find an unspecified element. [O(1)]. *) - val min_elt_exn : t -> elt * Q.t - (** Minimum element. *) + val pop : t -> (elt * Q.t * t) option + (** Find and remove an unspecified element. [O(1)]. *) val min_elt : t -> (elt * Q.t) option - (** Minimum element. *) + (** Minimum element. [O(log n)]. *) + + val pop_min_elt : t -> (elt * Q.t * t) option + (** Find and remove minimum element. [O(log n)]. *) + + val find_and_remove : t -> elt -> (Q.t * t) option + (** Find and remove an element. *) val to_list : t -> (elt * Q.t) list (** Convert to a list of elements in ascending order. *) diff --git a/sledge/lib/term.ml b/sledge/lib/term.ml index df385eea6..2c1a9454d 100644 --- a/sledge/lib/term.ml +++ b/sledge/lib/term.ml @@ -502,8 +502,16 @@ module Prod = struct assert (match term with Integer _ | Rational _ -> false | _ -> true) ; Qset.add prod term Q.one - let singleton term = add term empty + let of_ term = add term empty let union = Qset.union + + let to_term prod = + match Qset.pop prod with + | None -> one + | Some (factor, power, prod') + when Qset.is_empty prod' && Q.equal Q.one power -> + factor + | _ -> Mul prod end let rec simp_add_ es poly = @@ -547,21 +555,21 @@ and simp_mul2 e f = | Rational {data= c}, x | x, Rational {data= c} -> Sum.to_term (Sum.of_ ~coeff:c x) (* (∏ᵤ₌₀ⁱ xᵤ) × (∏ᵥ₌ᵢ₊₁ⁿ xᵥ) ==> ∏ⱼ₌₀ⁿ xⱼ *) - | Mul xs1, Mul xs2 -> Mul (Prod.union xs1 xs2) + | Mul xs1, Mul xs2 -> Prod.to_term (Prod.union xs1 xs2) (* (∏ᵢ xᵢ) × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ cᵤ × ∏ᵢ xᵢ × ∏ⱼ yᵤⱼ *) | (Mul prod as m), Add sum | Add sum, (Mul prod as m) -> Sum.to_term (Sum.map sum ~f:(function - | Mul args -> Mul (Prod.union prod args) + | Mul args -> Prod.to_term (Prod.union prod args) | (Integer _ | Rational _) as c -> simp_mul2 c m - | mono -> Mul (Prod.add mono prod) )) + | mono -> Prod.to_term (Prod.add mono prod) )) (* x₀ × (∏ᵢ₌₁ⁿ xᵢ) ==> ∏ᵢ₌₀ⁿ xᵢ *) - | Mul xs1, x | x, Mul xs1 -> Mul (Prod.add x xs1) + | Mul xs1, x | x, Mul xs1 -> Prod.to_term (Prod.add x xs1) (* e × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ e × cᵤ × ∏ⱼ yᵤⱼ *) | Add args, e | e, Add args -> simp_add_ (Sum.map ~f:(fun m -> simp_mul2 e m) args) zero (* x₁ × x₂ ==> ∏ᵢ₌₁² xᵢ *) - | _ -> Mul (Prod.add e (Prod.singleton f)) + | _ -> Prod.to_term (Prod.add e (Prod.of_ f)) let rec simp_div x y = match (x, y) with @@ -1222,6 +1230,11 @@ let fold_terms e ~init ~f = let iter_vars e ~f = iter_terms e ~f:(function Var _ as v -> f (v :> Var.t) | _ -> ()) +let exists_vars e ~f = + with_return (fun {return} -> + iter_vars e ~f:(fun v -> if f v then return true) ; + false ) + let fold_vars e ~init ~f = fold_terms e ~init ~f:(fun s -> function | Var _ as v -> f s (v :> Var.t) | _ -> s ) @@ -1248,42 +1261,69 @@ let height e = (** Solve *) -let find_for ?for_ args = - let exists_var args ~f = - with_return (fun {return} -> - Qset.iter args ~f:(fun arg _ -> - iter_vars arg ~f:(fun v -> if f v then return true) ) ; - false ) - in - let remove_if_non_occuring rejected args c q = - let args = Qset.remove args c in - let fv_c = fv c in - if exists_var ~f:(Var.Set.mem fv_c) args then None - else Some (c, q, Qset.union rejected args) - in - let rec find_for_ rejected args = - let* c, q = Qset.min_elt args in - remove_if_non_occuring rejected args c q - |> Option.or_else ~f:(fun () -> - find_for_ (Qset.add rejected c q) (Qset.remove args c) ) +let exists_fv_in vs qset = + Qset.exists qset ~f:(fun e _ -> exists_vars e ~f:(Var.Set.mem vs)) + +let exists_fv_in4 vs w x y z = + exists_fv_in vs w || exists_fv_in vs x || exists_fv_in vs y + || exists_fv_in vs z + +(* solve [0 = rejected_sum + (coeff × prod) + sum] *) +let solve_for_factor rejected_sum coeff prod sum = + let rec find_factor rejected_prod prod = + let* factor, power, prod = Qset.pop_min_elt prod in + if + (not (Q.equal Q.one power)) + || exists_fv_in4 (fv factor) rejected_sum rejected_prod prod sum + then find_factor (Qset.add rejected_prod factor power) prod + else Some (factor, Qset.union rejected_prod prod) in - match for_ with - | Some c -> - let q = Qset.count args c in - if Q.equal Q.zero q then None - else remove_if_non_occuring Qset.empty args c q - | None -> find_for_ Qset.empty args - + let+ factor, prod = find_factor Qset.empty prod in + (* solve [0 = rejected_sum + (coeff × factor × prod) + sum] yielding + [factor = (rejected_sum + sum) / (-coeff × prod)] *) + ( factor + , div + (Sum.to_term (Qset.union rejected_sum sum)) + (mul (rational (Q.neg coeff)) (Prod.to_term prod)) ) + +(* solve [0 = rejected_sum + (coeff × mono) + sum] *) +let solve_for_mono rejected_sum coeff mono sum = + match mono with + | Mul prod -> solve_for_factor rejected_sum coeff prod sum + | _ -> + if exists_fv_in (fv mono) sum then None + else + Some + ( mono + , Sum.to_term + (Sum.mul_const + (Q.inv (Q.neg coeff)) + (Qset.union rejected_sum sum)) ) + +(* solve [0 = rejected + sum] *) +let rec solve_sum rejected_sum sum = + let* mono, coeff, sum = Qset.pop_min_elt sum in + solve_for_mono rejected_sum coeff mono sum + |> Option.or_else ~f:(fun () -> + solve_sum (Qset.add rejected_sum mono coeff) sum ) + +let rec solve_div = function + (* [n / d = t] ==> [n = d × t] *) + | Some (Ap2 (Div, num, den), trm) -> solve_div (Some (num, mul den trm)) + | o -> o + +(* solve [0 = e] *) let solve_zero_eq ?for_ e = - [%Trace.call fun {pf} -> pf "%a%a" pp e (Option.pp " for %a" pp) for_] + [%Trace.call fun {pf} -> pf "0 = %a%a" pp e (Option.pp " for %a" pp) for_] ; ( match e with - | Add args -> - let+ c, q, args = find_for ?for_ args in - let n = Sum.to_term (Qset.remove args c) in - let d = rational (Q.neg q) in - let r = div n d in - (c, r) + | Add sum -> + ( match for_ with + | None -> solve_sum Qset.empty sum + | Some mono -> + let* coeff, sum = Qset.find_and_remove sum mono in + solve_for_mono Qset.empty coeff mono sum ) + |> solve_div | _ -> None ) |> [%Trace.retn fun {pf} s -> @@ -1292,5 +1332,7 @@ let solve_zero_eq ?for_ e = Format.fprintf fs "%a ↦ %a" pp c pp r )) s ; match (for_, s) with + | Some (Mul prod), Some (var, _) -> + assert (not (Q.equal Q.zero (Qset.count prod var))) | Some f, Some (c, _) -> assert (equal f c) | _ -> ()]