[sledge] Change polynomial coefficients and powers to rationals

Summary:
Rational coefficients enable solving equations via Gaussian
elimination, and rational powers enable some normalization of
multiplication and division.

Reviewed By: ngorogiannis

Differential Revision: D14233213

fbshipit-source-id: acecb7f16
master
Josh Berdine 6 years ago committed by Facebook Github Bot
parent d01de4b0dd
commit 0ecee6a848

@ -265,19 +265,36 @@ module Set = struct
let to_tree = Using_comparator.to_tree
end
module Mset = struct
include Mset
module Qset = struct
include Qset
let pp sep pp_elt fs s = List.pp sep pp_elt fs (to_list s)
end
module Z = struct
include Z
module Q = struct
let pp = Q.pp_print
let hash = Hashtbl.hash
let hash_fold_t s q = Int.hash_fold_t s (hash q)
let sexp_of_t q = Sexp.Atom (Q.to_string q)
let t_of_sexp = function
| Sexp.Atom s -> Q.of_string s
| _ -> assert false
let of_z = Q.of_bigint
let hash_fold_t s z = Int.hash_fold_t s (Z.hash z)
include Q
end
module Z = struct
let pp = Z.pp_print
let hash = [%hash: Z.t]
let hash_fold_t s z = Int.hash_fold_t s (hash z)
let sexp_of_t z = Sexp.Atom (Z.to_string z)
let t_of_sexp = function
| Sexp.Atom s -> Z.of_string s
| _ -> assert false
include Z
end

@ -212,17 +212,31 @@ module Set : sig
val to_tree : ('e, 'c) t -> ('e, 'c) tree
end
module Mset : sig
include module type of Mset
module Qset : sig
include module type of Qset
val pp : (unit, unit) fmt -> ('a * Z.t) pp -> ('a, _) t pp
(** Pretty-print a multiset. *)
val pp : (unit, unit) fmt -> ('a * Q.t) pp -> ('a, _) t pp
end
module Q : sig
include module type of struct include Q end
val of_z : Z.t -> t
val compare : t -> t -> int
val hash : t -> int
val hash_fold_t : t Hash.folder
val t_of_sexp : Sexp.t -> t
val sexp_of_t : t -> Sexp.t
val pp : t pp
end
module Z : sig
include module type of struct include Z end
val compare : t -> t -> int
val hash : t -> int
val hash_fold_t : t Hash.folder
val t_of_sexp : Sexp.t -> t
val sexp_of_t : t -> Sexp.t
val pp : t pp
end

