[sledge] Change polynomial coefficients and powers to rationals

Summary:
Rational coefficients enable solving equations via Gaussian
elimination, and rational powers enable some normalization of
multiplication and division.

Reviewed By: ngorogiannis

Differential Revision: D14233213

fbshipit-source-id: acecb7f16
master
Josh Berdine 6 years ago committed by Facebook Github Bot
parent d01de4b0dd
commit 0ecee6a848

@ -265,19 +265,36 @@ module Set = struct
let to_tree = Using_comparator.to_tree let to_tree = Using_comparator.to_tree
end end
module Mset = struct module Qset = struct
include Mset include Qset
let pp sep pp_elt fs s = List.pp sep pp_elt fs (to_list s) let pp sep pp_elt fs s = List.pp sep pp_elt fs (to_list s)
end end
module Z = struct module Q = struct
include Z let pp = Q.pp_print
let hash = Hashtbl.hash
let hash_fold_t s q = Int.hash_fold_t s (hash q)
let sexp_of_t q = Sexp.Atom (Q.to_string q)
let t_of_sexp = function
| Sexp.Atom s -> Q.of_string s
| _ -> assert false
let of_z = Q.of_bigint
let hash_fold_t s z = Int.hash_fold_t s (Z.hash z) include Q
end
module Z = struct
let pp = Z.pp_print
let hash = [%hash: Z.t]
let hash_fold_t s z = Int.hash_fold_t s (hash z)
let sexp_of_t z = Sexp.Atom (Z.to_string z) let sexp_of_t z = Sexp.Atom (Z.to_string z)
let t_of_sexp = function let t_of_sexp = function
| Sexp.Atom s -> Z.of_string s | Sexp.Atom s -> Z.of_string s
| _ -> assert false | _ -> assert false
include Z
end end

@ -212,17 +212,31 @@ module Set : sig
val to_tree : ('e, 'c) t -> ('e, 'c) tree val to_tree : ('e, 'c) t -> ('e, 'c) tree
end end
module Mset : sig module Qset : sig
include module type of Mset include module type of Qset
val pp : (unit, unit) fmt -> ('a * Z.t) pp -> ('a, _) t pp val pp : (unit, unit) fmt -> ('a * Q.t) pp -> ('a, _) t pp
(** Pretty-print a multiset. *) end
module Q : sig
include module type of struct include Q end
val of_z : Z.t -> t
val compare : t -> t -> int
val hash : t -> int
val hash_fold_t : t Hash.folder
val t_of_sexp : Sexp.t -> t
val sexp_of_t : t -> Sexp.t
val pp : t pp
end end
module Z : sig module Z : sig
include module type of struct include Z end include module type of struct include Z end
val compare : t -> t -> int
val hash : t -> int
val hash_fold_t : t Hash.folder val hash_fold_t : t Hash.folder
val t_of_sexp : Sexp.t -> t val t_of_sexp : Sexp.t -> t
val sexp_of_t : t -> Sexp.t val sexp_of_t : t -> Sexp.t
val pp : t pp
end end

