From 1a34e7eed2c27b5e82a52994cd5343146130ab90 Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Thu, 16 Apr 2020 03:38:43 -0700 Subject: [PATCH] [sledge] Add rational constants Reviewed By: jvillard Differential Revision: D20831348 fbshipit-source-id: 72790cbec --- sledge/lib/equality.ml | 10 +++-- sledge/lib/term.ml | 98 ++++++++++++++++++++++++++---------------- sledge/lib/term.mli | 1 + 3 files changed, 69 insertions(+), 40 deletions(-) diff --git a/sledge/lib/equality.ml b/sledge/lib/equality.ml index cebade643..eb0bffe67 100644 --- a/sledge/lib/equality.ml +++ b/sledge/lib/equality.ml @@ -19,7 +19,9 @@ let classify e = Interpreted | Ap2 ((Eq | Dq), _, _) -> Simplified | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted - | RecN _ | Var _ | Integer _ | Float _ | Nondet _ | Label _ -> Atomic + | RecN _ | Var _ | Integer _ | Rational _ | Float _ | Nondet _ | Label _ + -> + Atomic let interpreted e = equal_kind (classify e) Interpreted let non_interpreted e = not (interpreted e) @@ -264,7 +266,7 @@ and solve_ ?f d e s = (* e' = f' ==> true when e' ≡ f' *) | None -> Some s (* i = j ==> false when i ≠ j *) - | Some (Integer _, Integer _) -> None + | Some (Integer _, Integer _) | Some (Rational _, Rational _) -> None (* ⟨0,a⟩ = β ==> a = β = ⟨⟩ *) | Some (Ap2 (Memory, n, a), b) when Term.equal n Term.zero -> s |> solve_ ?f a (Term.concat [||]) >>= solve_ ?f b (Term.concat [||]) @@ -303,8 +305,8 @@ and solve_ ?f d e s = | Some (Ap3 (Extract, a, o, l), e) -> solve_extract ?f a o l e s (* p = q ==> p-q = 0 *) | Some - ( ((Add _ | Mul _ | Integer _) as p), q - | q, ((Add _ | Mul _ | Integer _) as p) ) -> + ( ((Add _ | Mul _ | Integer _ | Rational _) as p), q + | q, ((Add _ | Mul _ | Integer _ | Rational _) as p) ) -> solve_poly ?f p q s | Some (rep, var) -> assert (non_interpreted var) ; diff --git a/sledge/lib/term.ml b/sledge/lib/term.ml index eef8ff13c..003c31241 100644 --- a/sledge/lib/term.ml +++ b/sledge/lib/term.ml @@ -70,6 +70,7 @@ and T : sig | Nondet of {msg: string} | Float of {data: string} | Integer of {data: Z.t} + | Rational of {data: Q.t} [@@deriving compare, equal, hash, sexp] end = struct type qset = Qset.t [@@deriving compare, equal, hash, sexp] @@ -87,11 +88,12 @@ end = struct | Nondet of {msg: string} | Float of {data: string} | Integer of {data: Z.t} + | Rational of {data: Q.t} [@@deriving compare, equal, hash, sexp] (* Note: solve (and invariant) requires Qset.min_elt to return a - non-coefficient, so Integer terms must compare higher than any valid - monomial *) + non-coefficient, so Integer and Rational terms must compare higher than + any valid monomial *) let compare x y = match (x, y) with | Var {id= i; name= _}, Var {id= j; name= _} when i > 0 && j > 0 -> @@ -141,6 +143,7 @@ let rec ppx strength fs term = | Some `Existential -> Trace.pp_styled `Cyan "%%%s_%d" fs name id | Some `Anonymous -> Trace.pp_styled `Cyan "_" fs ) | Integer {data} -> Trace.pp_styled `Magenta "%a" fs Z.pp data + | Rational {data} -> Trace.pp_styled `Magenta "%a" fs Q.pp data | Float {data} -> pf "%s" data | Nondet {msg} -> pf "nondet \"%s\"" msg | Label {name} -> pf "%s" name @@ -218,9 +221,10 @@ let pp_diff fs (x, y) = Format.fprintf fs "-- %a ++ %a" pp x pp y (** Invariant *) -(* an indeterminate (factor of a monomial) is any non-Add/Mul/Integer term *) +(* an indeterminate (factor of a monomial) is any + non-Add/Mul/Integer/Rational term *) let assert_indeterminate = function - | Integer _ | Add _ | Mul _ -> assert false + | Integer _ | Rational _ | Add _ | Mul _ -> assert false | _ -> assert true (* a monomial is a power product of factors, e.g. @@ -244,7 +248,7 @@ let assert_poly_term mono coeff = | Integer {data} -> assert (Z.equal Z.one data) | Mul args -> ( match Qset.min_elt args with - | None | Some (Integer _, _) -> assert false + | None | Some ((Integer _ | Rational _), _) -> assert false | Some (_, n) -> assert (Qset.length args > 1 || not (Q.equal Q.one n)) ) ; assert_monomial mono |> Fn.id @@ -259,7 +263,7 @@ let assert_polynomial poly = match poly with | Add args -> ( match Qset.min_elt args with - | None | Some (Integer _, _) -> assert false + | None | Some ((Integer _ | Rational _), _) -> assert false | Some (_, k) -> assert (Qset.length args > 1 || not (Q.equal Q.one k)) ) ; Qset.iter args ~f:(fun m c -> assert_poly_term m c |> Fn.id) @@ -291,6 +295,9 @@ let invariant e = assert ( not (Typ.equivalent src dst) (* avoid redundant representations *) ) + | Rational {data} -> + assert (Q.is_real data) ; + assert (not (Z.equal Z.one (Q.den data))) | _ -> () [@@warning "-9"] @@ -412,6 +419,12 @@ let var x = x (* constants *) let integer data = Integer {data} |> check invariant + +let rational data = + ( if Z.equal Z.one (Q.den data) then Integer {data= Q.num data} + else Rational {data} ) + |> check invariant + let null = integer Z.zero let zero = integer Z.zero let one = integer Z.one @@ -451,6 +464,7 @@ module Sum = struct match term with | Integer {data} when Z.equal Z.zero data -> sum | Integer {data} -> Qset.add sum one Q.(coeff * of_z data) + | Rational {data} -> Qset.add sum one Q.(coeff * data) | _ -> Qset.add sum term coeff let singleton ?(coeff = Q.one) term = add coeff term empty @@ -462,6 +476,16 @@ module Sum = struct assert (not (Q.equal Q.zero const)) ; if Q.equal Q.one const then sum else Qset.map_counts ~f:(fun _ -> Q.mul const) sum + + let to_term sum = + match Qset.length sum with + | 0 -> zero + | 1 -> ( + match Qset.min_elt sum with + | Some (Integer _, q) -> rational q + | Some (arg, q) when Q.equal Q.one q -> arg + | _ -> Add sum ) + | _ -> Add sum end (* Products of indeterminants represented by multisets. A product ∏ᵢ xᵢ^nᵢ @@ -471,26 +495,14 @@ module Prod = struct let empty = Qset.empty let add term prod = - assert (match term with Integer _ -> false | _ -> true) ; + assert (match term with Integer _ | Rational _ -> false | _ -> true) ; Qset.add prod term Q.one let singleton term = add term empty let union = Qset.union end -let rec sum_to_term sum = - match Qset.length sum with - | 0 -> zero - | 1 -> ( - match Qset.min_elt sum with - | Some (Integer _, q) -> rational q - | Some (arg, q) when Q.equal Q.one q -> arg - | _ -> Add sum ) - | _ -> Add sum - -and rational Q.{num; den} = simp_div (integer num) (integer den) - -and simp_add_ es poly = +let rec simp_add_ es poly = (* (coeff × term) + poly *) let f term coeff poly = match (term, poly) with @@ -501,12 +513,13 @@ and simp_add_ es poly = (* (c × cᵢ) + cⱼ ==> c×cᵢ+cⱼ *) | Integer {data= i}, Integer {data= j} -> rational Q.((coeff * of_z i) + of_z j) + | Rational {data= i}, Rational {data= j} -> rational Q.((coeff * i) + j) (* (c × ∑ᵢ cᵢ × Xᵢ) + s ==> (∑ᵢ (c × cᵢ) × Xᵢ) + s *) | Add args, _ -> simp_add_ (Sum.mul_const coeff args) poly (* (c₀ × X₀) + (∑ᵢ₌₁ⁿ cᵢ × Xᵢ) ==> ∑ᵢ₌₀ⁿ cᵢ × Xᵢ *) - | _, Add args -> sum_to_term (Sum.add coeff term args) + | _, Add args -> Sum.to_term (Sum.add coeff term args) (* (c₁ × X₁) + X₂ ==> ∑ᵢ₌₁² cᵢ × Xᵢ for c₂ = 1 *) - | _ -> sum_to_term (Sum.add coeff term (Sum.singleton poly)) + | _ -> Sum.to_term (Sum.add coeff term (Sum.singleton poly)) in Qset.fold ~f es ~init:poly @@ -514,24 +527,29 @@ and simp_mul2 e f = match (e, f) with (* c₁ × c₂ ==> c₁×c₂ *) | Integer {data= i}, Integer {data= j} -> integer (Z.mul i j) + | Rational {data= i}, Rational {data= j} -> rational (Q.mul i j) (* 0 × f ==> 0 *) | Integer {data}, _ when Z.equal Z.zero data -> e (* e × 0 ==> 0 *) | _, Integer {data} when Z.equal Z.zero data -> f (* c × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ c × cᵤ × ∏ⱼ yᵤⱼ *) | Integer {data}, Add args | Add args, Integer {data} -> - sum_to_term (Sum.mul_const (Q.of_z data) args) + Sum.to_term (Sum.mul_const (Q.of_z data) args) + | Rational {data}, Add args | Add args, Rational {data} -> + Sum.to_term (Sum.mul_const data args) (* c₁ × x₁ ==> ∑ᵢ₌₁ cᵢ × xᵢ *) | Integer {data= c}, x | x, Integer {data= c} -> - sum_to_term (Sum.singleton ~coeff:(Q.of_z c) x) + Sum.to_term (Sum.singleton ~coeff:(Q.of_z c) x) + | Rational {data= c}, x | x, Rational {data= c} -> + Sum.to_term (Sum.singleton ~coeff:c x) (* (∏ᵤ₌₀ⁱ xᵤ) × (∏ᵥ₌ᵢ₊₁ⁿ xᵥ) ==> ∏ⱼ₌₀ⁿ xⱼ *) | Mul xs1, Mul xs2 -> Mul (Prod.union xs1 xs2) (* (∏ᵢ xᵢ) × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ cᵤ × ∏ᵢ xᵢ × ∏ⱼ yᵤⱼ *) | (Mul prod as m), Add sum | Add sum, (Mul prod as m) -> - sum_to_term + Sum.to_term (Sum.map sum ~f:(function | Mul args -> Mul (Prod.union prod args) - | Integer _ as c -> simp_mul2 c m + | (Integer _ | Rational _) as c -> simp_mul2 c m | mono -> Mul (Prod.add mono prod) )) (* x₀ × (∏ᵢ₌₁ⁿ xᵢ) ==> ∏ᵢ₌₀ⁿ xᵢ *) | Mul xs1, x | x, Mul xs1 -> Mul (Prod.add x xs1) @@ -541,18 +559,21 @@ and simp_mul2 e f = (* x₁ × x₂ ==> ∏ᵢ₌₁² xᵢ *) | _ -> Mul (Prod.add e (Prod.singleton f)) -and simp_div x y = +let simp_div x y = match (x, y) with (* i / j *) | Integer {data= i}, Integer {data= j} when not (Z.equal Z.zero j) -> integer (Z.div i j) + | Rational {data= i}, Rational {data= j} -> rational (Q.div i j) (* e / 1 ==> e *) | e, Integer {data} when Z.equal Z.one data -> e (* e / -1 ==> -1×e *) | e, (Integer {data} as c) when Z.equal Z.minus_one data -> simp_mul2 e c (* (∑ᵢ cᵢ × Xᵢ) / z ==> ∑ᵢ cᵢ/z × Xᵢ *) | Add args, Integer {data} -> - sum_to_term (Sum.mul_const Q.(inv (of_z data)) args) + Sum.to_term (Sum.mul_const Q.(inv (of_z data)) args) + | Add args, Rational {data} -> + Sum.to_term (Sum.mul_const Q.(inv data) args) | _ -> Ap2 (Div, x, y) let simp_rem x y = @@ -572,6 +593,7 @@ let simp_sub x y = match (x, y) with (* i - j *) | Integer {data= i}, Integer {data= j} -> integer (Z.sub i j) + | Rational {data= i}, Rational {data= j} -> rational (Q.sub i j) (* x - y ==> x + (-1 * y) *) | _ -> simp_add2 x (simp_negate y) @@ -668,6 +690,8 @@ let partial_compare x y : pcmp = match simp_sub x y with | Integer {data} -> ( match Int.sign (Z.sign data) with Neg -> Lt | Zero -> Eq | Pos -> Gt ) + | Rational {data} -> ( + match Int.sign (Q.sign data) with Neg -> Lt | Zero -> Eq | Pos -> Gt ) | _ -> Unknown let partial_ge x y = @@ -776,11 +800,13 @@ and simp_concat xs = let simp_lt x y = match (x, y) with | Integer {data= i}, Integer {data= j} -> bool (Z.lt i j) + | Rational {data= i}, Rational {data= j} -> bool (Q.lt i j) | _ -> Ap2 (Lt, x, y) let simp_le x y = match (x, y) with | Integer {data= i}, Integer {data= j} -> bool (Z.leq i j) + | Rational {data= i}, Rational {data= j} -> bool (Q.leq i j) | _ -> Ap2 (Le, x, y) let simp_ord x y = Ap2 (Ord, x, y) @@ -798,7 +824,7 @@ let rec simp_eq x y = | Some (x, y) -> ( match (x, y) with (* i = j ==> false when i ≠ j *) - | Integer _, Integer _ -> bool false + | Integer _, Integer _ | Rational _, Rational _ -> bool false (* b = false ==> ¬b *) | b, Integer {data} when Z.is_false data && is_boolean b -> simp_not b (* b = true ==> b *) @@ -1075,7 +1101,7 @@ let map e ~f = xs == IArray.map_endo ~f xs || fail "Term.map does not support updating subterms of RecN." () ) ; e - | Var _ | Label _ | Nondet _ | Float _ | Integer _ -> e + | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> e let fold_map e ~init ~f = let s = ref init in @@ -1130,7 +1156,7 @@ let iter e ~f = | Ap3 (_, x, y, z) -> f x ; f y ; f z | ApN (_, xs) | RecN (_, xs) -> IArray.iter ~f xs | Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args - | Var _ | Label _ | Nondet _ | Float _ | Integer _ -> () + | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> () let exists e ~f = match e with @@ -1139,7 +1165,7 @@ let exists e ~f = | Ap3 (_, x, y, z) -> f x || f y || f z | ApN (_, xs) | RecN (_, xs) -> IArray.exists ~f xs | Add args | Mul args -> Qset.exists ~f:(fun arg _ -> f arg) args - | Var _ | Label _ | Nondet _ | Float _ | Integer _ -> false + | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> false let fold e ~init:s ~f = match e with @@ -1149,7 +1175,7 @@ let fold e ~init:s ~f = | ApN (_, xs) | RecN (_, xs) -> IArray.fold ~f:(fun s x -> f x s) xs ~init:s | Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s - | Var _ | Label _ | Nondet _ | Float _ | Integer _ -> s + | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> s let fold_terms e ~init ~f = let fold_terms_ fold_terms_ e s = @@ -1162,7 +1188,7 @@ let fold_terms e ~init ~f = IArray.fold ~f:(fun s x -> fold_terms_ x s) xs ~init:s | Add args | Mul args -> Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms_ arg s) - | Var _ | Label _ | Nondet _ | Float _ | Integer _ -> s + | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> s in f s e in @@ -1188,7 +1214,7 @@ let height e = 1 + IArray.fold v ~init:0 ~f:(fun m a -> max m (height_ a)) | Add qs | Mul qs -> 1 + Qset.fold qs ~init:0 ~f:(fun a _ m -> max m (height_ a)) - | Label _ | Nondet _ | Float _ | Integer _ -> 0 + | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> 0 in fix height_ (fun _ -> 0) e @@ -1206,7 +1232,7 @@ let solve_zero_eq ?for_ e = if Q.equal Q.zero q then None else Some (f, q) | None -> Some (Qset.min_elt_exn args) in - let n = sum_to_term (Qset.remove args c) in + let n = Sum.to_term (Qset.remove args c) in let d = rational (Q.neg q) in let r = div n d in (c, r) diff --git a/sledge/lib/term.mli b/sledge/lib/term.mli index 8e601fe76..0c77e3d74 100644 --- a/sledge/lib/term.mli +++ b/sledge/lib/term.mli @@ -92,6 +92,7 @@ and T : sig non-deterministic approximation of value described by [msg] *) | Float of {data: string} (** Floating-point constant *) | Integer of {data: Z.t} (** Integer constant *) + | Rational of {data: Q.t} (** Rational constant *) [@@deriving compare, equal, hash, sexp] end