@ -5,12 +5,11 @@
* LICENSE file in the root directory of this source tree.
*)
(** Mset - Set with integer (positive, negative, or zero) multiplicity for
each element *)
(** Qset - Set with (signed) rational multiplicity for each element *)
open Base
type ('elt, 'cmp) t = ('elt, Z.t, 'cmp) Map.t
type ('elt, 'cmp) t = ('elt, Q.t, 'cmp) Map.t
module M (Elt : sig
type t
@ -33,23 +32,23 @@ end
module type Compare_m = sig end
module type Hash_fold_m = Hasher.S
let sexp_of_z z = Sexp.Atom (Z.to_string z)
let z_of_sexp = function Sexp.Atom s -> Z.of_string s | _ -> assert false
let hash_fold_z state z = Hash.fold_int state (Z.hash z)
let sexp_of_q q = Sexp.Atom (Q.to_string q)
let q_of_sexp = function Sexp.Atom s -> Q.of_string s | _ -> assert false
let hash_fold_q state q = Hash.fold_int state (Hashtbl.hash q)
let sexp_of_m__t (type elt) (module Elt : Sexp_of_m with type t = elt) t =
Map.sexp_of_m__t (module Elt) sexp_of_z t
Map.sexp_of_m__t (module Elt) sexp_of_q t
let m__t_of_sexp (type elt cmp)
(module Elt : M_of_sexp
with type t = elt and type comparator_witness = cmp) sexp =
Map.m__t_of_sexp (module Elt) z_of_sexp sexp
Map.m__t_of_sexp (module Elt) q_of_sexp sexp
let compare_m__t (module Elt : Compare_m) = Map.compare_direct Z.compare
let compare_m__t (module Elt : Compare_m) = Map.compare_direct Q.compare
let hash_fold_m__t (type elt) (module Elt : Hash_fold_m with type t = elt)
state =
Map.hash_fold_m__t (module Elt) hash_fold_z state
Map.hash_fold_m__t (module Elt) hash_fold_q state
let hash_m__t (type elt) (module Elt : Hash_fold_m with type t = elt) =
Hash.of_fold (hash_fold_m__t (module Elt))
@ -58,22 +57,22 @@ type ('elt, 'cmp) comparator =
(module Comparator.S with type t = 'elt and type comparator_witness = 'cmp)
let empty cmp = Map.empty cmp
let if_nz z = if Z.equal Z.zero z then None else Some z
let if_nz q = if Q.equal Q.zero q then None else Some q
let add m x i =
Map.change m x ~f:(function Some j -> if_nz Z.(i + j) | None -> if_nz i)
Map.change m x ~f:(function Some j -> if_nz Q.(i + j) | None -> if_nz i)
let remove m x = Map.remove m x
let union m n =
Map.merge m n ~f:(fun ~key:_ -> function
| `Both (i, j) -> if_nz Z.(i + j) | `Left i | `Right i -> Some i )
| `Both (i, j) -> if_nz Q.(i + j) | `Left i | `Right i -> Some i )
let length m = Map.length m
let count m x = match Map.find m x with Some z -> z | None -> Z.zero
let count m x = match Map.find m x with Some q -> q | None -> Q.zero
let count_and_remove m x =
let found = ref Z.zero in
let found = ref Q.zero in
let m =
Map.change m x ~f:(function
| None -> None
@ -81,7 +80,7 @@ let count_and_remove m x =
found := i ;
None )
in
if Z.equal !found Z.zero then None else Some (!found, m)
if Q.equal !found Q.zero then None else Some (!found, m)
let min_elt = Map.min_elt
let fold m ~f ~init = Map.fold m ~f:(fun ~key ~data s -> f key data s) ~init
@ -90,14 +89,14 @@ let map m ~f =
fold m ~init:m ~f:(fun x i m ->
let x', i' = f x i in
if phys_equal x' x then
if Z.equal i' i then m else Map.set m ~key:x ~data:i'
if Q.equal i' i then m else Map.set m ~key:x ~data:i'
else add (Map.remove m x) x' i' )
let fold_map m ~f ~init:s =
fold m ~init:(m, s) ~f:(fun x i (m, s) ->
let x', i', s = f x i s in
if phys_equal x' x then
if Z.equal i' i then (m, s) else (Map.set m ~key:x ~data:i', s)
if Q.equal i' i then (m, s) else (Map.set m ~key:x ~data:i', s)
else (add (Map.remove m x) x' i', s) )
let for_all m ~f = Map.for_alli m ~f:(fun ~key ~data -> f key data)

