[sledge] Minor simplification of polynomial representation

Reviewed By: bennostein

Differential Revision: D17665237

fbshipit-source-id: f9a082d26
master
Josh Berdine 5 years ago committed by Facebook Github Bot
parent 3bbb05216f
commit e87a0533be

@ -20,8 +20,8 @@ module rec T : sig
type t = type t =
(* nary: arithmetic, numeric and pointer *) (* nary: arithmetic, numeric and pointer *)
| Add of {args: qset} | Add of qset
| Mul of {args: qset} | Mul of qset
(* pointer and memory constants and operations *) (* pointer and memory constants and operations *)
| Splat of {byt: t; siz: t} | Splat of {byt: t; siz: t}
| Memory of {siz: t; arr: t} | Memory of {siz: t; arr: t}
@ -80,8 +80,8 @@ and T0 : sig
type qset = Qset.M(T).t [@@deriving compare, equal, hash, sexp] type qset = Qset.M(T).t [@@deriving compare, equal, hash, sexp]
type t = type t =
| Add of {args: qset} | Add of qset
| Mul of {args: qset} | Mul of qset
| Splat of {byt: t; siz: t} | Splat of {byt: t; siz: t}
| Memory of {siz: t; arr: t} | Memory of {siz: t; arr: t}
| Concat of {args: t vector} | Concat of {args: t vector}
@ -117,8 +117,8 @@ end = struct
type qset = Qset.M(T).t [@@deriving compare, equal, hash, sexp] type qset = Qset.M(T).t [@@deriving compare, equal, hash, sexp]
type t = type t =
| Add of {args: qset} | Add of qset
| Mul of {args: qset} | Mul of qset
| Splat of {byt: t; siz: t} | Splat of {byt: t; siz: t}
| Memory of {siz: t; arr: t} | Memory of {siz: t; arr: t}
| Concat of {args: t vector} | Concat of {args: t vector}
@ -212,7 +212,7 @@ let rec pp ?is_x fs term =
| Le -> pf "@<1>≤" | Le -> pf "@<1>≤"
| Ord -> pf "ord" | Ord -> pf "ord"
| Uno -> pf "uno" | Uno -> pf "uno"
| Add {args} -> | Add args ->
let pp_poly_term fs (monomial, coefficient) = let pp_poly_term fs (monomial, coefficient) =
match monomial with match monomial with
| Integer {data} when Z.equal Z.one data -> Q.pp fs coefficient | Integer {data} when Z.equal Z.one data -> Q.pp fs coefficient
@ -221,7 +221,7 @@ let rec pp ?is_x fs term =
Format.fprintf fs "%a @<1>× %a" Q.pp coefficient pp monomial Format.fprintf fs "%a @<1>× %a" Q.pp coefficient pp monomial
in in
pf "(%a)" (Qset.pp "@ + " pp_poly_term) args pf "(%a)" (Qset.pp "@ + " pp_poly_term) args
| Mul {args} -> | Mul args ->
let pp_mono_term fs (factor, exponent) = let pp_mono_term fs (factor, exponent) =
if Q.equal Q.one exponent then pp fs factor if Q.equal Q.one exponent then pp fs factor
else Format.fprintf fs "%a^%a" pp factor Q.pp exponent else Format.fprintf fs "%a^%a" pp factor Q.pp exponent
@ -311,7 +311,7 @@ let rec assert_indeterminate = function
*) *)
let assert_monomial mono = let assert_monomial mono =
match mono with match mono with
| Mul {args} -> | Mul args ->
Qset.iter args ~f:(fun factor exponent -> Qset.iter args ~f:(fun factor exponent ->
assert (Q.sign exponent > 0) ; assert (Q.sign exponent > 0) ;
assert_indeterminate factor |> Fn.id ) assert_indeterminate factor |> Fn.id )
@ -324,7 +324,7 @@ let assert_poly_term mono coeff =
assert (not (Q.equal Q.zero coeff)) ; assert (not (Q.equal Q.zero coeff)) ;
match mono with match mono with
| Integer {data} -> assert (Z.equal Z.one data) | Integer {data} -> assert (Z.equal Z.one data)
| Mul {args} -> | Mul args ->
( match Qset.min_elt args with ( match Qset.min_elt args with
| None | Some (Integer _, _) -> assert false | None | Some (Integer _, _) -> assert false
| Some (_, n) -> assert (Qset.length args > 1 || not (Q.equal Q.one n)) | Some (_, n) -> assert (Qset.length args > 1 || not (Q.equal Q.one n))
@ -339,7 +339,7 @@ let assert_poly_term mono coeff =
*) *)
let assert_polynomial poly = let assert_polynomial poly =
match poly with match poly with
| Add {args} -> | Add args ->
( match Qset.min_elt args with ( match Qset.min_elt args with
| None | Some (Integer _, _) -> assert false | None | Some (Integer _, _) -> assert false
| Some (_, k) -> assert (Qset.length args > 1 || not (Q.equal Q.one k)) | Some (_, k) -> assert (Qset.length args > 1 || not (Q.equal Q.one k))
@ -503,7 +503,7 @@ let fold_terms e ~init ~f =
|Splat {byt= x; siz= y} |Splat {byt= x; siz= y}
|Memory {siz= x; arr= y} -> |Memory {siz= x; arr= y} ->
fold_terms_ y (fold_terms_ x z) fold_terms_ y (fold_terms_ x z)
| Add {args} | Mul {args} -> | Add args | Mul args ->
Qset.fold args ~init:z ~f:(fun arg _ z -> fold_terms_ arg z) Qset.fold args ~init:z ~f:(fun arg _ z -> fold_terms_ arg z)
| Concat {args} | Struct_rec {elts= args} -> | Concat {args} | Struct_rec {elts= args} ->
Vector.fold args ~init:z ~f:(fun z elt -> fold_terms_ elt z) Vector.fold args ~init:z ~f:(fun z elt -> fold_terms_ elt z)
@ -574,8 +574,8 @@ let rec sum_to_term sum =
match Qset.min_elt sum with match Qset.min_elt sum with
| Some (Integer _, q) -> rational q | Some (Integer _, q) -> rational q
| Some (arg, q) when Q.equal Q.one q -> arg | Some (arg, q) when Q.equal Q.one q -> arg
| _ -> Add {args= sum} ) | _ -> Add sum )
| _ -> Add {args= sum} | _ -> Add sum
and rational Q.{num; den} = simp_div (integer num) (integer den) and rational Q.{num; den} = simp_div (integer num) (integer den)
@ -587,7 +587,7 @@ and simp_div x y =
(* e / 1 ==> e *) (* e / 1 ==> e *)
| e, Integer {data} when Z.equal Z.one data -> e | e, Integer {data} when Z.equal Z.one data -> e
(* (∑ᵢ cᵢ × Xᵢ) / z ==> ∑ᵢ cᵢ/z × Xᵢ *) (* (∑ᵢ cᵢ × Xᵢ) / z ==> ∑ᵢ cᵢ/z × Xᵢ *)
| Add {args}, Integer {data} -> | Add args, Integer {data} ->
sum_to_term (sum_mul_const Q.(inv (of_z data)) args) sum_to_term (sum_mul_const Q.(inv (of_z data)) args)
| _ -> App {op= App {op= Div; arg= x}; arg= y} | _ -> App {op= App {op= Div; arg= x}; arg= y}
@ -635,9 +635,9 @@ let rec simp_add_ es poly =
| Integer {data= i}, Integer {data= j} -> | Integer {data= i}, Integer {data= j} ->
rational Q.((coeff * of_z i) + of_z j) rational Q.((coeff * of_z i) + of_z j)
(* (c × ∑ᵢ cᵢ × Xᵢ) + s ==> (∑ᵢ (c × cᵢ) × Xᵢ) + s *) (* (c × ∑ᵢ cᵢ × Xᵢ) + s ==> (∑ᵢ (c × cᵢ) × Xᵢ) + s *)
| Add {args}, _ -> simp_add_ (Sum.mul_const coeff args) poly | Add args, _ -> simp_add_ (Sum.mul_const coeff args) poly
(* (c₀ × X₀) + (∑ᵢ₌₁ⁿ cᵢ × Xᵢ) ==> ∑ᵢ₌₀ⁿ cᵢ × Xᵢ *) (* (c₀ × X₀) + (∑ᵢ₌₁ⁿ cᵢ × Xᵢ) ==> ∑ᵢ₌₀ⁿ cᵢ × Xᵢ *)
| _, Add {args} -> Sum.to_term (Sum.add coeff term args) | _, Add args -> Sum.to_term (Sum.add coeff term args)
(* (c₁ × X₁) + X₂ ==> ∑ᵢ₌₁² cᵢ × Xᵢ for c₂ = 1 *) (* (c₁ × X₁) + X₂ ==> ∑ᵢ₌₁² cᵢ × Xᵢ for c₂ = 1 *)
| _ -> Sum.to_term (Sum.add coeff term (Sum.singleton poly)) | _ -> Sum.to_term (Sum.add coeff term (Sum.singleton poly))
in in
@ -669,28 +669,27 @@ let rec simp_mul2 e f =
(* e × 0 ==> 0 *) (* e × 0 ==> 0 *)
| _, Integer {data} when Z.equal Z.zero data -> f | _, Integer {data} when Z.equal Z.zero data -> f
(* c × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ c × cᵤ × ∏ⱼ yᵤⱼ *) (* c × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ c × cᵤ × ∏ⱼ yᵤⱼ *)
| Integer {data}, Add {args} | Add {args}, Integer {data} -> | Integer {data}, Add args | Add args, Integer {data} ->
Sum.to_term (Sum.mul_const (Q.of_z data) args) Sum.to_term (Sum.mul_const (Q.of_z data) args)
(* c₁ × x₁ ==> ∑ᵢ₌₁ cᵢ × xᵢ *) (* c₁ × x₁ ==> ∑ᵢ₌₁ cᵢ × xᵢ *)
| Integer {data= c}, x | x, Integer {data= c} -> | Integer {data= c}, x | x, Integer {data= c} ->
Sum.to_term (Sum.singleton ~coeff:(Q.of_z c) x) Sum.to_term (Sum.singleton ~coeff:(Q.of_z c) x)
(* (∏ᵤ₌₀ⁱ xᵤ) × (∏ᵥ₌ᵢ₊₁ⁿ xᵥ) ==> ∏ⱼ₌₀ⁿ xⱼ *) (* (∏ᵤ₌₀ⁱ xᵤ) × (∏ᵥ₌ᵢ₊₁ⁿ xᵥ) ==> ∏ⱼ₌₀ⁿ xⱼ *)
| Mul {args= xs1}, Mul {args= xs2} -> Mul {args= Prod.union xs1 xs2} | Mul xs1, Mul xs2 -> Mul (Prod.union xs1 xs2)
(* (∏ᵢ xᵢ) × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ cᵤ × ∏ᵢ xᵢ × ∏ⱼ yᵤⱼ *) (* (∏ᵢ xᵢ) × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ cᵤ × ∏ᵢ xᵢ × ∏ⱼ yᵤⱼ *)
| (Mul {args= prod} as m), Add {args= sum} | (Mul prod as m), Add sum | Add sum, (Mul prod as m) ->
|Add {args= sum}, (Mul {args= prod} as m) ->
Sum.to_term Sum.to_term
(Sum.map sum ~f:(function (Sum.map sum ~f:(function
| Mul {args} -> Mul {args= Prod.union prod args} | Mul args -> Mul (Prod.union prod args)
| Integer _ as c -> simp_mul2 c m | Integer _ as c -> simp_mul2 c m
| mono -> Mul {args= Prod.add mono prod} )) | mono -> Mul (Prod.add mono prod) ))
(* x₀ × (∏ᵢ₌₁ⁿ xᵢ) ==> ∏ᵢ₌₀ⁿ xᵢ *) (* x₀ × (∏ᵢ₌₁ⁿ xᵢ) ==> ∏ᵢ₌₀ⁿ xᵢ *)
| Mul {args= xs1}, x | x, Mul {args= xs1} -> Mul {args= Prod.add x xs1} | Mul xs1, x | x, Mul xs1 -> Mul (Prod.add x xs1)
(* e × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ e × cᵤ × ∏ⱼ yᵤⱼ *) (* e × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ e × cᵤ × ∏ⱼ yᵤⱼ *)
| Add {args}, e | e, Add {args} -> | Add args, e | e, Add args ->
simp_add (Sum.map ~f:(fun m -> simp_mul2 e m) args) simp_add (Sum.map ~f:(fun m -> simp_mul2 e m) args)
(* x₁ × x₂ ==> ∏ᵢ₌₁² xᵢ *) (* x₁ × x₂ ==> ∏ᵢ₌₁² xᵢ *)
| _ -> Mul {args= Prod.add e (Prod.singleton f)} | _ -> Mul (Prod.add e (Prod.singleton f))
let simp_mul es = let simp_mul es =
(* (bas ^ pwr) × term *) (* (bas ^ pwr) × term *)
@ -869,7 +868,7 @@ let iter e ~f =
| App {op= x; arg= y} | Splat {byt= x; siz= y} | Memory {siz= x; arr= y} | App {op= x; arg= y} | Splat {byt= x; siz= y} | Memory {siz= x; arr= y}
-> ->
f x ; f y f x ; f y
| Add {args} | Mul {args} -> Qset.iter ~f:(fun arg _ -> f arg) args | Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args
| Concat {args} | Struct_rec {elts= args} -> Vector.iter ~f args | Concat {args} | Struct_rec {elts= args} -> Vector.iter ~f args
| _ -> () | _ -> ()
@ -878,8 +877,7 @@ let fold e ~init:s ~f =
| App {op= x; arg= y} | Splat {byt= x; siz= y} | Memory {siz= x; arr= y} | App {op= x; arg= y} | Splat {byt= x; siz= y} | Memory {siz= x; arr= y}
-> ->
f y (f x s) f y (f x s)
| Add {args} | Mul {args} -> | Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s
Qset.fold ~f:(fun e _ s -> f e s) args ~init:s
| Concat {args} | Struct_rec {elts= args} -> | Concat {args} | Struct_rec {elts= args} ->
Vector.fold ~f:(fun s e -> f e s) args ~init:s Vector.fold ~f:(fun s e -> f e s) args ~init:s
| _ -> s | _ -> s
@ -1090,8 +1088,8 @@ let map e ~f =
in in
match e with match e with
| App {op; arg} -> map_bin (app1 ~partial:true) ~f op arg | App {op; arg} -> map_bin (app1 ~partial:true) ~f op arg
| Add {args} -> map_qset addN ~f args | Add args -> map_qset addN ~f args
| Mul {args} -> map_qset mulN ~f args | Mul args -> map_qset mulN ~f args
| Splat {byt; siz} -> map_bin simp_splat ~f byt siz | Splat {byt; siz} -> map_bin simp_splat ~f byt siz
| Memory {siz; arr} -> map_bin simp_memory ~f siz arr | Memory {siz; arr} -> map_bin simp_memory ~f siz arr
| Concat {args} -> map_vector simp_concat ~f args | Concat {args} -> map_vector simp_concat ~f args
@ -1120,7 +1118,7 @@ let rec is_constant e =
| App {op= x; arg= y} | Splat {byt= x; siz= y} | Memory {siz= x; arr= y} | App {op= x; arg= y} | Splat {byt= x; siz= y} | Memory {siz= x; arr= y}
-> ->
is_constant_bin x y is_constant_bin x y
| Add {args} | Mul {args} -> | Add args | Mul args ->
Qset.for_all ~f:(fun arg _ -> is_constant arg) args Qset.for_all ~f:(fun arg _ -> is_constant arg) args
| Concat {args} | Struct_rec {elts= args} -> | Concat {args} | Struct_rec {elts= args} ->
Vector.for_all ~f:is_constant args Vector.for_all ~f:is_constant args
@ -1157,7 +1155,7 @@ let solve e f =
match (e, f) with match (e, f) with
| (Add _ | Mul _ | Integer _), _ | _, (Add _ | Mul _ | Integer _) -> ( | (Add _ | Mul _ | Integer _), _ | _, (Add _ | Mul _ | Integer _) -> (
match sub e f with match sub e f with
| Add {args} -> | Add args ->
let c, q = Qset.min_elt_exn args in let c, q = Qset.min_elt_exn args in
let n = Sum.to_term (Qset.remove args c) in let n = Sum.to_term (Qset.remove args c) in
let d = rational (Q.neg q) in let d = rational (Q.neg q) in

@ -21,8 +21,8 @@ type comparator_witness
type qset = (t, comparator_witness) Qset.t type qset = (t, comparator_witness) Qset.t
and t = private and t = private
| Add of {args: qset} (** Addition *) | Add of qset (** Addition *)
| Mul of {args: qset} (** Multiplication *) | Mul of qset (** Multiplication *)
| Splat of {byt: t; siz: t} | Splat of {byt: t; siz: t}
(** Iterated concatenation of a single byte *) (** Iterated concatenation of a single byte *)
| Memory of {siz: t; arr: t} (** Size-tagged byte-array *) | Memory of {siz: t; arr: t} (** Size-tagged byte-array *)

Loading…
Cancel
Save