@ -5,12 +5,11 @@
* LICENSE file in the root directory of this source tree. * LICENSE file in the root directory of this source tree.
*) *)
(** Mset - Set with integer (positive, negative, or zero) multiplicity for (** Qset - Set with (signed) rational multiplicity for each element *)
each element *)
open Base open Base
type ('elt, 'cmp) t = ('elt, Z.t, 'cmp) Map.t type ('elt, 'cmp) t = ('elt, Q.t, 'cmp) Map.t
module M (Elt : sig module M (Elt : sig
type t type t
@ -33,23 +32,23 @@ end
module type Compare_m = sig end module type Compare_m = sig end
module type Hash_fold_m = Hasher.S module type Hash_fold_m = Hasher.S
let sexp_of_z z = Sexp.Atom (Z.to_string z) let sexp_of_q q = Sexp.Atom (Q.to_string q)
let z_of_sexp = function Sexp.Atom s -> Z.of_string s | _ -> assert false let q_of_sexp = function Sexp.Atom s -> Q.of_string s | _ -> assert false
let hash_fold_z state z = Hash.fold_int state (Z.hash z) let hash_fold_q state q = Hash.fold_int state (Hashtbl.hash q)
let sexp_of_m__t (type elt) (module Elt : Sexp_of_m with type t = elt) t = let sexp_of_m__t (type elt) (module Elt : Sexp_of_m with type t = elt) t =
Map.sexp_of_m__t (module Elt) sexp_of_z t Map.sexp_of_m__t (module Elt) sexp_of_q t
let m__t_of_sexp (type elt cmp) let m__t_of_sexp (type elt cmp)
(module Elt : M_of_sexp (module Elt : M_of_sexp
with type t = elt and type comparator_witness = cmp) sexp = with type t = elt and type comparator_witness = cmp) sexp =
Map.m__t_of_sexp (module Elt) z_of_sexp sexp Map.m__t_of_sexp (module Elt) q_of_sexp sexp
let compare_m__t (module Elt : Compare_m) = Map.compare_direct Z.compare let compare_m__t (module Elt : Compare_m) = Map.compare_direct Q.compare
let hash_fold_m__t (type elt) (module Elt : Hash_fold_m with type t = elt) let hash_fold_m__t (type elt) (module Elt : Hash_fold_m with type t = elt)
state = state =
Map.hash_fold_m__t (module Elt) hash_fold_z state Map.hash_fold_m__t (module Elt) hash_fold_q state
let hash_m__t (type elt) (module Elt : Hash_fold_m with type t = elt) = let hash_m__t (type elt) (module Elt : Hash_fold_m with type t = elt) =
Hash.of_fold (hash_fold_m__t (module Elt)) Hash.of_fold (hash_fold_m__t (module Elt))
@ -58,22 +57,22 @@ type ('elt, 'cmp) comparator =
(module Comparator.S with type t = 'elt and type comparator_witness = 'cmp) (module Comparator.S with type t = 'elt and type comparator_witness = 'cmp)
let empty cmp = Map.empty cmp let empty cmp = Map.empty cmp
let if_nz z = if Z.equal Z.zero z then None else Some z let if_nz q = if Q.equal Q.zero q then None else Some q
let add m x i = let add m x i =
Map.change m x ~f:(function Some j -> if_nz Z.(i + j) | None -> if_nz i) Map.change m x ~f:(function Some j -> if_nz Q.(i + j) | None -> if_nz i)
let remove m x = Map.remove m x let remove m x = Map.remove m x
let union m n = let union m n =
Map.merge m n ~f:(fun ~key:_ -> function Map.merge m n ~f:(fun ~key:_ -> function
| `Both (i, j) -> if_nz Z.(i + j) | `Left i | `Right i -> Some i ) | `Both (i, j) -> if_nz Q.(i + j) | `Left i | `Right i -> Some i )
let length m = Map.length m let length m = Map.length m
let count m x = match Map.find m x with Some z -> z | None -> Z.zero let count m x = match Map.find m x with Some q -> q | None -> Q.zero
let count_and_remove m x = let count_and_remove m x =
let found = ref Z.zero in let found = ref Q.zero in
let m = let m =
Map.change m x ~f:(function Map.change m x ~f:(function
| None -> None | None -> None
@ -81,7 +80,7 @@ let count_and_remove m x =
found := i ; found := i ;
None ) None )
in in
if Z.equal !found Z.zero then None else Some (!found, m) if Q.equal !found Q.zero then None else Some (!found, m)
let min_elt = Map.min_elt let min_elt = Map.min_elt
let fold m ~f ~init = Map.fold m ~f:(fun ~key ~data s -> f key data s) ~init let fold m ~f ~init = Map.fold m ~f:(fun ~key ~data s -> f key data s) ~init
@ -90,14 +89,14 @@ let map m ~f =
fold m ~init:m ~f:(fun x i m -> fold m ~init:m ~f:(fun x i m ->
let x', i' = f x i in let x', i' = f x i in
if phys_equal x' x then if phys_equal x' x then
if Z.equal i' i then m else Map.set m ~key:x ~data:i' if Q.equal i' i then m else Map.set m ~key:x ~data:i'
else add (Map.remove m x) x' i' ) else add (Map.remove m x) x' i' )
let fold_map m ~f ~init:s = let fold_map m ~f ~init:s =
fold m ~init:(m, s) ~f:(fun x i (m, s) -> fold m ~init:(m, s) ~f:(fun x i (m, s) ->
let x', i', s = f x i s in let x', i', s = f x i s in
if phys_equal x' x then if phys_equal x' x then
if Z.equal i' i then (m, s) else (Map.set m ~key:x ~data:i', s) if Q.equal i' i then (m, s) else (Map.set m ~key:x ~data:i', s)
else (add (Map.remove m x) x' i', s) ) else (add (Map.remove m x) x' i', s) )
let for_all m ~f = Map.for_alli m ~f:(fun ~key ~data -> f key data) let for_all m ~f = Map.for_alli m ~f:(fun ~key ~data -> f key data)

@ -5,8 +5,7 @@
* LICENSE file in the root directory of this source tree. * LICENSE file in the root directory of this source tree.
*) *)
(** Mset - Set with integer (positive, negative, or zero) multiplicity for (** Qset - Set with (signed) rational multiplicity for each element *)
each element *)
open Base open Base
@ -58,7 +57,7 @@ val hash_m__t :
val empty : ('elt, 'cmp) comparator -> ('elt, 'cmp) t val empty : ('elt, 'cmp) comparator -> ('elt, 'cmp) t
(** The empty multiset over the provided order. *) (** The empty multiset over the provided order. *)
val add : ('a, 'c) t -> 'a -> Z.t -> ('a, 'c) t val add : ('a, 'c) t -> 'a -> Q.t -> ('a, 'c) t
(** Add to multiplicity of single element. [O(log n)] *) (** Add to multiplicity of single element. [O(log n)] *)
val remove : ('a, 'c) t -> 'a -> ('a, 'c) t val remove : ('a, 'c) t -> 'a -> ('a, 'c) t
@ -70,39 +69,39 @@ val union : ('a, 'c) t -> ('a, 'c) t -> ('a, 'c) t
val length : _ t -> int val length : _ t -> int
(** Number of elements with non-zero multiplicity. [O(1)]. *) (** Number of elements with non-zero multiplicity. [O(1)]. *)
val count : ('a, _) t -> 'a -> Z.t val count : ('a, _) t -> 'a -> Q.t
(** Multiplicity of an element. [O(log n)]. *) (** Multiplicity of an element. [O(log n)]. *)
val count_and_remove : ('a, 'c) t -> 'a -> (Z.t * ('a, 'c) t) option val count_and_remove : ('a, 'c) t -> 'a -> (Q.t * ('a, 'c) t) option
(** Multiplicity of an element, and remove it. [O(log n)]. *) (** Multiplicity of an element, and remove it. [O(log n)]. *)
val map : ('a, 'c) t -> f:('a -> Z.t -> 'a * Z.t) -> ('a, 'c) t val map : ('a, 'c) t -> f:('a -> Q.t -> 'a * Q.t) -> ('a, 'c) t
(** Map over the elements in ascending order. Preserves physical equality if (** Map over the elements in ascending order. Preserves physical equality if
[f] does. *) [f] does. *)
val map_counts : ('a, 'c) t -> f:('a -> Z.t -> Z.t) -> ('a, 'c) t val map_counts : ('a, 'c) t -> f:('a -> Q.t -> Q.t) -> ('a, 'c) t
(** Map over the multiplicities of the elements in ascending order. *) (** Map over the multiplicities of the elements in ascending order. *)
val fold : ('a, _) t -> f:('a -> Z.t -> 's -> 's) -> init:'s -> 's val fold : ('a, _) t -> f:('a -> Q.t -> 's -> 's) -> init:'s -> 's
(** Fold over the elements in ascending order. *) (** Fold over the elements in ascending order. *)
val fold_map : val fold_map :
('a, 'c) t ('a, 'c) t
-> f:('a -> Z.t -> 's -> 'a * Z.t * 's) -> f:('a -> Q.t -> 's -> 'a * Q.t * 's)
-> init:'s -> init:'s
-> ('a, 'c) t * 's -> ('a, 'c) t * 's
(** Folding map over the elements in ascending order. Preserves physical (** Folding map over the elements in ascending order. Preserves physical
equality if [f] does. *) equality if [f] does. *)
val for_all : ('a, _) t -> f:('a -> Z.t -> bool) -> bool val for_all : ('a, _) t -> f:('a -> Q.t -> bool) -> bool
(** Universal property test. [O(n)] but returns as soon as a violation is (** Universal property test. [O(n)] but returns as soon as a violation is
found, in ascending order. *) found, in ascending order. *)
val iter : ('a, _) t -> f:('a -> Z.t -> unit) -> unit val iter : ('a, _) t -> f:('a -> Q.t -> unit) -> unit
(** Iterate over the elements in ascending order. *) (** Iterate over the elements in ascending order. *)
val min_elt : ('a, _) t -> ('a * Z.t) option val min_elt : ('a, _) t -> ('a * Q.t) option
(** Minimum element. *) (** Minimum element. *)
val to_list : ('a, _) t -> ('a * Z.t) list val to_list : ('a, _) t -> ('a * Q.t) list
(** Convert to a list of elements in ascending order. *) (** Convert to a list of elements in ascending order. *)

@ -9,13 +9,7 @@
(** Z wrapped to treat bounded and unsigned operations *) (** Z wrapped to treat bounded and unsigned operations *)
module Z = struct module Z = struct
type t = Z.t [@@deriving compare, hash, sexp] include Z
include (Z : module type of Z with type t := t)
let pp = Z.pp_print
let is_zero = Z.equal zero
let is_one = Z.equal one
(* the signed 1-bit integers are -1 and 0 *) (* the signed 1-bit integers are -1 and 0 *)
let true_ = Z.minus_one let true_ = Z.minus_one
@ -43,7 +37,6 @@ module Z = struct
let bugeq ~bits x y = clamp_cmp ~signed:false bits Z.geq x y let bugeq ~bits x y = clamp_cmp ~signed:false bits Z.geq x y
let bult ~bits x y = clamp_cmp ~signed:false bits Z.lt x y let bult ~bits x y = clamp_cmp ~signed:false bits Z.lt x y
let bugt ~bits x y = clamp_cmp ~signed:false bits Z.gt x y let bugt ~bits x y = clamp_cmp ~signed:false bits Z.gt x y
let badd ~bits x y = clamp_bop ~signed:true bits Z.add x y
let bsub ~bits x y = clamp_bop ~signed:true bits Z.sub x y let bsub ~bits x y = clamp_bop ~signed:true bits Z.sub x y
let bmul ~bits x y = clamp_bop ~signed:true bits Z.mul x y let bmul ~bits x y = clamp_bop ~signed:true bits Z.mul x y
let bdiv ~bits x y = clamp_bop ~signed:true bits Z.div x y let bdiv ~bits x y = clamp_bop ~signed:true bits Z.div x y
@ -61,13 +54,13 @@ module Z = struct
end end
module rec T : sig module rec T : sig
type mset = Mset.M(T).t [@@deriving compare, hash, sexp] type qset = Qset.M(T).t [@@deriving compare, hash, sexp]
type t = type t =
| App of {op: t; arg: t} | App of {op: t; arg: t}
(* nary: arithmetic, numeric and pointer *) (* nary: arithmetic, numeric and pointer *)
| Add of {args: mset; typ: Typ.t} | Add of {args: qset; typ: Typ.t}
| Mul of {args: mset; typ: Typ.t} | Mul of {args: qset; typ: Typ.t}
| Var of {id: int; name: string} | Var of {id: int; name: string}
| Nondet of {msg: string} | Nondet of {msg: string}
| Label of {parent: string; name: string} | Label of {parent: string; name: string}
@ -124,12 +117,12 @@ end
(* auxiliary definition for safe recursive module initialization *) (* auxiliary definition for safe recursive module initialization *)
and T0 : sig and T0 : sig
type mset = Mset.M(T).t [@@deriving compare, hash, sexp] type qset = Qset.M(T).t [@@deriving compare, hash, sexp]
type t = type t =
| App of {op: t; arg: t} | App of {op: t; arg: t}
| Add of {args: mset; typ: Typ.t} | Add of {args: qset; typ: Typ.t}
| Mul of {args: mset; typ: Typ.t} | Mul of {args: qset; typ: Typ.t}
| Var of {id: int; name: string} | Var of {id: int; name: string}
| Nondet of {msg: string} | Nondet of {msg: string}
| Label of {parent: string; name: string} | Label of {parent: string; name: string}
@ -168,12 +161,12 @@ and T0 : sig
| Convert of {signed: bool; dst: Typ.t; src: Typ.t} | Convert of {signed: bool; dst: Typ.t; src: Typ.t}
[@@deriving compare, hash, sexp] [@@deriving compare, hash, sexp]
end = struct end = struct
type mset = Mset.M(T).t [@@deriving compare, hash, sexp] type qset = Qset.M(T).t [@@deriving compare, hash, sexp]
type t = type t =
| App of {op: t; arg: t} | App of {op: t; arg: t}
| Add of {args: mset; typ: Typ.t} | Add of {args: qset; typ: Typ.t}
| Mul of {args: mset; typ: Typ.t} | Mul of {args: qset; typ: Typ.t}
| Var of {id: int; name: string} | Var of {id: int; name: string}
| Nondet of {msg: string} | Nondet of {msg: string}
| Label of {parent: string; name: string} | Label of {parent: string; name: string}
@ -218,7 +211,7 @@ type _t = T0.t
include T include T
let empty_mset = Mset.empty (module T) let empty_qset = Qset.empty (module T)
let equal = [%compare.equal: t] let equal = [%compare.equal: t]
let sorted e f = compare e f <= 0 let sorted e f = compare e f <= 0
let sort e f = if sorted e f then (e, f) else (f, e) let sort e f = if sorted e f then (e, f) else (f, e)
@ -257,7 +250,7 @@ let rec pp fs exp =
| Var {name; id} -> pf "%%%s_%d" name id | Var {name; id} -> pf "%%%s_%d" name id
| Nondet {msg} -> pf "nondet \"%s\"" msg | Nondet {msg} -> pf "nondet \"%s\"" msg
| Label {name} -> pf "%s" name | Label {name} -> pf "%s" name
| Integer {data; typ= Pointer _} when Z.is_zero data -> pf "null" | Integer {data; typ= Pointer _} when Z.equal Z.zero data -> pf "null"
| Splat -> pf "^" | Splat -> pf "^"
| Memory -> pf "⟨_,_⟩" | Memory -> pf "⟨_,_⟩"
| App {op= Memory; arg= siz} -> pf "@<1>⟨%a,_@<1>⟩" pp siz | App {op= Memory; arg= siz} -> pf "@<1>⟨%a,_@<1>⟩" pp siz
@ -281,18 +274,18 @@ let rec pp fs exp =
| Add {args} -> | Add {args} ->
let pp_poly_term fs (monomial, coefficient) = let pp_poly_term fs (monomial, coefficient) =
match monomial with match monomial with
| Integer {data} when Z.is_one data -> Z.pp fs coefficient | Integer {data} when Z.equal Z.one data -> Q.pp fs coefficient
| _ when Z.is_one coefficient -> pp fs monomial | _ when Q.equal Q.one coefficient -> pp fs monomial
| _ -> | _ ->
Format.fprintf fs "%a @<1>× %a" Z.pp coefficient pp monomial Format.fprintf fs "%a @<1>× %a" Q.pp coefficient pp monomial
in in
pf "(%a)" (Mset.pp "@ + " pp_poly_term) args pf "(%a)" (Qset.pp "@ + " pp_poly_term) args
| Mul {args} -> | Mul {args} ->
let pp_mono_term fs (factor, exponent) = let pp_mono_term fs (factor, exponent) =
if Z.is_one exponent then pp fs factor if Q.equal Q.one exponent then pp fs factor
else Format.fprintf fs "%a^%a" pp factor Z.pp exponent else Format.fprintf fs "%a^%a" pp factor Q.pp exponent
in in
pf "(%a)" (Mset.pp "@ @<2>× " pp_mono_term) args pf "(%a)" (Qset.pp "@ @<2>× " pp_mono_term) args
| Div -> pf "/" | Div -> pf "/"
| Udiv -> pf "udiv" | Udiv -> pf "udiv"
| Rem -> pf "rem" | Rem -> pf "rem"
@ -389,8 +382,8 @@ let assert_monomial add_typ mono =
| Mul {typ; args} -> | Mul {typ; args} ->
assert (Typ.castable add_typ typ) ; assert (Typ.castable add_typ typ) ;
assert (Option.exists ~f:(fun n -> 1 < n) (Typ.prim_bit_size_of typ)) ; assert (Option.exists ~f:(fun n -> 1 < n) (Typ.prim_bit_size_of typ)) ;
Mset.iter args ~f:(fun factor exponent -> Qset.iter args ~f:(fun factor exponent ->
assert (Z.sign exponent > 0) ; assert (Q.sign exponent > 0) ;
assert_indeterminate factor |> Fn.id ) assert_indeterminate factor |> Fn.id )
| _ -> assert_indeterminate mono |> Fn.id | _ -> assert_indeterminate mono |> Fn.id
@ -398,12 +391,12 @@ let assert_monomial add_typ mono =
* c × x * c × x
*) *)
let assert_poly_term add_typ mono coeff = let assert_poly_term add_typ mono coeff =
assert (not (Z.is_zero coeff)) ; assert (not (Q.equal Q.zero coeff)) ;
match mono with match mono with
| Integer {data} -> assert (Z.is_one data) | Integer {data} -> assert (Z.equal Z.one data)
| Mul {args} -> | Mul {args} ->
if Z.is_one coeff then assert (Mset.length args > 1) if Q.equal Q.one coeff then assert (Qset.length args > 1)
else assert (Mset.length args > 0) ; else assert (Qset.length args > 0) ;
assert_monomial add_typ mono |> Fn.id assert_monomial add_typ mono |> Fn.id
| _ -> assert_monomial add_typ mono |> Fn.id | _ -> assert_monomial add_typ mono |> Fn.id
@ -415,15 +408,15 @@ let assert_poly_term add_typ mono coeff =
let assert_polynomial poly = let assert_polynomial poly =
match poly with match poly with
| Add {typ; args} -> | Add {typ; args} ->
( match Mset.length args with ( match Qset.length args with
| 0 -> assert false | 0 -> assert false
| 1 -> ( | 1 -> (
match Mset.min_elt args with match Qset.min_elt args with
| Some (Integer _, _) -> assert false | Some (Integer _, _) -> assert false
| Some (_, k) -> assert (not (Z.is_one k)) | Some (_, k) -> assert (not (Q.equal Q.one k))
| _ -> () ) | _ -> () )
| _ -> () ) ; | _ -> () ) ;
Mset.iter args ~f:(fun m c -> assert_poly_term typ m c |> Fn.id) Qset.iter args ~f:(fun m c -> assert_poly_term typ m c |> Fn.id)
| _ -> assert false | _ -> assert false
let invariant ?(partial = false) e = let invariant ?(partial = false) e =
@ -616,7 +609,7 @@ let fold_exps e ~init ~f =
match e with match e with
| App {op; arg} -> fold_exps_ op (fold_exps_ arg z) | App {op; arg} -> fold_exps_ op (fold_exps_ arg z)
| Add {args} | Mul {args} -> | Add {args} | Mul {args} ->
Mset.fold args ~init:z ~f:(fun arg _ z -> fold_exps_ arg z) Qset.fold args ~init:z ~f:(fun arg _ z -> fold_exps_ arg z)
| Struct_rec {elts} -> | Struct_rec {elts} ->
Vector.fold elts ~init:z ~f:(fun z elt -> fold_exps_ elt z) Vector.fold elts ~init:z ~f:(fun z elt -> fold_exps_ elt z)
| _ -> z | _ -> z
@ -705,7 +698,7 @@ let simp_div x y =
let bits = Option.value_exn (Typ.prim_bit_size_of typ) in let bits = Option.value_exn (Typ.prim_bit_size_of typ) in
integer (Z.bdiv ~bits i j) typ integer (Z.bdiv ~bits i j) typ
(* e / 1 ==> e *) (* e / 1 ==> e *)
| e, Integer {data} when Z.is_one data -> e | e, Integer {data} when Z.equal Z.one data -> e
| _ -> App {op= App {op= Div; arg= x}; arg= y} | _ -> App {op= App {op= Div; arg= x}; arg= y}
let simp_udiv x y = let simp_udiv x y =
@ -715,7 +708,7 @@ let simp_udiv x y =
let bits = Option.value_exn (Typ.prim_bit_size_of typ) in let bits = Option.value_exn (Typ.prim_bit_size_of typ) in
integer (Z.budiv ~bits i j) typ integer (Z.budiv ~bits i j) typ
(* e u/ 1 ==> e *) (* e u/ 1 ==> e *)
| e, Integer {data} when Z.is_one data -> e | e, Integer {data} when Z.equal Z.one data -> e
| _ -> App {op= App {op= Udiv; arg= x}; arg= y} | _ -> App {op= App {op= Udiv; arg= x}; arg= y}
let simp_rem x y = let simp_rem x y =
@ -725,7 +718,7 @@ let simp_rem x y =
let bits = Option.value_exn (Typ.prim_bit_size_of typ) in let bits = Option.value_exn (Typ.prim_bit_size_of typ) in
integer (Z.brem ~bits i j) typ integer (Z.brem ~bits i j) typ
(* e % 1 ==> 0 *) (* e % 1 ==> 0 *)
| _, Integer {data; typ} when Z.is_one data -> integer Z.zero typ | _, Integer {data; typ} when Z.equal Z.one data -> integer Z.zero typ
| _ -> App {op= App {op= Rem; arg= x}; arg= y} | _ -> App {op= App {op= Rem; arg= x}; arg= y}
let simp_urem x y = let simp_urem x y =
@ -735,42 +728,45 @@ let simp_urem x y =
let bits = Option.value_exn (Typ.prim_bit_size_of typ) in let bits = Option.value_exn (Typ.prim_bit_size_of typ) in
integer (Z.burem ~bits i j) typ integer (Z.burem ~bits i j) typ
(* e u% 1 ==> 0 *) (* e u% 1 ==> 0 *)
| _, Integer {data; typ} when Z.is_one data -> integer Z.zero typ | _, Integer {data; typ} when Z.equal Z.one data -> integer Z.zero typ
| _ -> App {op= App {op= Urem; arg= x}; arg= y} | _ -> App {op= App {op= Urem; arg= x}; arg= y}
let rational Q.({num; den}) typ =
simp_div (integer num typ) (integer den typ)
(* Sums of polynomial terms represented by multisets. A sum ∑ᵢ cᵢ × (* Sums of polynomial terms represented by multisets. A sum ∑ᵢ cᵢ ×
X of monomials X with coefficients c is represented by a X of monomials X with coefficients c is represented by a
multiset where the elements are X with multiplicities c. A constant multiset where the elements are X with multiplicities c. A constant
is treated as the coefficient of the empty monomial, which is the unit of is treated as the coefficient of the empty monomial, which is the unit of
multiplication 1. *) multiplication 1. *)
module Sum = struct module Sum = struct
let empty = empty_mset let empty = empty_qset
let add coeff exp sum = let add coeff exp sum =
assert (not (Z.is_zero coeff)) ; assert (not (Q.equal Q.zero coeff)) ;
match exp with match exp with
| Integer {data} when Z.is_zero data -> sum | Integer {data} when Z.equal Z.zero data -> sum
| Integer {data; typ} -> | Integer {data; typ} ->
Mset.add sum (integer Z.one typ) Z.(coeff * data) Qset.add sum (integer Z.one typ) Q.(coeff * of_z data)
| _ -> Mset.add sum exp coeff | _ -> Qset.add sum exp coeff
let singleton ?(coeff = Z.one) exp = add coeff exp empty let singleton ?(coeff = Q.one) exp = add coeff exp empty
let map sum ~f = let map sum ~f =
Mset.fold sum ~init:empty ~f:(fun e c sum -> add c (f e) sum) Qset.fold sum ~init:empty ~f:(fun e c sum -> add c (f e) sum)
let mul_const const sum = let mul_const const sum =
assert (not (Z.is_zero const)) ; assert (not (Q.equal Q.zero const)) ;
if Z.is_one const then sum if Q.equal Q.one const then sum
else Mset.map_counts ~f:(fun _ -> Z.mul const) sum else Qset.map_counts ~f:(fun _ -> Q.mul const) sum
let to_exp typ sum = let to_exp typ sum =
match Mset.length sum with match Qset.length sum with
| 0 -> integer Z.zero typ | 0 -> integer Z.zero typ
| 1 -> ( | 1 -> (
match Mset.min_elt sum with match Qset.min_elt sum with
| Some (Integer _, z) -> integer z typ | Some (Integer _, q) -> rational q typ
| Some (arg, z) when Z.is_one z -> arg | Some (arg, q) when Q.equal Q.one q -> arg
| _ -> Add {typ; args= sum} ) | _ -> Add {typ; args= sum} )
| _ -> Add {typ; args= sum} | _ -> Add {typ; args= sum}
end end
@ -780,12 +776,12 @@ let rec simp_add_ typ es poly =
let f exp coeff poly = let f exp coeff poly =
match (exp, poly) with match (exp, poly) with
(* (0 × e) + s ==> 0 (optim) *) (* (0 × e) + s ==> 0 (optim) *)
| _ when Z.is_zero coeff -> poly | _ when Q.equal Q.zero coeff -> poly
(* (c × 0) + s ==> s (optim) *) (* (c × 0) + s ==> s (optim) *)
| Integer {data}, _ when Z.is_zero data -> poly | Integer {data}, _ when Z.equal Z.zero data -> poly
(* (c × cᵢ) + cⱼ ==> c×cᵢ+cⱼ *) (* (c × cᵢ) + cⱼ ==> c×cᵢ+cⱼ *)
| Integer {data= i}, Integer {data= j} -> | Integer {data= i}, Integer {data= j} ->
integer (Z.badd ~bits:(bits_of_int exp) Z.(coeff * i) j) typ rational Q.((coeff * of_z i) + of_z j) typ
(* (c × ∑ᵢ cᵢ × Xᵢ) + s ==> (∑ᵢ (c × cᵢ) × Xᵢ) + s *) (* (c × ∑ᵢ cᵢ × Xᵢ) + s ==> (∑ᵢ (c × cᵢ) × Xᵢ) + s *)
| Add {args}, _ -> simp_add_ typ (Sum.mul_const coeff args) poly | Add {args}, _ -> simp_add_ typ (Sum.mul_const coeff args) poly
(* (c₀ × X₀) + (∑ᵢ₌₁ⁿ cᵢ × Xᵢ) ==> ∑ᵢ₌₀ⁿ (* (c₀ × X₀) + (∑ᵢ₌₁ⁿ cᵢ × Xᵢ) ==> ∑ᵢ₌₀ⁿ
@ -794,7 +790,7 @@ let rec simp_add_ typ es poly =
(* (c₁ × X₁) + X₂ ==> ∑ᵢ₌₁² cᵢ × Xᵢ for c₂ = 1 *) (* (c₁ × X₁) + X₂ ==> ∑ᵢ₌₁² cᵢ × Xᵢ for c₂ = 1 *)
| _ -> Sum.to_exp typ (Sum.add coeff exp (Sum.singleton poly)) | _ -> Sum.to_exp typ (Sum.add coeff exp (Sum.singleton poly))
in in
Mset.fold ~f es ~init:poly Qset.fold ~f es ~init:poly
let simp_add typ es = simp_add_ typ es (integer Z.zero typ) let simp_add typ es = simp_add_ typ es (integer Z.zero typ)
let simp_add2 typ e f = simp_add_ typ (Sum.singleton e) f let simp_add2 typ e f = simp_add_ typ (Sum.singleton e) f
@ -803,10 +799,10 @@ let simp_add2 typ e f = simp_add_ typ (Sum.singleton e) f
x^n of indeterminates x is represented by a multiset where the x^n of indeterminates x is represented by a multiset where the
elements are x and the multiplicities are the exponents n. *) elements are x and the multiplicities are the exponents n. *)
module Prod = struct module Prod = struct
let empty = empty_mset let empty = empty_qset
let add exp prod = Mset.add prod exp Z.one let add exp prod = Qset.add prod exp Q.one
let singleton exp = add exp empty let singleton exp = add exp empty
let union = Mset.union let union = Qset.union
end end
(* map over each monomial of a polynomial *) (* map over each monomial of a polynomial *)
@ -825,16 +821,16 @@ let rec simp_mul2 typ e f =
| Integer {data= i}, Integer {data= j} -> | Integer {data= i}, Integer {data= j} ->
integer (Z.bmul ~bits:(bits_of_int e) i j) typ integer (Z.bmul ~bits:(bits_of_int e) i j) typ
(* 0 × f ==> 0 *) (* 0 × f ==> 0 *)
| Integer {data}, _ when Z.is_zero data -> e | Integer {data}, _ when Z.equal Z.zero data -> e
(* e × 0 ==> 0 *) (* e × 0 ==> 0 *)
| _, Integer {data} when Z.is_zero data -> f | _, Integer {data} when Z.equal Z.zero data -> f
(* c × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ c × cᵤ × ∏ⱼ (* c × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ c × cᵤ × ∏ⱼ
y *) y *)
| Integer {data}, Add {args} | Add {args}, Integer {data} -> | Integer {data}, Add {args} | Add {args}, Integer {data} ->
Sum.to_exp typ (Sum.mul_const data args) Sum.to_exp typ (Sum.mul_const (Q.of_z data) args)
(* c₁ × x₁ ==> ∑ᵢ₌₁ cᵢ × xᵢ *) (* c₁ × x₁ ==> ∑ᵢ₌₁ cᵢ × xᵢ *)
| Integer {data= c}, x | x, Integer {data= c} -> | Integer {data= c}, x | x, Integer {data= c} ->
Sum.to_exp typ (Sum.singleton ~coeff:c x) Sum.to_exp typ (Sum.singleton ~coeff:(Q.of_z c) x)
(* (∏ᵤ₌₀ⁱ xᵤ) × (∏ᵥ₌ᵢ₊₁ⁿ xᵥ) ==> (* (∏ᵤ₌₀ⁱ xᵤ) × (∏ᵥ₌ᵢ₊₁ⁿ xᵥ) ==>
x *) x *)
| Mul {typ; args= xs1}, Mul {args= xs2} -> | Mul {typ; args= xs1}, Mul {args= xs2} ->
@ -856,13 +852,13 @@ let rec simp_mul2 typ e f =
let simp_mul typ es = let simp_mul typ es =
(* (bas ^ pwr) × exp *) (* (bas ^ pwr) × exp *)
let rec mul_pwr bas pwr exp = let rec mul_pwr bas pwr exp =
if Z.is_zero pwr then exp if Q.equal Q.zero pwr then exp
else mul_pwr bas (Z.pred pwr) (simp_mul2 typ bas exp) else mul_pwr bas Q.(pwr - one) (simp_mul2 typ bas exp)
in in
let one = integer Z.one typ in let one = integer Z.one typ in
Mset.fold es ~init:one ~f:(fun bas pwr exp -> Qset.fold es ~init:one ~f:(fun bas pwr exp ->
if Z.sign pwr >= 0 then mul_pwr bas pwr exp if Q.sign pwr >= 0 then mul_pwr bas pwr exp
else simp_div exp (mul_pwr bas (Z.neg pwr) one) ) else simp_div exp (mul_pwr bas (Q.neg pwr) one) )
let simp_negate typ x = simp_mul2 typ (integer Z.minus_one typ) x let simp_negate typ x = simp_mul2 typ (integer Z.minus_one typ) x
@ -1010,7 +1006,7 @@ let simp_shl x y =
let bits = Option.value_exn (Typ.prim_bit_size_of typ) in let bits = Option.value_exn (Typ.prim_bit_size_of typ) in
integer (Z.bshift_left ~bits i (Z.to_int j)) typ integer (Z.bshift_left ~bits i (Z.to_int j)) typ
(* e shl 0 ==> e *) (* e shl 0 ==> e *)
| e, Integer {data} when Z.is_zero data -> e | e, Integer {data} when Z.equal Z.zero data -> e
| _ -> App {op= App {op= Shl; arg= x}; arg= y} | _ -> App {op= App {op= Shl; arg= x}; arg= y}
let simp_lshr x y = let simp_lshr x y =
@ -1020,7 +1016,7 @@ let simp_lshr x y =
let bits = Option.value_exn (Typ.prim_bit_size_of typ) in let bits = Option.value_exn (Typ.prim_bit_size_of typ) in
integer (Z.bshift_right_trunc ~bits i (Z.to_int j)) typ integer (Z.bshift_right_trunc ~bits i (Z.to_int j)) typ
(* e lshr 0 ==> e *) (* e lshr 0 ==> e *)
| e, Integer {data} when Z.is_zero data -> e | e, Integer {data} when Z.equal Z.zero data -> e
| _ -> App {op= App {op= Lshr; arg= x}; arg= y} | _ -> App {op= App {op= Lshr; arg= x}; arg= y}
let simp_ashr x y = let simp_ashr x y =
@ -1030,7 +1026,7 @@ let simp_ashr x y =
let bits = Option.value_exn (Typ.prim_bit_size_of typ) in let bits = Option.value_exn (Typ.prim_bit_size_of typ) in
integer (Z.bshift_right ~bits i (Z.to_int j)) typ integer (Z.bshift_right ~bits i (Z.to_int j)) typ
(* e ashr 0 ==> e *) (* e ashr 0 ==> e *)
| e, Integer {data} when Z.is_zero data -> e | e, Integer {data} when Z.equal Z.zero data -> e
| _ -> App {op= App {op= Ashr; arg= x}; arg= y} | _ -> App {op= App {op= Ashr; arg= x}; arg= y}
(** Access *) (** Access *)
@ -1038,7 +1034,7 @@ let simp_ashr x y =
let iter e ~f = let iter e ~f =
match e with match e with
| App {op; arg} -> f op ; f arg | App {op; arg} -> f op ; f arg
| Add {args} | Mul {args} -> Mset.iter ~f:(fun arg _ -> f arg) args | Add {args} | Mul {args} -> Qset.iter ~f:(fun arg _ -> f arg) args
| _ -> () | _ -> ()
let fold e ~init:s ~f = let fold e ~init:s ~f =
@ -1048,7 +1044,7 @@ let fold e ~init:s ~f =
let s = f s arg in let s = f s arg in
s s
| Add {args} | Mul {args} -> | Add {args} | Mul {args} ->
let s = Mset.fold ~f:(fun e _ s -> f s e) args ~init:s in let s = Qset.fold ~f:(fun e _ s -> f s e) args ~init:s in
s s
| _ -> s | _ -> s
@ -1173,8 +1169,8 @@ let convert ?(signed = false) ~dst ~src exp =
(** Transform *) (** Transform *)
let map e ~f = let map e ~f =
let map_mset mk typ ~f args = let map_qset mk typ ~f args =
let args' = Mset.map ~f:(fun arg z -> (f arg, z)) args in let args' = Qset.map ~f:(fun arg q -> (f arg, q)) args in
if args' == args then e else mk typ args' if args' == args then e else mk typ args'
in in
match e with match e with
@ -1182,16 +1178,16 @@ let map e ~f =
let op' = f op in let op' = f op in
let arg' = f arg in let arg' = f arg in
if op' == op && arg' == arg then e else app1 ~partial:true op' arg' if op' == op && arg' == arg then e else app1 ~partial:true op' arg'
| Add {args; typ} -> map_mset addN typ ~f args | Add {args; typ} -> map_qset addN typ ~f args
| Mul {args; typ} -> map_mset mulN typ ~f args | Mul {args; typ} -> map_qset mulN typ ~f args
| _ -> e | _ -> e
let fold_map e ~init:s ~f = let fold_map e ~init:s ~f =
let fold_map_mset mk typ ~f ~init args = let fold_map_qset mk typ ~f ~init args =
let args', s = let args', s =
Mset.fold_map args ~init ~f:(fun x z s -> Qset.fold_map args ~init ~f:(fun x q s ->
let s, x' = f s x in let s, x' = f s x in
(x', z, s) ) (x', q, s) )
in in
if args' == args then (s, e) else (s, mk typ args') if args' == args then (s, e) else (s, mk typ args')
in in
@ -1201,8 +1197,8 @@ let fold_map e ~init:s ~f =
let s, arg' = f s arg in let s, arg' = f s arg in
if op' == op && arg' == arg then (s, e) if op' == op && arg' == arg then (s, e)
else (s, app1 ~partial:true op' arg') else (s, app1 ~partial:true op' arg')
| Add {args; typ} -> fold_map_mset addN typ ~f args ~init:s | Add {args; typ} -> fold_map_qset addN typ ~f args ~init:s
| Mul {args; typ} -> fold_map_mset mulN typ ~f args ~init:s | Mul {args; typ} -> fold_map_qset mulN typ ~f args ~init:s
| _ -> (s, e) | _ -> (s, e)
let rename e sub = let rename e sub =
@ -1218,22 +1214,22 @@ let rename e sub =
let offset e = let offset e =
( match e with ( match e with
| Add {typ; args} -> | Add {typ; args} ->
let offset = Mset.count args (integer Z.one typ) in let offset = Qset.count args (integer Z.one typ) in
if Z.is_zero offset then None else Some (offset, typ) if Q.equal Q.zero offset then None else Some (offset, typ)
| _ -> None ) | _ -> None )
|> check (function |> check (function
| Some (k, _) -> assert (not (Z.is_zero k)) | Some (k, _) -> assert (not (Q.equal Q.zero k))
| None -> () ) | None -> () )
let base e = let base e =
( match e with ( match e with
| Add {typ; args} -> ( | Add {typ; args} -> (
let args = Mset.remove args (integer Z.one typ) in let args = Qset.remove args (integer Z.one typ) in
match Mset.length args with match Qset.length args with
| 0 -> integer Z.zero typ | 0 -> integer Z.zero typ
| 1 -> ( | 1 -> (
match Mset.min_elt args with match Qset.min_elt args with
| Some (arg, z) when Z.is_one z -> arg | Some (arg, q) when Q.equal Q.one q -> arg
| _ -> Add {typ; args} ) | _ -> Add {typ; args} )
| _ -> Add {typ; args} ) | _ -> Add {typ; args} )
| _ -> e ) | _ -> e )
@ -1242,14 +1238,14 @@ let base e =
let base_offset e = let base_offset e =
( match e with ( match e with
| Add {typ; args} -> ( | Add {typ; args} -> (
match Mset.count_and_remove args (integer Z.one typ) with match Qset.count_and_remove args (integer Z.one typ) with
| Some (offset, args) -> | Some (offset, args) ->
let base = let base =
match Mset.length args with match Qset.length args with
| 0 -> integer Z.zero typ | 0 -> integer Z.zero typ
| 1 -> ( | 1 -> (
match Mset.min_elt args with match Qset.min_elt args with
| Some (arg, z) when Z.is_one z -> arg | Some (arg, q) when Q.equal Q.one q -> arg
| _ -> Add {typ; args} ) | _ -> Add {typ; args} )
| _ -> Add {typ; args} | _ -> Add {typ; args}
in in
@ -1259,7 +1255,7 @@ let base_offset e =
|> check (function |> check (function
| Some (b, k, _) -> | Some (b, k, _) ->
invariant b ; invariant b ;
assert (not (Z.is_zero k)) assert (not (Q.equal Q.zero k))
| None -> () ) | None -> () )
(** Query *) (** Query *)
@ -1278,5 +1274,5 @@ let rec is_constant = function
| Var _ | Nondet _ -> false | Var _ | Nondet _ -> false
| App {op; arg} -> is_constant arg && is_constant op | App {op; arg} -> is_constant arg && is_constant op
| Add {args} | Mul {args} -> | Add {args} | Mul {args} ->
Mset.for_all ~f:(fun arg _ -> is_constant arg) args Qset.for_all ~f:(fun arg _ -> is_constant arg) args
| _ -> true | _ -> true

@ -23,13 +23,13 @@
type comparator_witness type comparator_witness
type mset = (t, comparator_witness) Mset.t type qset = (t, comparator_witness) Qset.t
and t = private and t = private
| App of {op: t; arg: t} | App of {op: t; arg: t}
(** Application of function symbol to argument, curried *) (** Application of function symbol to argument, curried *)
| Add of {args: mset; typ: Typ.t} (** Addition *) | Add of {args: qset; typ: Typ.t} (** Addition *)
| Mul of {args: mset; typ: Typ.t} (** Multiplication *) | Mul of {args: qset; typ: Typ.t} (** Multiplication *)
| Var of {id: int; name: string} (** Local variable / virtual register *) | Var of {id: int; name: string} (** Local variable / virtual register *)
| Nondet of {msg: string} | Nondet of {msg: string}
(** Anonymous local variable with arbitrary value, representing (** Anonymous local variable with arbitrary value, representing
@ -141,6 +141,7 @@ val memory : siz:t -> arr:t -> t
val concat : t -> t -> t val concat : t -> t -> t
val bool : bool -> t val bool : bool -> t
val integer : Z.t -> Typ.t -> t val integer : Z.t -> Typ.t -> t
val rational : Q.t -> Typ.t -> t
val float : string -> t val float : string -> t
val eq : t -> t -> t val eq : t -> t -> t
val dq : t -> t -> t val dq : t -> t -> t
@ -189,13 +190,13 @@ val convert : ?signed:bool -> dst:Typ.t -> src:Typ.t -> t -> t
(** Destruct *) (** Destruct *)
val base_offset : t -> (t * Z.t * Typ.t) option val base_offset : t -> (t * Q.t * Typ.t) option
(** Decompose an addition of a constant "offset" to a "base" exp. *) (** Decompose an addition of a constant "offset" to a "base" exp. *)
val base : t -> t val base : t -> t
(** Like [base_offset] but does not construct the "offset" exp. *) (** Like [base_offset] but does not construct the "offset" exp. *)
val offset : t -> (Z.t * Typ.t) option val offset : t -> (Q.t * Typ.t) option
(** Like [base_offset] but does not construct the "base" exp. *) (** Like [base_offset] but does not construct the "base" exp. *)
(** Access *) (** Access *)

@ -238,13 +238,13 @@ let pp_diff fs (r, s) =
(** solve a+i = b for a, yielding a = b-i *) (** solve a+i = b for a, yielding a = b-i *)
let solve_for_base ai b = let solve_for_base ai b =
match Exp.base_offset ai with match Exp.base_offset ai with
| Some (a, i, typ) -> (a, Exp.sub typ b (Exp.integer i typ)) | Some (a, i, typ) -> (a, Exp.sub typ b (Exp.rational i typ))
| None -> (ai, b) | None -> (ai, b)
(** subtract offset from both sides of equation a+i = b, yielding b-i *) (** subtract offset from both sides of equation a+i = b, yielding b-i *)
let subtract_offset ai b = let subtract_offset ai b =
match Exp.offset ai with match Exp.offset ai with
| Some (i, typ) -> Exp.sub typ b (Exp.integer i typ) | Some (i, typ) -> Exp.sub typ b (Exp.rational i typ)
| None -> b | None -> b
(** [map_base ~f a+i] is [f(a) + i] and [map_base ~f a] is [f(a)] *) (** [map_base ~f a+i] is [f(a) + i] and [map_base ~f a] is [f(a)] *)
@ -252,7 +252,7 @@ let map_base ai ~f =
match Exp.base_offset ai with match Exp.base_offset ai with
| Some (a, i, typ) -> | Some (a, i, typ) ->
let a' = f a in let a' = f a in
if a' == a then ai else Exp.add typ a' (Exp.integer i typ) if a' == a then ai else Exp.add typ a' (Exp.rational i typ)
| None -> f ai | None -> f ai
(** [norm_base r a] is [a'+k] where [r] implies [a = a'+k] and [a'] is a (** [norm_base r a] is [a'+k] where [r] implies [a = a'+k] and [a'] is a
@ -538,7 +538,7 @@ let rec norm_extend r ek =
Map.find_or_add r.rep e Map.find_or_add r.rep e
~if_found:(fun e' -> ~if_found:(fun e' ->
match Exp.offset ek with match Exp.offset ek with
| Some (k, typ) -> (r, Exp.add typ e' (Exp.integer k typ)) | Some (k, typ) -> (r, Exp.add typ e' (Exp.rational k typ))
| None -> (r, e') ) | None -> (r, e') )
~default:e ~default:e
~if_added:(fun rep -> ~if_added:(fun rep ->

Loading…
Cancel
Save