From 65f38d68cc859c6da737196c5098b02eac69e8f3 Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Tue, 18 Feb 2020 07:50:41 -0800 Subject: [PATCH] [sledge] Refactor to allow more recursion between arithmetic canonizer cases Summary: No functional change. Reviewed By: jvillard Differential Revision: D19950367 fbshipit-source-id: 9d14e98bf --- sledge/src/llair/term.ml | 146 +++++++++++++++++++-------------------- 1 file changed, 71 insertions(+), 75 deletions(-) diff --git a/sledge/src/llair/term.ml b/sledge/src/llair/term.ml index 8057d7da6..657ab839a 100644 --- a/sledge/src/llair/term.ml +++ b/sledge/src/llair/term.ml @@ -463,44 +463,6 @@ let simp_convert src dst arg = (* arithmetic *) -let sum_mul_const const sum = - assert (not (Q.equal Q.zero const)) ; - if Q.equal Q.one const then sum - else Qset.map_counts ~f:(fun _ -> Q.mul const) sum - -let rec sum_to_term sum = - match Qset.length sum with - | 0 -> zero - | 1 -> ( - match Qset.min_elt sum with - | Some (Integer _, q) -> rational q - | Some (arg, q) when Q.equal Q.one q -> arg - | _ -> Add sum ) - | _ -> Add sum - -and rational Q.{num; den} = simp_div (integer num) (integer den) - -and simp_div x y = - match (x, y) with - (* i / j *) - | Integer {data= i}, Integer {data= j} when not (Z.equal Z.zero j) -> - integer (Z.div i j) - (* e / 1 ==> e *) - | e, Integer {data} when Z.equal Z.one data -> e - (* (∑ᵢ cᵢ × Xᵢ) / z ==> ∑ᵢ cᵢ/z × Xᵢ *) - | Add args, Integer {data} -> - sum_to_term (sum_mul_const Q.(inv (of_z data)) args) - | _ -> Ap2 (Div, x, y) - -let simp_rem x y = - match (x, y) with - (* i % j *) - | Integer {data= i}, Integer {data= j} when not (Z.equal Z.zero j) -> - integer (Z.rem i j) - (* e % 1 ==> 0 *) - | _, Integer {data} when Z.equal Z.one data -> zero - | _ -> Ap2 (Rem, x, y) - (* Sums of polynomial terms represented by multisets. A sum ∑ᵢ cᵢ × Xᵢ of monomials Xᵢ with coefficients cᵢ is represented by a multiset where the elements are Xᵢ with multiplicities cᵢ. A constant is treated as the @@ -520,11 +482,39 @@ module Sum = struct let map sum ~f = Qset.fold sum ~init:empty ~f:(fun e c sum -> add c (f e) sum) - let mul_const = sum_mul_const - let to_term = sum_to_term + let mul_const const sum = + assert (not (Q.equal Q.zero const)) ; + if Q.equal Q.one const then sum + else Qset.map_counts ~f:(fun _ -> Q.mul const) sum end -let rec simp_add_ es poly = +(* Products of indeterminants represented by multisets. A product ∏ᵢ xᵢ^nᵢ + of indeterminates xᵢ is represented by a multiset where the elements are + xᵢ and the multiplicities are the exponents nᵢ. *) +module Prod = struct + let empty = empty_qset + + let add term prod = + assert (match term with Integer _ -> false | _ -> true) ; + Qset.add prod term Q.one + + let singleton term = add term empty + let union = Qset.union +end + +let rec sum_to_term sum = + match Qset.length sum with + | 0 -> zero + | 1 -> ( + match Qset.min_elt sum with + | Some (Integer _, q) -> rational q + | Some (arg, q) when Q.equal Q.one q -> arg + | _ -> Add sum ) + | _ -> Add sum + +and rational Q.{num; den} = simp_div (integer num) (integer den) + +and simp_add_ es poly = (* (coeff × term) + poly *) let f term coeff poly = match (term, poly) with @@ -538,30 +528,13 @@ let rec simp_add_ es poly = (* (c × ∑ᵢ cᵢ × Xᵢ) + s ==> (∑ᵢ (c × cᵢ) × Xᵢ) + s *) | Add args, _ -> simp_add_ (Sum.mul_const coeff args) poly (* (c₀ × X₀) + (∑ᵢ₌₁ⁿ cᵢ × Xᵢ) ==> ∑ᵢ₌₀ⁿ cᵢ × Xᵢ *) - | _, Add args -> Sum.to_term (Sum.add coeff term args) + | _, Add args -> sum_to_term (Sum.add coeff term args) (* (c₁ × X₁) + X₂ ==> ∑ᵢ₌₁² cᵢ × Xᵢ for c₂ = 1 *) - | _ -> Sum.to_term (Sum.add coeff term (Sum.singleton poly)) + | _ -> sum_to_term (Sum.add coeff term (Sum.singleton poly)) in Qset.fold ~f es ~init:poly -let simp_add es = simp_add_ es zero -let simp_add2 e f = simp_add_ (Sum.singleton e) f - -(* Products of indeterminants represented by multisets. A product ∏ᵢ xᵢ^nᵢ - of indeterminates xᵢ is represented by a multiset where the elements are - xᵢ and the multiplicities are the exponents nᵢ. *) -module Prod = struct - let empty = empty_qset - - let add term prod = - assert (match term with Integer _ -> false | _ -> true) ; - Qset.add prod term Q.one - - let singleton term = add term empty - let union = Qset.union -end - -let rec simp_mul2 e f = +and simp_mul2 e f = match (e, f) with (* c₁ × c₂ ==> c₁×c₂ *) | Integer {data= i}, Integer {data= j} -> integer (Z.mul i j) @@ -571,15 +544,15 @@ let rec simp_mul2 e f = | _, Integer {data} when Z.equal Z.zero data -> f (* c × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ c × cᵤ × ∏ⱼ yᵤⱼ *) | Integer {data}, Add args | Add args, Integer {data} -> - Sum.to_term (Sum.mul_const (Q.of_z data) args) + sum_to_term (Sum.mul_const (Q.of_z data) args) (* c₁ × x₁ ==> ∑ᵢ₌₁ cᵢ × xᵢ *) | Integer {data= c}, x | x, Integer {data= c} -> - Sum.to_term (Sum.singleton ~coeff:(Q.of_z c) x) + sum_to_term (Sum.singleton ~coeff:(Q.of_z c) x) (* (∏ᵤ₌₀ⁱ xᵤ) × (∏ᵥ₌ᵢ₊₁ⁿ xᵥ) ==> ∏ⱼ₌₀ⁿ xⱼ *) | Mul xs1, Mul xs2 -> Mul (Prod.union xs1 xs2) (* (∏ᵢ xᵢ) × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ cᵤ × ∏ᵢ xᵢ × ∏ⱼ yᵤⱼ *) | (Mul prod as m), Add sum | Add sum, (Mul prod as m) -> - Sum.to_term + sum_to_term (Sum.map sum ~f:(function | Mul args -> Mul (Prod.union prod args) | Integer _ as c -> simp_mul2 c m @@ -588,20 +561,33 @@ let rec simp_mul2 e f = | Mul xs1, x | x, Mul xs1 -> Mul (Prod.add x xs1) (* e × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ e × cᵤ × ∏ⱼ yᵤⱼ *) | Add args, e | e, Add args -> - simp_add (Sum.map ~f:(fun m -> simp_mul2 e m) args) + simp_add_ (Sum.map ~f:(fun m -> simp_mul2 e m) args) zero (* x₁ × x₂ ==> ∏ᵢ₌₁² xᵢ *) | _ -> Mul (Prod.add e (Prod.singleton f)) -let simp_mul es = - (* (bas ^ pwr) × term *) - let rec mul_pwr bas pwr term = - if Q.equal Q.zero pwr then term - else mul_pwr bas Q.(pwr - one) (simp_mul2 bas term) - in - Qset.fold es ~init:one ~f:(fun bas pwr term -> - if Q.sign pwr >= 0 then mul_pwr bas pwr term - else simp_div term (mul_pwr bas (Q.neg pwr) one) ) +and simp_div x y = + match (x, y) with + (* i / j *) + | Integer {data= i}, Integer {data= j} when not (Z.equal Z.zero j) -> + integer (Z.div i j) + (* e / 1 ==> e *) + | e, Integer {data} when Z.equal Z.one data -> e + (* (∑ᵢ cᵢ × Xᵢ) / z ==> ∑ᵢ cᵢ/z × Xᵢ *) + | Add args, Integer {data} -> + sum_to_term (Sum.mul_const Q.(inv (of_z data)) args) + | _ -> Ap2 (Div, x, y) +let simp_rem x y = + match (x, y) with + (* i % j *) + | Integer {data= i}, Integer {data= j} when not (Z.equal Z.zero j) -> + integer (Z.rem i j) + (* e % 1 ==> 0 *) + | _, Integer {data} when Z.equal Z.one data -> zero + | _ -> Ap2 (Rem, x, y) + +let simp_add es = simp_add_ es zero +let simp_add2 e f = simp_add_ (Sum.singleton e) f let simp_negate x = simp_mul2 minus_one x let simp_sub x y = @@ -611,6 +597,16 @@ let simp_sub x y = (* x - y ==> x + (-1 * y) *) | _ -> simp_add2 x (simp_negate y) +let simp_mul es = + (* (bas ^ pwr) × term *) + let rec mul_pwr bas pwr term = + if Q.equal Q.zero pwr then term + else mul_pwr bas Q.(pwr - one) (simp_mul2 bas term) + in + Qset.fold es ~init:one ~f:(fun bas pwr term -> + if Q.sign pwr >= 0 then mul_pwr bas pwr term + else simp_div term (mul_pwr bas (Q.neg pwr) one) ) + (* if-then-else *) let simp_cond cnd thn els = @@ -1178,7 +1174,7 @@ let solve_zero_eq ?for_ e = if Q.equal Q.zero q then None else Some (f, q) | None -> Some (Qset.min_elt_exn args) in - let n = Sum.to_term (Qset.remove args c) in + let n = sum_to_term (Qset.remove args c) in let d = rational (Q.neg q) in let r = div n d in (c, r)