@ -5,8 +5,7 @@
* LICENSE file in the root directory of this source tree.
*)
(** Mset - Set with integer (positive, negative, or zero) multiplicity for
each element *)
(** Qset - Set with (signed) rational multiplicity for each element *)
open Base
@ -58,7 +57,7 @@ val hash_m__t :
val empty : ('elt, 'cmp) comparator -> ('elt, 'cmp) t
(** The empty multiset over the provided order. *)
val add : ('a, 'c) t -> 'a -> Z.t -> ('a, 'c) t
val add : ('a, 'c) t -> 'a -> Q.t -> ('a, 'c) t
(** Add to multiplicity of single element. [O(log n)] *)
val remove : ('a, 'c) t -> 'a -> ('a, 'c) t
@ -70,39 +69,39 @@ val union : ('a, 'c) t -> ('a, 'c) t -> ('a, 'c) t
val length : _ t -> int
(** Number of elements with non-zero multiplicity. [O(1)]. *)
val count : ('a, _) t -> 'a -> Z.t
val count : ('a, _) t -> 'a -> Q.t
(** Multiplicity of an element. [O(log n)]. *)
val count_and_remove : ('a, 'c) t -> 'a -> (Z.t * ('a, 'c) t) option
val count_and_remove : ('a, 'c) t -> 'a -> (Q.t * ('a, 'c) t) option
(** Multiplicity of an element, and remove it. [O(log n)]. *)
val map : ('a, 'c) t -> f:('a -> Z.t -> 'a * Z.t) -> ('a, 'c) t
val map : ('a, 'c) t -> f:('a -> Q.t -> 'a * Q.t) -> ('a, 'c) t
(** Map over the elements in ascending order. Preserves physical equality if
[f] does. *)
val map_counts : ('a, 'c) t -> f:('a -> Z.t -> Z.t) -> ('a, 'c) t
val map_counts : ('a, 'c) t -> f:('a -> Q.t -> Q.t) -> ('a, 'c) t
(** Map over the multiplicities of the elements in ascending order. *)
val fold : ('a, _) t -> f:('a -> Z.t -> 's -> 's) -> init:'s -> 's
val fold : ('a, _) t -> f:('a -> Q.t -> 's -> 's) -> init:'s -> 's
(** Fold over the elements in ascending order. *)
val fold_map :
('a, 'c) t
-> f:('a -> Z.t -> 's -> 'a * Z.t * 's)
-> f:('a -> Q.t -> 's -> 'a * Q.t * 's)
-> init:'s
-> ('a, 'c) t * 's
(** Folding map over the elements in ascending order. Preserves physical
equality if [f] does. *)
val for_all : ('a, _) t -> f:('a -> Z.t -> bool) -> bool
val for_all : ('a, _) t -> f:('a -> Q.t -> bool) -> bool
(** Universal property test. [O(n)] but returns as soon as a violation is
found, in ascending order. *)
val iter : ('a, _) t -> f:('a -> Z.t -> unit) -> unit
val iter : ('a, _) t -> f:('a -> Q.t -> unit) -> unit
(** Iterate over the elements in ascending order. *)
val min_elt : ('a, _) t -> ('a * Z.t) option
val min_elt : ('a, _) t -> ('a * Q.t) option
(** Minimum element. *)
val to_list : ('a, _) t -> ('a * Z.t) list
val to_list : ('a, _) t -> ('a * Q.t) list
(** Convert to a list of elements in ascending order. *)

@ -9,13 +9,7 @@
(** Z wrapped to treat bounded and unsigned operations *)
module Z = struct
type t = Z.t [@@deriving compare, hash, sexp]
include (Z : module type of Z with type t := t)
let pp = Z.pp_print
let is_zero = Z.equal zero
let is_one = Z.equal one
include Z
(* the signed 1-bit integers are -1 and 0 *)
let true_ = Z.minus_one
@ -43,7 +37,6 @@ module Z = struct
let bugeq ~bits x y = clamp_cmp ~signed:false bits Z.geq x y
let bult ~bits x y = clamp_cmp ~signed:false bits Z.lt x y
let bugt ~bits x y = clamp_cmp ~signed:false bits Z.gt x y
let badd ~bits x y = clamp_bop ~signed:true bits Z.add x y
let bsub ~bits x y = clamp_bop ~signed:true bits Z.sub x y
let bmul ~bits x y = clamp_bop ~signed:true bits Z.mul x y
let bdiv ~bits x y = clamp_bop ~signed:true bits Z.div x y
@ -61,13 +54,13 @@ module Z = struct
end
module rec T : sig
type mset = Mset.M(T).t [@@deriving compare, hash, sexp]
type qset = Qset.M(T).t [@@deriving compare, hash, sexp]
type t =
| App of {op: t; arg: t}
(* nary: arithmetic, numeric and pointer *)
| Add of {args: mset; typ: Typ.t}
| Mul of {args: mset; typ: Typ.t}
| Add of {args: qset; typ: Typ.t}
| Mul of {args: qset; typ: Typ.t}
| Var of {id: int; name: string}
| Nondet of {msg: string}
| Label of {parent: string; name: string}
@ -124,12 +117,12 @@ end
(* auxiliary definition for safe recursive module initialization *)
and T0 : sig
type mset = Mset.M(T).t [@@deriving compare, hash, sexp]
type qset = Qset.M(T).t [@@deriving compare, hash, sexp]
type t =
| App of {op: t; arg: t}
| Add of {args: mset; typ: Typ.t}
| Mul of {args: mset; typ: Typ.t}
| Add of {args: qset; typ: Typ.t}
| Mul of {args: qset; typ: Typ.t}
| Var of {id: int; name: string}
| Nondet of {msg: string}
| Label of {parent: string; name: string}
@ -168,12 +161,12 @@ and T0 : sig
| Convert of {signed: bool; dst: Typ.t; src: Typ.t}
[@@deriving compare, hash, sexp]
end = struct
type mset = Mset.M(T).t [@@deriving compare, hash, sexp]
type qset = Qset.M(T).t [@@deriving compare, hash, sexp]
type t =
| App of {op: t; arg: t}
| Add of {args: mset; typ: Typ.t}
| Mul of {args: mset; typ: Typ.t}
| Add of {args: qset; typ: Typ.t}
| Mul of {args: qset; typ: Typ.t}
| Var of {id: int; name: string}
| Nondet of {msg: string}
| Label of {parent: string; name: string}
@ -218,7 +211,7 @@ type _t = T0.t
include T
let empty_mset = Mset.empty (module T)
let empty_qset = Qset.empty (module T)
let equal = [%compare.equal: t]
let sorted e f = compare e f <= 0
let sort e f = if sorted e f then (e, f) else (f, e)
@ -257,7 +250,7 @@ let rec pp fs exp =
| Var {name; id} -> pf "%%%s_%d" name id
| Nondet {msg} -> pf "nondet \"%s\"" msg
| Label {name} -> pf "%s" name
| Integer {data; typ= Pointer _} when Z.is_zero data -> pf "null"
| Integer {data; typ= Pointer _} when Z.equal Z.zero data -> pf "null"
| Splat -> pf "^"
| Memory -> pf "⟨_,_⟩"
| App {op= Memory; arg= siz} -> pf "@<1>⟨%a,_@<1>⟩" pp siz
@ -281,18 +274,18 @@ let rec pp fs exp =
| Add {args} ->
let pp_poly_term fs (monomial, coefficient) =
match monomial with
| Integer {data} when Z.is_one data -> Z.pp fs coefficient
| _ when Z.is_one coefficient -> pp fs monomial
| Integer {data} when Z.equal Z.one data -> Q.pp fs coefficient
| _ when Q.equal Q.one coefficient -> pp fs monomial
| _ ->
Format.fprintf fs "%a @<1>× %a" Z.pp coefficient pp monomial
Format.fprintf fs "%a @<1>× %a" Q.pp coefficient pp monomial
in
pf "(%a)" (Mset.pp "@ + " pp_poly_term) args
pf "(%a)" (Qset.pp "@ + " pp_poly_term) args
| Mul {args} ->
let pp_mono_term fs (factor, exponent) =
if Z.is_one exponent then pp fs factor
else Format.fprintf fs "%a^%a" pp factor Z.pp exponent
if Q.equal Q.one exponent then pp fs factor
else Format.fprintf fs "%a^%a" pp factor Q.pp exponent
in
pf "(%a)" (Mset.pp "@ @<2>× " pp_mono_term) args
pf "(%a)" (Qset.pp "@ @<2>× " pp_mono_term) args
| Div -> pf "/"
| Udiv -> pf "udiv"
| Rem -> pf "rem"
@ -389,8 +382,8 @@ let assert_monomial add_typ mono =
| Mul {typ; args} ->
assert (Typ.castable add_typ typ) ;
assert (Option.exists ~f:(fun n -> 1 < n) (Typ.prim_bit_size_of typ)) ;
Mset.iter args ~f:(fun factor exponent ->
assert (Z.sign exponent > 0) ;
Qset.iter args ~f:(fun factor exponent ->
assert (Q.sign exponent > 0) ;
assert_indeterminate factor |> Fn.id )
| _ -> assert_indeterminate mono |> Fn.id
@ -398,12 +391,12 @@ let assert_monomial add_typ mono =
* c × x
*)
let assert_poly_term add_typ mono coeff =
assert (not (Z.is_zero coeff)) ;
assert (not (Q.equal Q.zero coeff)) ;
match mono with
| Integer {data} -> assert (Z.is_one data)
| Integer {data} -> assert (Z.equal Z.one data)
| Mul {args} ->
if Z.is_one coeff then assert (Mset.length args > 1)
else assert (Mset.length args > 0) ;
if Q.equal Q.one coeff then assert (Qset.length args > 1)
else assert (Qset.length args > 0) ;
assert_monomial add_typ mono |> Fn.id
| _ -> assert_monomial add_typ mono |> Fn.id
@ -415,15 +408,15 @@ let assert_poly_term add_typ mono coeff =
let assert_polynomial poly =
match poly with
| Add {typ; args} ->
( match Mset.length args with
( match Qset.length args with
| 0 -> assert false
| 1 -> (
match Mset.min_elt args with
match Qset.min_elt args with
| Some (Integer _, _) -> assert false
| Some (_, k) -> assert (not (Z.is_one k))
| Some (_, k) -> assert (not (Q.equal Q.one k))
| _ -> () )
| _ -> () ) ;
Mset.iter args ~f:(fun m c -> assert_poly_term typ m c |> Fn.id)
Qset.iter args ~f:(fun m c -> assert_poly_term typ m c |> Fn.id)
| _ -> assert false
let invariant ?(partial = false) e =
@ -616,7 +609,7 @@ let fold_exps e ~init ~f =
match e with
| App {op; arg} -> fold_exps_ op (fold_exps_ arg z)
| Add {args} | Mul {args} ->
Mset.fold args ~init:z ~f:(fun arg _ z -> fold_exps_ arg z)
Qset.fold args ~init:z ~f:(fun arg _ z -> fold_exps_ arg z)
| Struct_rec {elts} ->
Vector.fold elts ~init:z ~f:(fun z elt -> fold_exps_ elt z)
| _ -> z
@ -705,7 +698,7 @@ let simp_div x y =
let bits = Option.value_exn (Typ.prim_bit_size_of typ) in
integer (Z.bdiv ~bits i j) typ
(* e / 1 ==> e *)
| e, Integer {data} when Z.is_one data -> e
| e, Integer {data} when Z.equal Z.one data -> e
| _ -> App {op= App {op= Div; arg= x}; arg= y}
let simp_udiv x y =
@ -715,7 +708,7 @@ let simp_udiv x y =
let bits = Option.value_exn (Typ.prim_bit_size_of typ) in
integer (Z.budiv ~bits i j) typ
(* e u/ 1 ==> e *)
| e, Integer {data} when Z.is_one data -> e
| e, Integer {data} when Z.equal Z.one data -> e
| _ -> App {op= App {op= Udiv; arg= x}; arg= y}
let simp_rem x y =
@ -725,7 +718,7 @@ let simp_rem x y =
let bits = Option.value_exn (Typ.prim_bit_size_of typ) in
integer (Z.brem ~bits i j) typ
(* e % 1 ==> 0 *)
| _, Integer {data; typ} when Z.is_one data -> integer Z.zero typ
| _, Integer {data; typ} when Z.equal Z.one data -> integer Z.zero typ
| _ -> App {op= App {op= Rem; arg= x}; arg= y}
let simp_urem x y =
@ -735,42 +728,45 @@ let simp_urem x y =
let bits = Option.value_exn (Typ.prim_bit_size_of typ) in
integer (Z.burem ~bits i j) typ
(* e u% 1 ==> 0 *)
| _, Integer {data; typ} when Z.is_one data -> integer Z.zero typ
| _, Integer {data; typ} when Z.equal Z.one data -> integer Z.zero typ
| _ -> App {op= App {op= Urem; arg= x}; arg= y}
let rational Q.({num; den}) typ =
simp_div (integer num typ) (integer den typ)
(* Sums of polynomial terms represented by multisets. A sum ∑ᵢ cᵢ ×
X of monomials X with coefficients c is represented by a
multiset where the elements are X with multiplicities c. A constant
is treated as the coefficient of the empty monomial, which is the unit of
multiplication 1. *)
module Sum = struct
let empty = empty_mset
let empty = empty_qset
let add coeff exp sum =
assert (not (Z.is_zero coeff)) ;
assert (not (Q.equal Q.zero coeff)) ;
match exp with
| Integer {data} when Z.is_zero data -> sum
| Integer {data} when Z.equal Z.zero data -> sum
| Integer {data; typ} ->
Mset.add sum (integer Z.one typ) Z.(coeff * data)
| _ -> Mset.add sum exp coeff
Qset.add sum (integer Z.one typ) Q.(coeff * of_z data)
| _ -> Qset.add sum exp coeff
let singleton ?(coeff = Z.one) exp = add coeff exp empty
let singleton ?(coeff = Q.one) exp = add coeff exp empty
let map sum ~f =
Mset.fold sum ~init:empty ~f:(fun e c sum -> add c (f e) sum)
Qset.fold sum ~init:empty ~f:(fun e c sum -> add c (f e) sum)
let mul_const const sum =
assert (not (Z.is_zero const)) ;
if Z.is_one const then sum
else Mset.map_counts ~f:(fun _ -> Z.mul const) sum
assert (not (Q.equal Q.zero const)) ;
if Q.equal Q.one const then sum
else Qset.map_counts ~f:(fun _ -> Q.mul const) sum
let to_exp typ sum =
match Mset.length sum with
match Qset.length sum with
| 0 -> integer Z.zero typ
| 1 -> (
match Mset.min_elt sum with
| Some (Integer _, z) -> integer z typ
| Some (arg, z) when Z.is_one z -> arg
match Qset.min_elt sum with
| Some (Integer _, q) -> rational q typ
| Some (arg, q) when Q.equal Q.one q -> arg
| _ -> Add {typ; args= sum} )
| _ -> Add {typ; args= sum}
end
@ -780,12 +776,12 @@ let rec simp_add_ typ es poly =
let f exp coeff poly =
match (exp, poly) with
(* (0 × e) + s ==> 0 (optim) *)
| _ when Z.is_zero coeff -> poly
| _ when Q.equal Q.zero coeff -> poly
(* (c × 0) + s ==> s (optim) *)
| Integer {data}, _ when Z.is_zero data -> poly
| Integer {data}, _ when Z.equal Z.zero data -> poly
(* (c × cᵢ) + cⱼ ==> c×cᵢ+cⱼ *)
| Integer {data= i}, Integer {data= j} ->
integer (Z.badd ~bits:(bits_of_int exp) Z.(coeff * i) j) typ
rational Q.((coeff * of_z i) + of_z j) typ
(* (c × ∑ᵢ cᵢ × Xᵢ) + s ==> (∑ᵢ (c × cᵢ) × Xᵢ) + s *)
| Add {args}, _ -> simp_add_ typ (Sum.mul_const coeff args) poly
(* (c₀ × X₀) + (∑ᵢ₌₁ⁿ cᵢ × Xᵢ) ==> ∑ᵢ₌₀ⁿ
@ -794,7 +790,7 @@ let rec simp_add_ typ es poly =
(* (c₁ × X₁) + X₂ ==> ∑ᵢ₌₁² cᵢ × Xᵢ for c₂ = 1 *)
| _ -> Sum.to_exp typ (Sum.add coeff exp (Sum.singleton poly))
in
Mset.fold ~f es ~init:poly
Qset.fold ~f es ~init:poly
let simp_add typ es = simp_add_ typ es (integer Z.zero typ)
let simp_add2 typ e f = simp_add_ typ (Sum.singleton e) f
@ -803,10 +799,10 @@ let simp_add2 typ e f = simp_add_ typ (Sum.singleton e) f
x^n of indeterminates x is represented by a multiset where the
elements are x and the multiplicities are the exponents n. *)
module Prod = struct
let empty = empty_mset
let add exp prod = Mset.add prod exp Z.one
let empty = empty_qset
let add exp prod = Qset.add prod exp Q.one
let singleton exp = add exp empty
let union = Mset.union
let union = Qset.union
end
(* map over each monomial of a polynomial *)
@ -825,16 +821,16 @@ let rec simp_mul2 typ e f =
| Integer {data= i}, Integer {data= j} ->
integer (Z.bmul ~bits:(bits_of_int e) i j) typ
(* 0 × f ==> 0 *)
| Integer {data}, _ when Z.is_zero data -> e
| Integer {data}, _ when Z.equal Z.zero data -> e
(* e × 0 ==> 0 *)
| _, Integer {data} when Z.is_zero data -> f
| _, Integer {data} when Z.equal Z.zero data -> f
(* c × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ c × cᵤ × ∏ⱼ
y *)
| Integer {data}, Add {args} | Add {args}, Integer {data} ->
Sum.to_exp typ (Sum.mul_const data args)
Sum.to_exp typ (Sum.mul_const (Q.of_z data) args)
(* c₁ × x₁ ==> ∑ᵢ₌₁ cᵢ × xᵢ *)
| Integer {data= c}, x | x, Integer {data= c} ->
Sum.to_exp typ (Sum.singleton ~coeff:c x)
Sum.to_exp typ (Sum.singleton ~coeff:(Q.of_z c) x)
(* (∏ᵤ₌₀ⁱ xᵤ) × (∏ᵥ₌ᵢ₊₁ⁿ xᵥ) ==>
x *)
| Mul {typ; args= xs1}, Mul {args= xs2} ->
@ -856,13 +852,13 @@ let rec simp_mul2 typ e f =
let simp_mul typ es =
(* (bas ^ pwr) × exp *)
let rec mul_pwr bas pwr exp =
if Z.is_zero pwr then exp
else mul_pwr bas (Z.pred pwr) (simp_mul2 typ bas exp)
if Q.equal Q.zero pwr then exp
else mul_pwr bas Q.(pwr - one) (simp_mul2 typ bas exp)
in
let one = integer Z.one typ in
Mset.fold es ~init:one ~f:(fun bas pwr exp ->
if Z.sign pwr >= 0 then mul_pwr bas pwr exp
else simp_div exp (mul_pwr bas (Z.neg pwr) one) )
Qset.fold es ~init:one ~f:(fun bas pwr exp ->
if Q.sign pwr >= 0 then mul_pwr bas pwr exp
else simp_div exp (mul_pwr bas (Q.neg pwr) one) )
let simp_negate typ x = simp_mul2 typ (integer Z.minus_one typ) x
@ -1010,7 +1006,7 @@ let simp_shl x y =
let bits = Option.value_exn (Typ.prim_bit_size_of typ) in
integer (Z.bshift_left ~bits i (Z.to_int j)) typ
(* e shl 0 ==> e *)
| e, Integer {data} when Z.is_zero data -> e
| e, Integer {data} when Z.equal Z.zero data -> e
| _ -> App {op= App {op= Shl; arg= x}; arg= y}
let simp_lshr x y =
@ -1020,7 +1016,7 @@ let simp_lshr x y =
let bits = Option.value_exn (Typ.prim_bit_size_of typ) in
integer (Z.bshift_right_trunc ~bits i (Z.to_int j)) typ
(* e lshr 0 ==> e *)
| e, Integer {data} when Z.is_zero data -> e
| e, Integer {data} when Z.equal Z.zero data -> e
| _ -> App {op= App {op= Lshr; arg= x}; arg= y}
let simp_ashr x y =
@ -1030,7 +1026,7 @@ let simp_ashr x y =
let bits = Option.value_exn (Typ.prim_bit_size_of typ) in
integer (Z.bshift_right ~bits i (Z.to_int j)) typ
(* e ashr 0 ==> e *)
| e, Integer {data} when Z.is_zero data -> e
| e, Integer {data} when Z.equal Z.zero data -> e
| _ -> App {op= App {op= Ashr; arg= x}; arg= y}
(** Access *)
@ -1038,7 +1034,7 @@ let simp_ashr x y =
let iter e ~f =
match e with
| App {op; arg} -> f op ; f arg
| Add {args} | Mul {args} -> Mset.iter ~f:(fun arg _ -> f arg) args
| Add {args} | Mul {args} -> Qset.iter ~f:(fun arg _ -> f arg) args
| _ -> ()
let fold e ~init:s ~f =
@ -1048,7 +1044,7 @@ let fold e ~init:s ~f =
let s = f s arg in
s
| Add {args} | Mul {args} ->
let s = Mset.fold ~f:(fun e _ s -> f s e) args ~init:s in
let s = Qset.fold ~f:(fun e _ s -> f s e) args ~init:s in
s
| _ -> s
@ -1173,8 +1169,8 @@ let convert ?(signed = false) ~dst ~src exp =
(** Transform *)
let map e ~f =
let map_mset mk typ ~f args =
let args' = Mset.map ~f:(fun arg z -> (f arg, z)) args in
let map_qset mk typ ~f args =
let args' = Qset.map ~f:(fun arg q -> (f arg, q)) args in
if args' == args then e else mk typ args'
in
match e with
@ -1182,16 +1178,16 @@ let map e ~f =
let op' = f op in
let arg' = f arg in
if op' == op && arg' == arg then e else app1 ~partial:true op' arg'
| Add {args; typ} -> map_mset addN typ ~f args
| Mul {args; typ} -> map_mset mulN typ ~f args
| Add {args; typ} -> map_qset addN typ ~f args
| Mul {args; typ} -> map_qset mulN typ ~f args
| _ -> e
let fold_map e ~init:s ~f =
let fold_map_mset mk typ ~f ~init args =
let fold_map_qset mk typ ~f ~init args =
let args', s =
Mset.fold_map args ~init ~f:(fun x z s ->
Qset.fold_map args ~init ~f:(fun x q s ->
let s, x' = f s x in
(x', z, s) )
(x', q, s) )
in
if args' == args then (s, e) else (s, mk typ args')
in
@ -1201,8 +1197,8 @@ let fold_map e ~init:s ~f =
let s, arg' = f s arg in
if op' == op && arg' == arg then (s, e)
else (s, app1 ~partial:true op' arg')
| Add {args; typ} -> fold_map_mset addN typ ~f args ~init:s
| Mul {args; typ} -> fold_map_mset mulN typ ~f args ~init:s
| Add {args; typ} -> fold_map_qset addN typ ~f args ~init:s
| Mul {args; typ} -> fold_map_qset mulN typ ~f args ~init:s
| _ -> (s, e)
let rename e sub =
@ -1218,22 +1214,22 @@ let rename e sub =
let offset e =
( match e with
| Add {typ; args} ->
let offset = Mset.count args (integer Z.one typ) in
if Z.is_zero offset then None else Some (offset, typ)
let offset = Qset.count args (integer Z.one typ) in
if Q.equal Q.zero offset then None else Some (offset, typ)
| _ -> None )
|> check (function
| Some (k, _) -> assert (not (Z.is_zero k))
| Some (k, _) -> assert (not (Q.equal Q.zero k))
| None -> () )
let base e =
( match e with
| Add {typ; args} -> (
let args = Mset.remove args (integer Z.one typ) in
match Mset.length args with
let args = Qset.remove args (integer Z.one typ) in
match Qset.length args with
| 0 -> integer Z.zero typ
| 1 -> (
match Mset.min_elt args with
| Some (arg, z) when Z.is_one z -> arg
match Qset.min_elt args with
| Some (arg, q) when Q.equal Q.one q -> arg
| _ -> Add {typ; args} )
| _ -> Add {typ; args} )
| _ -> e )
@ -1242,14 +1238,14 @@ let base e =
let base_offset e =
( match e with
| Add {typ; args} -> (
match Mset.count_and_remove args (integer Z.one typ) with
match Qset.count_and_remove args (integer Z.one typ) with
| Some (offset, args) ->
let base =
match Mset.length args with
match Qset.length args with
| 0 -> integer Z.zero typ
| 1 -> (
match Mset.min_elt args with
| Some (arg, z) when Z.is_one z -> arg
match Qset.min_elt args with
| Some (arg, q) when Q.equal Q.one q -> arg
| _ -> Add {typ; args} )
| _ -> Add {typ; args}
in
@ -1259,7 +1255,7 @@ let base_offset e =
|> check (function
| Some (b, k, _) ->
invariant b ;
assert (not (Z.is_zero k))
assert (not (Q.equal Q.zero k))
| None -> () )
(** Query *)
@ -1278,5 +1274,5 @@ let rec is_constant = function
| Var _ | Nondet _ -> false
| App {op; arg} -> is_constant arg && is_constant op
| Add {args} | Mul {args} ->
Mset.for_all ~f:(fun arg _ -> is_constant arg) args
Qset.for_all ~f:(fun arg _ -> is_constant arg) args
| _ -> true

@ -23,13 +23,13 @@
type comparator_witness
type mset = (t, comparator_witness) Mset.t
type qset = (t, comparator_witness) Qset.t
and t = private
| App of {op: t; arg: t}
(** Application of function symbol to argument, curried *)
| Add of {args: mset; typ: Typ.t} (** Addition *)
| Mul of {args: mset; typ: Typ.t} (** Multiplication *)
| Add of {args: qset; typ: Typ.t} (** Addition *)
| Mul of {args: qset; typ: Typ.t} (** Multiplication *)
| Var of {id: int; name: string} (** Local variable / virtual register *)
| Nondet of {msg: string}
(** Anonymous local variable with arbitrary value, representing
@ -141,6 +141,7 @@ val memory : siz:t -> arr:t -> t
val concat : t -> t -> t
val bool : bool -> t
val integer : Z.t -> Typ.t -> t
val rational : Q.t -> Typ.t -> t
val float : string -> t
val eq : t -> t -> t
val dq : t -> t -> t
@ -189,13 +190,13 @@ val convert : ?signed:bool -> dst:Typ.t -> src:Typ.t -> t -> t
(** Destruct *)
val base_offset : t -> (t * Z.t * Typ.t) option
val base_offset : t -> (t * Q.t * Typ.t) option
(** Decompose an addition of a constant "offset" to a "base" exp. *)
val base : t -> t
(** Like [base_offset] but does not construct the "offset" exp. *)
val offset : t -> (Z.t * Typ.t) option
val offset : t -> (Q.t * Typ.t) option
(** Like [base_offset] but does not construct the "base" exp. *)
(** Access *)

@ -238,13 +238,13 @@ let pp_diff fs (r, s) =
(** solve a+i = b for a, yielding a = b-i *)
let solve_for_base ai b =
match Exp.base_offset ai with
| Some (a, i, typ) -> (a, Exp.sub typ b (Exp.integer i typ))
| Some (a, i, typ) -> (a, Exp.sub typ b (Exp.rational i typ))
| None -> (ai, b)
(** subtract offset from both sides of equation a+i = b, yielding b-i *)
let subtract_offset ai b =
match Exp.offset ai with
| Some (i, typ) -> Exp.sub typ b (Exp.integer i typ)
| Some (i, typ) -> Exp.sub typ b (Exp.rational i typ)
| None -> b
(** [map_base ~f a+i] is [f(a) + i] and [map_base ~f a] is [f(a)] *)
@ -252,7 +252,7 @@ let map_base ai ~f =
match Exp.base_offset ai with
| Some (a, i, typ) ->
let a' = f a in
if a' == a then ai else Exp.add typ a' (Exp.integer i typ)
if a' == a then ai else Exp.add typ a' (Exp.rational i typ)
| None -> f ai
(** [norm_base r a] is [a'+k] where [r] implies [a = a'+k] and [a'] is a
@ -538,7 +538,7 @@ let rec norm_extend r ek =
Map.find_or_add r.rep e
~if_found:(fun e' ->
match Exp.offset ek with
| Some (k, typ) -> (r, Exp.add typ e' (Exp.integer k typ))
| Some (k, typ) -> (r, Exp.add typ e' (Exp.rational k typ))
| None -> (r, e') )
~default:e
~if_added:(fun rep ->

Loading…
Cancel
Save