From 22578089c344bd94eeb2892ad428d2179f8dd018 Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Mon, 25 Feb 2019 07:08:22 -0800 Subject: [PATCH] [sledge] Reimplement arithmetic and congruence closure Summary: - Add nary expressions implemented using a form of multisets which support any integer multiplicity - Reimplement polynomials using new nary expressions - Move the decomposition of exps into "base plus offset" form into Exp, to enforce simplification invariants - Revise expression simplification to cooperate with congruence closure (mainly: simplification should not invent new subexpressions) - Reimplement congruence closure plus integer offsets to + cope with new representation of polynomials using nary expression forms + be diligent about maintaining which expressions are in the relation + add lots of invariant checking for the correlations between the componnents of the congruence closure data structures Reviewed By: jvillard Differential Revision: D14075512 fbshipit-source-id: 2dbaf3d11 --- sledge/TODO.org | 23 +- sledge/src/import/import.ml | 6 + sledge/src/import/import.mli | 9 +- sledge/src/import/mset.ml | 106 ++ sledge/src/import/mset.mli | 108 +++ sledge/src/import/vector.ml | 1 + sledge/src/import/vector.mli | 2 +- sledge/src/llair/exp.ml | 1223 +++++++++++++----------- sledge/src/llair/exp.mli | 33 +- sledge/src/llair/exp_test.ml | 78 +- sledge/src/symbheap/congruence.ml | 685 ++++++++----- sledge/src/symbheap/congruence.mli | 3 + sledge/src/symbheap/congruence_test.ml | 269 +++++- 13 files changed, 1679 insertions(+), 867 deletions(-) create mode 100644 sledge/src/import/mset.ml create mode 100644 sledge/src/import/mset.mli diff --git a/sledge/TODO.org b/sledge/TODO.org index 5f59b045d..2a07a3306 100644 --- a/sledge/TODO.org +++ b/sledge/TODO.org @@ -1,3 +1,5 @@ +* overall +** rename accumulators from [z] to [s] for "state" * llvm * import ** consider adding set ops that operate on a set and the domain of a map @@ -16,7 +18,16 @@ rather than nearest enclosing ** revise spec of strlen to account for non-max length strings ** convert strlen inst into a primitive to return the end of the block containing a pointer, and model strlen in code * llair -** simplify "greater-than" exps to "less-than" in reverse order +** divide Exp into two: one for code and one for formulas +- Exp simplification does not preserve order of operations, which is wrong wrt overflow +- code Exps don't need polynomial simplification +- code Exps could be given strong types in order to check the frontend, while letting formula Exps have weaker types as dictated by the logic +- treat formula exps as unbounded, clamp to bounded range when conferting to a code exp +** check if simplification via simp_sub in simp_eq is still needed +- it leads to violations of the subexp assertion on app1 +** replace Option.value_exn (Typ.prim_bit_size_of typ) with bits_of_int +** simplify combinations of mul and div, e.g. x * (y / z) ==> (x * y) / z +** ? simplify "greater-than" exps to "less-than" in reverse order ** when Xor exps have types, simplify e xor e to 0 ** normalize polynomial equations by dividing coefficients by their gcd ** treat Typ.ptr as an integer of some particular size (i.e. ptr = intptr) @@ -215,6 +226,14 @@ it is not obvious whether it will be simpler to use free variables instead of No ** llvm bugs? - Why aren't shufflevector instructions with zeroinitializer masks eliminated by the scalarizer pass? * congruence +** should handle equality and disequality simplification +- equalities of equalities to integers currently handled by Sh.pure +- doing it in Exp leads to violations of the subexp assertion on app1 +** optimize: change Cls.t and Use.t from a list to an unbalanced tree data structure +- only need empty, add, union, map, fold, fold_map to be fast, so no need for balancing +- detecting duplicates probably not worth the time since if any occur, the only cost is adding a redundant equation to pnd which will be quickly processed +** optimize: when called from extend, norm_extend calls norm unnecessarily +** revise mli to two sections, one for a "relation" api (with merge, mem/check, etc) and one for a "formula" api (with and_, or_, etc.) ** ? assert exps in formulas are in the carrier us and xs, or just fv? ** strengthen invariant @@ -222,8 +241,6 @@ us and xs, or just fv? since they (could) have the same domain ** optimize: can identity mappings in lkp be removed? * symbolic heap -** Congruence should handle equalities of equalities to integers -currently handled by Sh.pure ** normalize exps in terms of reps - add operation to normalize by rewriting in terms of reps - check for unsat diff --git a/sledge/src/import/import.ml b/sledge/src/import/import.ml index fa0ea42f1..b8d32f48c 100644 --- a/sledge/src/import/import.ml +++ b/sledge/src/import/import.ml @@ -265,6 +265,12 @@ module Set = struct let to_tree = Using_comparator.to_tree end +module Mset = struct + include Mset + + let pp sep pp_elt fs s = List.pp sep pp_elt fs (to_list s) +end + module Z = struct include Z diff --git a/sledge/src/import/import.mli b/sledge/src/import/import.mli index 6c582073a..2fcc38c8b 100644 --- a/sledge/src/import/import.mli +++ b/sledge/src/import/import.mli @@ -212,8 +212,15 @@ module Set : sig val to_tree : ('e, 'c) t -> ('e, 'c) tree end +module Mset : sig + include module type of Mset + + val pp : (unit, unit) fmt -> ('a * Z.t) pp -> ('a, _) t pp + (** Pretty-print a multiset. *) +end + module Z : sig - include module type of Z + include module type of struct include Z end val hash_fold_t : t Hash.folder val t_of_sexp : Sexp.t -> t diff --git a/sledge/src/import/mset.ml b/sledge/src/import/mset.ml new file mode 100644 index 000000000..22252ed1a --- /dev/null +++ b/sledge/src/import/mset.ml @@ -0,0 +1,106 @@ +(* + * Copyright (c) 2018-present, Facebook, Inc. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + *) + +(** Mset - Set with integer (positive, negative, or zero) multiplicity for + each element *) + +open Base + +type ('elt, 'cmp) t = ('elt, Z.t, 'cmp) Map.t + +module M (Elt : sig + type t + type comparator_witness +end) = +struct + type nonrec t = (Elt.t, Elt.comparator_witness) t +end + +module type Sexp_of_m = sig + type t [@@deriving sexp_of] +end + +module type M_of_sexp = sig + type t [@@deriving of_sexp] + + include Comparator.S with type t := t +end + +module type Compare_m = sig end +module type Hash_fold_m = Hasher.S + +let sexp_of_z z = Sexp.Atom (Z.to_string z) +let z_of_sexp = function Sexp.Atom s -> Z.of_string s | _ -> assert false +let hash_fold_z state z = Hash.fold_int state (Z.hash z) + +let sexp_of_m__t (type elt) (module Elt : Sexp_of_m with type t = elt) t = + Map.sexp_of_m__t (module Elt) sexp_of_z t + +let m__t_of_sexp (type elt cmp) + (module Elt : M_of_sexp + with type t = elt and type comparator_witness = cmp) sexp = + Map.m__t_of_sexp (module Elt) z_of_sexp sexp + +let compare_m__t (module Elt : Compare_m) = Map.compare_direct Z.compare + +let hash_fold_m__t (type elt) (module Elt : Hash_fold_m with type t = elt) + state = + Map.hash_fold_m__t (module Elt) hash_fold_z state + +let hash_m__t (type elt) (module Elt : Hash_fold_m with type t = elt) = + Hash.of_fold (hash_fold_m__t (module Elt)) + +type ('elt, 'cmp) comparator = + (module Comparator.S with type t = 'elt and type comparator_witness = 'cmp) + +let empty cmp = Map.empty cmp +let if_nz z = if Z.equal Z.zero z then None else Some z + +let add m x i = + Map.change m x ~f:(function Some j -> if_nz Z.(i + j) | None -> if_nz i) + +let remove m x = Map.remove m x + +let union m n = + Map.merge m n ~f:(fun ~key:_ -> function + | `Both (i, j) -> if_nz Z.(i + j) | `Left i | `Right i -> Some i ) + +let length m = Map.length m +let count m x = match Map.find m x with Some z -> z | None -> Z.zero + +let count_and_remove m x = + let found = ref Z.zero in + let m = + Map.change m x ~f:(function + | None -> None + | Some i -> + found := i ; + None ) + in + if Z.equal !found Z.zero then None else Some (!found, m) + +let min_elt = Map.min_elt +let fold m ~f ~init = Map.fold m ~f:(fun ~key ~data s -> f key data s) ~init + +let map m ~f = + fold m ~init:m ~f:(fun x i m -> + let x', i' = f x i in + if phys_equal x' x then + if Z.equal i' i then m else Map.set m ~key:x ~data:i' + else add (Map.remove m x) x' i' ) + +let fold_map m ~f ~init:s = + fold m ~init:(m, s) ~f:(fun x i (m, s) -> + let x', i', s = f x i s in + if phys_equal x' x then + if Z.equal i' i then (m, s) else (Map.set m ~key:x ~data:i', s) + else (add (Map.remove m x) x' i', s) ) + +let for_all m ~f = Map.for_alli m ~f:(fun ~key ~data -> f key data) +let map_counts m ~f = Map.mapi m ~f:(fun ~key ~data -> f key data) +let iter m ~f = Map.iteri m ~f:(fun ~key ~data -> f key data) +let to_list m = Map.to_alist m diff --git a/sledge/src/import/mset.mli b/sledge/src/import/mset.mli new file mode 100644 index 000000000..7e48caeb2 --- /dev/null +++ b/sledge/src/import/mset.mli @@ -0,0 +1,108 @@ +(* + * Copyright (c) 2018-present, Facebook, Inc. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + *) + +(** Mset - Set with integer (positive, negative, or zero) multiplicity for + each element *) + +open Base + +type ('elt, 'cmp) t + +type ('elt, 'cmp) comparator = + (module Comparator.S with type t = 'elt and type comparator_witness = 'cmp) + +module M (Elt : sig + type t + type comparator_witness +end) : sig + type nonrec t = (Elt.t, Elt.comparator_witness) t +end + +module type Sexp_of_m = sig + type t [@@deriving sexp_of] +end + +module type M_of_sexp = sig + type t [@@deriving of_sexp] + + include Comparator.S with type t := t +end + +module type Compare_m = sig end +module type Hash_fold_m = Hasher.S + +val sexp_of_m__t : + (module Sexp_of_m with type t = 'elt) -> ('elt, 'cmp) t -> Sexp.t + +val m__t_of_sexp : + (module M_of_sexp with type t = 'elt and type comparator_witness = 'cmp) + -> Sexp.t + -> ('elt, 'cmp) t + +val compare_m__t : + (module Compare_m) -> ('elt, 'cmp) t -> ('elt, 'cmp) t -> int + +val hash_fold_m__t : + (module Hash_fold_m with type t = 'elt) + -> Hash.state + -> ('elt, _) t + -> Hash.state + +val hash_m__t : + (module Hash_fold_m with type t = 'elt) -> ('elt, _) t -> Hash.hash_value + +val empty : ('elt, 'cmp) comparator -> ('elt, 'cmp) t +(** The empty multiset over the provided order. *) + +val add : ('a, 'c) t -> 'a -> Z.t -> ('a, 'c) t +(** Add to multiplicity of single element. [O(log n)] *) + +val remove : ('a, 'c) t -> 'a -> ('a, 'c) t +(** Set the multiplicity of an element to zero. [O(log n)] *) + +val union : ('a, 'c) t -> ('a, 'c) t -> ('a, 'c) t +(** Sum multiplicities pointwise. [O(n + m)] *) + +val length : _ t -> int +(** Number of elements with non-zero multiplicity. [O(1)]. *) + +val count : ('a, _) t -> 'a -> Z.t +(** Multiplicity of an element. [O(log n)]. *) + +val count_and_remove : ('a, 'c) t -> 'a -> (Z.t * ('a, 'c) t) option +(** Multiplicity of an element, and remove it. [O(log n)]. *) + +val map : ('a, 'c) t -> f:('a -> Z.t -> 'a * Z.t) -> ('a, 'c) t +(** Map over the elements in ascending order. Preserves physical equality if + [f] does. *) + +val map_counts : ('a, 'c) t -> f:('a -> Z.t -> Z.t) -> ('a, 'c) t +(** Map over the multiplicities of the elements in ascending order. *) + +val fold : ('a, _) t -> f:('a -> Z.t -> 's -> 's) -> init:'s -> 's +(** Fold over the elements in ascending order. *) + +val fold_map : + ('a, 'c) t + -> f:('a -> Z.t -> 's -> 'a * Z.t * 's) + -> init:'s + -> ('a, 'c) t * 's +(** Folding map over the elements in ascending order. Preserves physical + equality if [f] does. *) + +val for_all : ('a, _) t -> f:('a -> Z.t -> bool) -> bool +(** Universal property test. [O(n)] but returns as soon as a violation is + found, in ascending order. *) + +val iter : ('a, _) t -> f:('a -> Z.t -> unit) -> unit +(** Iterate over the elements in ascending order. *) + +val min_elt : ('a, _) t -> ('a * Z.t) option +(** Minimum element. *) + +val to_list : ('a, _) t -> ('a * Z.t) list +(** Convert to a list of elements in ascending order. *) diff --git a/sledge/src/import/vector.ml b/sledge/src/import/vector.ml index 36b6cf31d..1c3258e1b 100644 --- a/sledge/src/import/vector.ml +++ b/sledge/src/import/vector.ml @@ -74,5 +74,6 @@ let of_array = v let of_list x = v (Array.of_list x) let of_list_rev x = v (Array.of_list_rev x) let of_option x = v (Option.to_array x) +let reduce_exn x ~f = Array.reduce_exn (a x) ~f let to_list x = Array.to_list (a x) let to_array = a diff --git a/sledge/src/import/vector.mli b/sledge/src/import/vector.mli index d494bef96..1a0bf5d01 100644 --- a/sledge/src/import/vector.mli +++ b/sledge/src/import/vector.mli @@ -167,7 +167,7 @@ val find_exn : 'a t -> f:('a -> bool) -> 'a val contains_dup : compare:('a -> 'a -> int) -> 'a t -> bool (* val reduce : 'a t -> f:('a -> 'a -> 'a) -> 'a option *) -(* val reduce_exn : 'a t -> f:('a -> 'a -> 'a) -> 'a *) +val reduce_exn : 'a t -> f:('a -> 'a -> 'a) -> 'a (* val random_element : * ?random_state:Base.Random.State.t -> 'a t -> 'a option *) diff --git a/sledge/src/llair/exp.ml b/sledge/src/llair/exp.ml index 717962027..0f239057a 100644 --- a/sledge/src/llair/exp.ml +++ b/sledge/src/llair/exp.ml @@ -11,15 +11,11 @@ module Z = struct type t = Z.t [@@deriving compare, hash, sexp] - let equal = Z.equal + include (Z : module type of Z with type t := t) + let pp = Z.pp_print - let zero = Z.zero - let one = Z.one - let minus_one = Z.minus_one - let sign = Z.sign - let to_int = Z.to_int - let numbits = Z.numbits - let fits_int = Z.fits_int + let is_zero = Z.equal zero + let is_one = Z.equal one (* the signed 1-bit integers are -1 and 0 *) let true_ = Z.minus_one @@ -38,35 +34,38 @@ module Z = struct let clamp_bop ~signed bits op x y = clamp ~signed bits (op (clamp ~signed bits x) (clamp ~signed bits y)) - let eq ~bits x y = clamp_cmp ~signed:true bits Z.equal x y - let leq ~bits x y = clamp_cmp ~signed:true bits Z.leq x y - let geq ~bits x y = clamp_cmp ~signed:true bits Z.geq x y - let lt ~bits x y = clamp_cmp ~signed:true bits Z.lt x y - let gt ~bits x y = clamp_cmp ~signed:true bits Z.gt x y - let uleq ~bits x y = clamp_cmp ~signed:false bits Z.leq x y - let ugeq ~bits x y = clamp_cmp ~signed:false bits Z.geq x y - let ult ~bits x y = clamp_cmp ~signed:false bits Z.lt x y - let ugt ~bits x y = clamp_cmp ~signed:false bits Z.gt x y - let add ~bits x y = clamp_bop ~signed:true bits Z.add x y - let sub ~bits x y = clamp_bop ~signed:true bits Z.sub x y - let mul ~bits x y = clamp_bop ~signed:true bits Z.mul x y - let div ~bits x y = clamp_bop ~signed:true bits Z.div x y - let rem ~bits x y = clamp_bop ~signed:true bits Z.rem x y - let udiv ~bits x y = clamp_bop ~signed:false bits Z.div x y - let urem ~bits x y = clamp_bop ~signed:false bits Z.rem x y - let logand ~bits x y = clamp_bop ~signed:true bits Z.logand x y - let logor ~bits x y = clamp_bop ~signed:true bits Z.logor x y - let logxor ~bits x y = clamp_bop ~signed:true bits Z.logxor x y - let shift_left ~bits z i = Z.shift_left (clamp bits ~signed:true z) i - let shift_right ~bits z i = Z.shift_right (clamp bits ~signed:true z) i - - let shift_right_trunc ~bits z i = + let beq ~bits x y = clamp_cmp ~signed:true bits Z.equal x y + let bleq ~bits x y = clamp_cmp ~signed:true bits Z.leq x y + let bgeq ~bits x y = clamp_cmp ~signed:true bits Z.geq x y + let blt ~bits x y = clamp_cmp ~signed:true bits Z.lt x y + let bgt ~bits x y = clamp_cmp ~signed:true bits Z.gt x y + let buleq ~bits x y = clamp_cmp ~signed:false bits Z.leq x y + let bugeq ~bits x y = clamp_cmp ~signed:false bits Z.geq x y + let bult ~bits x y = clamp_cmp ~signed:false bits Z.lt x y + let bugt ~bits x y = clamp_cmp ~signed:false bits Z.gt x y + let badd ~bits x y = clamp_bop ~signed:true bits Z.add x y + let bsub ~bits x y = clamp_bop ~signed:true bits Z.sub x y + let bmul ~bits x y = clamp_bop ~signed:true bits Z.mul x y + let bdiv ~bits x y = clamp_bop ~signed:true bits Z.div x y + let brem ~bits x y = clamp_bop ~signed:true bits Z.rem x y + let budiv ~bits x y = clamp_bop ~signed:false bits Z.div x y + let burem ~bits x y = clamp_bop ~signed:false bits Z.rem x y + let blogand ~bits x y = clamp_bop ~signed:true bits Z.logand x y + let blogor ~bits x y = clamp_bop ~signed:true bits Z.logor x y + let blogxor ~bits x y = clamp_bop ~signed:true bits Z.logxor x y + let bshift_left ~bits z i = Z.shift_left (clamp bits ~signed:true z) i + let bshift_right ~bits z i = Z.shift_right (clamp bits ~signed:true z) i + + let bshift_right_trunc ~bits z i = Z.shift_right_trunc (clamp bits ~signed:true z) i end -module T0 = struct +module rec T : sig + type mset = Mset.M(T).t [@@deriving compare, hash, sexp] + type t = | App of {op: t; arg: t} + | AppN of {op: t; args: mset} | Var of {id: int; name: string} | Nondet of {msg: string} | Label of {parent: string; name: string} @@ -90,9 +89,10 @@ module T0 = struct | Ule | Ord | Uno - (* binary: arithmetic, numeric and pointer *) + (* nary: arithmetic, numeric and pointer *) | Add of {typ: Typ.t} | Mul of {typ: Typ.t} + (* binary: arithmetic, numeric and pointer *) | Div | Udiv | Rem @@ -115,176 +115,262 @@ module T0 = struct | Convert of {signed: bool; dst: Typ.t; src: Typ.t} [@@deriving compare, hash, sexp] - let equal = [%compare.equal: t] - let sorted e f = compare e f <= 0 - let sort e f = if sorted e f then (e, f) else (f, e) -end + type comparator_witness -module T = struct + val comparator : (t, comparator_witness) Comparator.t +end = struct include T0 include Comparator.Make (T0) +end - let fix (f : (t -> 'a as 'f) -> 'f) (bot : 'f) (e : t) : 'a = - let rec fix_f seen e = - match e with - | Struct_rec _ -> - if List.mem ~equal:( == ) seen e then f bot e - else f (fix_f (e :: seen)) e - | _ -> f (fix_f seen) e - in - let rec fix_f_seen_nil e = - match e with - | Struct_rec _ -> f (fix_f [e]) e - | _ -> f fix_f_seen_nil e - in - fix_f_seen_nil e - - let fix_flip (f : ('z -> t -> 'a as 'f) -> 'f) (bot : 'f) (z : 'z) (e : t) - = - fix (fun f' e z -> f (fun z e -> f' e z) z e) (fun e z -> bot z e) e z +(* auxiliary definition for safe recursive module initialization *) +and T0 : sig + type mset = Mset.M(T).t [@@deriving compare, hash, sexp] - let uncurry = - let rec uncurry_ args = function - | App {op; arg} -> uncurry_ (arg :: args) op - | op -> (op, args) - in - uncurry_ [] + type t = + | App of {op: t; arg: t} + | AppN of {op: t; args: mset} + | Var of {id: int; name: string} + | Nondet of {msg: string} + | Label of {parent: string; name: string} + | Splat + | Memory + | Concat + | Integer of {data: Z.t; typ: Typ.t} + | Float of {data: string} + | Eq + | Dq + | Gt + | Ge + | Lt + | Le + | Ugt + | Uge + | Ult + | Ule + | Ord + | Uno + | Add of {typ: Typ.t} + | Mul of {typ: Typ.t} + | Div + | Udiv + | Rem + | Urem + | And + | Or + | Xor + | Shl + | Lshr + | Ashr + | Conditional + | Record + | Select + | Update + | Struct_rec of {elts: t vector} + | Convert of {signed: bool; dst: Typ.t; src: Typ.t} + [@@deriving compare, hash, sexp] +end = struct + type mset = Mset.M(T).t [@@deriving compare, hash, sexp] - let rec pp fs exp = - let pp_ pp fs exp = - let pf fmt = - Format.pp_open_box fs 2 ; - Format.kfprintf (fun fs -> Format.pp_close_box fs ()) fs fmt - in - match exp with - | Var {name; id= 0} -> pf "%%%s" name - | Var {name; id} -> pf "%%%s_%d" name id - | Nondet {msg} -> pf "nondet \"%s\"" msg - | Label {name} -> pf "%s" name - | Integer {data; typ= Pointer _} when Z.equal Z.zero data -> pf "null" - | Splat -> pf "^" - | Memory -> pf "⟨_,_⟩" - | App {op= Memory; arg= siz} -> pf "@<1>⟨%a,_@<1>⟩" pp siz - | App {op= App {op= Memory; arg= siz}; arg= bytes} -> - pf "@<1>⟨%a,%a@<1>⟩" pp siz pp bytes - | Concat -> pf "^" - | Integer {data} -> pf "%a" Z.pp data - | Float {data} -> pf "%s" data - | Eq -> pf "=" - | Dq -> pf "@<1>≠" - | Gt -> pf ">" - | Ge -> pf ">=" - | Lt -> pf "<" - | Le -> pf "<=" - | Ugt -> pf "u>" - | Uge -> pf "u>=" - | Ult -> pf "u<" - | Ule -> pf "u<=" - | Ord -> pf "ord" - | Uno -> pf "uno" - | Add _ -> pf "+" - | Mul _ -> pf "@<1>×" - | App {op= App {op= Add _ | Mul _}} -> pp_poly fs exp - | Div -> pf "/" - | Udiv -> pf "udiv" - | Rem -> pf "rem" - | Urem -> pf "urem" - | And -> pf "&&" - | Or -> pf "||" - | Xor -> pf "xor" - | App - { op= App {op= Xor; arg} - ; arg= Integer {data; typ= Integer {bits= 1}} } - when Z.is_true data -> - pf "¬%a" pp arg - | App - { op= App {op= Xor; arg= Integer {data; typ= Integer {bits= 1}}} - ; arg } - when Z.is_true data -> - pf "¬%a" pp arg - | Shl -> pf "shl" - | Lshr -> pf "lshr" - | Ashr -> pf "ashr" - | Conditional -> pf "(_?_:_)" - | App {op= Conditional; arg= cnd} -> pf "(%a@ ? _:_)" pp cnd - | App {op= App {op= Conditional; arg= cnd}; arg= thn} -> - pf "(%a@ ? %a@ : _)" pp cnd pp thn - | App - {op= App {op= App {op= Conditional; arg= cnd}; arg= thn}; arg= els} - -> - pf "(%a@ ? %a@ : %a)" pp cnd pp thn pp els - | Select -> pf "_[_]" - | App {op= Select; arg= rcd} -> pf "%a[_]" pp rcd - | App {op= App {op= Select; arg= rcd}; arg= idx} -> - pf "%a[%a]" pp rcd pp idx - | Update -> pf "[_|_→_]" - | App {op= Update; arg= rcd} -> pf "[%a@ @[| _→_@]]" pp rcd - | App {op= App {op= Update; arg= rcd}; arg= elt} -> - pf "[%a@ @[| _→ %a@]]" pp rcd pp elt - | App {op= App {op= App {op= Update; arg= rcd}; arg= elt}; arg= idx} - -> - pf "[%a@ @[| %a → %a@]]" pp rcd pp idx pp elt - | Record -> pf "{_}" - | App {op; arg} -> ( - match uncurry exp with - | Record, elts -> pf "{%a}" pp_record elts - | op, [x; y] -> pf "(%a@ %a %a)" pp x pp op pp y - | _ -> pf "(%a@ %a)" pp op pp arg ) - | Struct_rec {elts} -> pf "{|%a|}" (Vector.pp ",@ " pp) elts - | Convert {dst; src} -> pf "(%a)(%a)" Typ.pp dst Typ.pp src - in - fix_flip pp_ (fun _ _ -> ()) fs exp - - and pp_mono fs m = - let rec pp_mono fs m = - match m with - | App {op= App {op= Mul _; arg= m}; arg= f} -> - Format.fprintf fs "%a@ @<2>× %a" pp_mono m pp f - | e -> pp fs e - in - match m with - | App {op= App {op= Mul _; arg= m}; arg= Integer _ as c} -> - Format.fprintf fs "(%a@ @<2>× %a)" pp c pp_mono m - | App {op= App {op= Mul _}} -> Format.fprintf fs "(%a)" pp_mono m - | _ -> pp_mono fs m - - and pp_poly fs p = - let rec pp_poly fs p = - match p with - | App {op= App {op= Add _; arg= p}; arg= m} -> - Format.fprintf fs "%a@ + @[%a@]" pp_poly p pp_mono m - | m -> Format.fprintf fs "@[%a@]" pp_mono m - in - match p with - | App {op= App {op= Add _}} -> - Format.fprintf fs "@[(%a)@]" pp_poly p - | _ -> pp_poly fs p - - and pp_record fs elts = - [%Trace.fprintf - fs "%a" - (fun fs elts -> - let elta = Array.of_list elts in - match - String.init (Array.length elta) ~f:(fun i -> - match elta.(i) with - | Integer {data} -> Char.of_int_exn (Z.to_int data) - | _ -> raise (Invalid_argument "not a string") ) - with - | s -> Format.fprintf fs "@[%s@]" (String.escaped s) - | exception _ -> - Format.fprintf fs "@[%a@]" (List.pp ",@ " pp) elts ) - elts] + type t = + | App of {op: t; arg: t} + | AppN of {op: t; args: mset} + | Var of {id: int; name: string} + | Nondet of {msg: string} + | Label of {parent: string; name: string} + | Splat + | Memory + | Concat + | Integer of {data: Z.t; typ: Typ.t} + | Float of {data: string} + | Eq + | Dq + | Gt + | Ge + | Lt + | Le + | Ugt + | Uge + | Ult + | Ule + | Ord + | Uno + | Add of {typ: Typ.t} + | Mul of {typ: Typ.t} + | Div + | Udiv + | Rem + | Urem + | And + | Or + | Xor + | Shl + | Lshr + | Ashr + | Conditional + | Record + | Select + | Update + | Struct_rec of {elts: t vector} + | Convert of {signed: bool; dst: Typ.t; src: Typ.t} + [@@deriving compare, hash, sexp] end +(* suppress spurious "Warning 60: unused module T0." *) +type _t = T0.t + include T +let empty_mset = Mset.empty (module T) +let equal = [%compare.equal: t] +let sorted e f = compare e f <= 0 +let sort e f = if sorted e f then (e, f) else (f, e) + +let fix (f : (t -> 'a as 'f) -> 'f) (bot : 'f) (e : t) : 'a = + let rec fix_f seen e = + match e with + | Struct_rec _ -> + if List.mem ~equal:( == ) seen e then f bot e + else f (fix_f (e :: seen)) e + | _ -> f (fix_f seen) e + in + let rec fix_f_seen_nil e = + match e with Struct_rec _ -> f (fix_f [e]) e | _ -> f fix_f_seen_nil e + in + fix_f_seen_nil e + +let fix_flip (f : ('z -> t -> 'a as 'f) -> 'f) (bot : 'f) (z : 'z) (e : t) = + fix (fun f' e z -> f (fun z e -> f' e z) z e) (fun e z -> bot z e) e z + +let uncurry = + let rec uncurry_ acc_args = function + | App {op; arg} -> uncurry_ (arg :: acc_args) op + | op -> (op, acc_args) + in + uncurry_ [] + +let rec pp fs exp = + let pp_ pp fs exp = + let pf fmt = + Format.pp_open_box fs 2 ; + Format.kfprintf (fun fs -> Format.pp_close_box fs ()) fs fmt + in + match exp with + | Var {name; id= 0} -> pf "%%%s" name + | Var {name; id} -> pf "%%%s_%d" name id + | Nondet {msg} -> pf "nondet \"%s\"" msg + | Label {name} -> pf "%s" name + | Integer {data; typ= Pointer _} when Z.is_zero data -> pf "null" + | Splat -> pf "^" + | Memory -> pf "⟨_,_⟩" + | App {op= Memory; arg= siz} -> pf "@<1>⟨%a,_@<1>⟩" pp siz + | App {op= App {op= Memory; arg= siz}; arg= bytes} -> + pf "@<1>⟨%a,%a@<1>⟩" pp siz pp bytes + | Concat -> pf "^" + | Integer {data} -> pf "%a" Z.pp data + | Float {data} -> pf "%s" data + | Eq -> pf "=" + | Dq -> pf "@<1>≠" + | Gt -> pf ">" + | Ge -> pf ">=" + | Lt -> pf "<" + | Le -> pf "<=" + | Ugt -> pf "u>" + | Uge -> pf "u>=" + | Ult -> pf "u<" + | Ule -> pf "u<=" + | Ord -> pf "ord" + | Uno -> pf "uno" + | Add _ -> pf "+" + | AppN {op= Add _; args} -> + let pp_poly_term fs (monomial, coefficient) = + match monomial with + | Integer {data} when Z.is_one data -> Z.pp fs coefficient + | _ when Z.is_one coefficient -> pp fs monomial + | _ -> + Format.fprintf fs "%a @<1>× %a" Z.pp coefficient pp monomial + in + pf "(%a)" (Mset.pp "@ + " pp_poly_term) args + | Mul _ -> pf "@<1>×" + | AppN {op= Mul _; args} -> + let pp_mono_term fs (factor, exponent) = + if Z.is_one exponent then pp fs factor + else Format.fprintf fs "%a^%a" pp factor Z.pp exponent + in + pf "(%a)" (Mset.pp "@ @<2>× " pp_mono_term) args + | Div -> pf "/" + | Udiv -> pf "udiv" + | Rem -> pf "rem" + | Urem -> pf "urem" + | And -> pf "&&" + | Or -> pf "||" + | Xor -> pf "xor" + | App + {op= App {op= Xor; arg}; arg= Integer {data; typ= Integer {bits= 1}}} + when Z.is_true data -> + pf "¬%a" pp arg + | App + {op= App {op= Xor; arg= Integer {data; typ= Integer {bits= 1}}}; arg} + when Z.is_true data -> + pf "¬%a" pp arg + | Shl -> pf "shl" + | Lshr -> pf "lshr" + | Ashr -> pf "ashr" + | Conditional -> pf "(_?_:_)" + | App {op= Conditional; arg= cnd} -> pf "(%a@ ? _:_)" pp cnd + | App {op= App {op= Conditional; arg= cnd}; arg= thn} -> + pf "(%a@ ? %a@ : _)" pp cnd pp thn + | App {op= App {op= App {op= Conditional; arg= cnd}; arg= thn}; arg= els} + -> + pf "(%a@ ? %a@ : %a)" pp cnd pp thn pp els + | Select -> pf "_[_]" + | App {op= Select; arg= rcd} -> pf "%a[_]" pp rcd + | App {op= App {op= Select; arg= rcd}; arg= idx} -> + pf "%a[%a]" pp rcd pp idx + | Update -> pf "[_|_→_]" + | App {op= Update; arg= rcd} -> pf "[%a@ @[| _→_@]]" pp rcd + | App {op= App {op= Update; arg= rcd}; arg= elt} -> + pf "[%a@ @[| _→ %a@]]" pp rcd pp elt + | App {op= App {op= App {op= Update; arg= rcd}; arg= elt}; arg= idx} -> + pf "[%a@ @[| %a → %a@]]" pp rcd pp idx pp elt + | Record -> pf "{_}" + | App {op; arg} -> ( + match uncurry exp with + | Record, elts -> pf "{%a}" pp_record elts + | op, [x; y] -> pf "(%a@ %a %a)" pp x pp op pp y + | _ -> pf "(%a@ %a)" pp op pp arg ) + | AppN {op; args} -> + let pp_elt fs (e, z) = Format.fprintf fs "%a %a" pp e Z.pp z in + pf "(%a@ %a)" pp op (Mset.pp "@ " pp_elt) args + | Struct_rec {elts} -> pf "{|%a|}" (Vector.pp ",@ " pp) elts + | Convert {dst; src} -> pf "(%a)(%a)" Typ.pp dst Typ.pp src + in + fix_flip pp_ (fun _ _ -> ()) fs exp + +and pp_record fs elts = + [%Trace.fprintf + fs "%a" + (fun fs elts -> + let elta = Array.of_list elts in + match + String.init (Array.length elta) ~f:(fun i -> + match elta.(i) with + | Integer {data} -> Char.of_int_exn (Z.to_int data) + | _ -> raise (Invalid_argument "not a string") ) + with + | s -> Format.fprintf fs "@[%s@]" (String.escaped s) + | exception _ -> + Format.fprintf fs "@[%a@]" (List.pp ",@ " pp) elts ) + elts] + type exp = t +let pp_t = pp + (** Invariant *) let typ_of = function - | App {op= App {op= Add {typ} | Mul {typ}}} + | AppN {op= Add {typ} | Mul {typ}} |Integer {typ} |App {op= Convert {dst= typ}} -> Some typ @@ -298,92 +384,57 @@ let type_check typ e = let type_check2 typ e f = type_check typ e ; type_check typ f -let coefficient = function - | App {op= App {op= Mul _}; arg= Integer {data}} -> data - | _ -> Z.one - -let indeterminate = function - | App {op= App {op= Mul _; arg}; arg= Integer _} -> arg - | x -> x - -(* a polynomial is a sum of monomials, e.g. - * (…(c₀ × x₀ + c₁ × x₁) + …) + cᵤ × xᵤ - * for constants cᵢ and indeterminants xᵢ - * with at most one constant - * which is not 0 - * where no non-constant monomial has coefficient 0 or 1 - * where monomials are unique and sorted modulo their coefficients - * represented as left-associated Add expressions - * so the constant is the right-most and shallowest Add arg - * a monomial is a product of factors, e.g. - * ((…(x₀ × x₁) × …) × xᵤ) × c - * for constant c and indeterminants xᵢ - * with at most one constant (aka the coefficient) - * which, if 0 or 1, is the only factor - * represented as left-associated Mul expressions - * where factors are sorted in increasing order wrt compare - * so the constant coefficient is the right-most and shallowest Mul arg - * a factor is either - * a constant (Integer) - * or an indeterminate (non-Add/Mul/Integer) - *) +(* an indeterminate (factor of a monomial) is any non-Add/Mul/Integer exp *) +let rec assert_indeterminate = function + | App {op} | AppN {op} -> assert_indeterminate op + | Integer _ | Add _ | Mul _ -> assert false + | _ -> assert true -let assert_polynomial ?(partial = false) p = - let rec assert_poly ?typ:typ0 ?bound p = - let assert_sorted mono = - Option.iter bound ~f:(fun bound -> - assert (compare (indeterminate mono) (indeterminate bound) < 0) ) - in - let assert_monomial ?typ:typ0 m = - let rec assert_mono ?typ:typ0 ?bound m = - let assert_sorted fact = - Option.iter bound ~f:(fun bound -> assert (compare fact bound <= 0) - ) - in - let assert_factor = function - | Integer _ | App {op= App {op= Add _ | Mul _}} -> assert false - | Add _ | Mul _ | App {op= Add _ | Mul _} -> assert partial - | _ -> assert true - in - match m with - | App {op= App {op= Mul _}; arg= Integer _} -> assert false - | App {op= App {op= Mul {typ}; arg= mono}; arg= fact} -> - assert (Option.for_all ~f:(Typ.castable typ) typ0) ; - assert_factor fact ; - assert_sorted fact ; - assert_mono ~typ ~bound:fact mono - | fact -> assert_factor fact ; assert_sorted fact - in - match m with - | App - { op= App {op= Mul {typ}; arg= mono} - ; arg= Integer {data; typ= typ'} } -> - assert (Option.for_all ~f:(Typ.castable typ) typ0) ; - assert (Typ.castable typ typ') ; - assert (Option.exists ~f:(( < ) 1) (Typ.prim_bit_size_of typ)) ; - assert (not (Z.equal Z.zero data)) ; - assert (not (Z.equal Z.one data)) ; - assert_mono ~typ mono - | mono -> assert_mono mono - in - match p with - | App {op= App {op= Add _}; arg= Integer _} -> assert false - | App {op= App {op= Add {typ}; arg= poly}; arg= mono} -> - assert (Option.for_all ~f:(Typ.castable typ) typ0) ; - assert_monomial ~typ mono ; - assert_sorted mono ; - assert_poly ~bound:mono ~typ poly - | mono -> - assert_monomial ?typ:typ0 mono ; - assert_sorted mono - in - match p with - | App {op= App {op= Add {typ}; arg= poly}; arg= Integer {data; typ= typ'}} - -> - assert (Typ.castable typ typ') ; - assert (not (Z.equal Z.zero data)) ; - assert_poly ~typ poly - | poly -> assert_poly poly +(* a monomial is a power product of factors, e.g. + * ∏ᵢ xᵢ^nᵢ + * for (non-constant) indeterminants xᵢ and positive integer exponents nᵢ + *) +let assert_monomial add_typ mono = + match mono with + | AppN {op= Mul {typ}; args} -> + assert (Typ.castable add_typ typ) ; + assert (Option.exists ~f:(fun n -> 1 < n) (Typ.prim_bit_size_of typ)) ; + Mset.iter args ~f:(fun factor exponent -> + assert (Z.sign exponent > 0) ; + assert_indeterminate factor |> Fn.id ) + | _ -> assert_indeterminate mono |> Fn.id + +(* a polynomial term is a monomial multiplied by a non-zero coefficient + * c × ∏ᵢ xᵢ + *) +let assert_poly_term add_typ mono coeff = + assert (not (Z.is_zero coeff)) ; + match mono with + | Integer {data} -> assert (Z.is_one data) + | AppN {op= Mul _; args} -> + if Z.is_one coeff then assert (Mset.length args > 1) + else assert (Mset.length args > 0) ; + assert_monomial add_typ mono |> Fn.id + | _ -> assert_monomial add_typ mono |> Fn.id + +(* a polynomial is a linear combination of monomials, e.g. + * ∑ᵢ cᵢ × ∏ⱼ xᵢⱼ + * for non-zero constant coefficients cᵢ + * and monomials ∏ⱼ xᵢⱼ, one of which may be the empty product 1 + *) +let assert_polynomial poly = + match poly with + | AppN {op= Add {typ}; args} -> + ( match Mset.length args with + | 0 -> assert false + | 1 -> ( + match Mset.min_elt args with + | Some (Integer _, _) -> assert false + | Some (_, k) -> assert (not (Z.is_one k)) + | _ -> () ) + | _ -> () ) ; + Mset.iter args ~f:(fun m c -> assert_poly_term typ m c |> Fn.id) + | _ -> assert false let invariant ?(partial = false) e = Invariant.invariant [%here] e [%sexp_of: t] @@ -394,26 +445,35 @@ let invariant ?(partial = false) e = assert (nargs = arity || (partial && nargs < arity)) in match op with - | Integer {data; typ= Integer {bits}} -> - assert_arity 0 ; - assert (Z.numbits data <= bits) - | Var _ | Nondet _ | Label _ | Integer _ | Float _ -> assert_arity 0 + | App _ -> fail "uncurry cannot return App or AppN" () + | Integer {data; typ= (Integer _ | Pointer _) as typ} -> ( + match Typ.prim_bit_size_of typ with + | None -> assert false + | Some bits -> + assert_arity 0 ; + assert (Z.numbits data <= bits) ) + | Integer _ -> assert false + | Var _ | Nondet _ | Label _ | Float _ -> assert_arity 0 | Convert {dst; src} -> ( match args with | [Integer {typ}] -> assert (Typ.equal src typ) | _ -> assert_arity 1 ) ; assert (Typ.convertible src dst) - | Add _ | Mul _ -> assert_polynomial ~partial e + | AppN {op= Add _} -> + assert_arity 0 ; + assert_polynomial e |> Fn.id + | AppN {op= Mul {typ}} -> + assert_arity 0 ; + assert_monomial typ e |> Fn.id + | Add _ | Mul _ -> assert (partial || fail "Add and Mul are not curried") + | AppN _ -> fail "Add and Mul are the only nary operators" () | Eq | Dq | Gt | Ge | Lt | Le | Ugt | Uge | Ult | Ule | Div | Udiv | Rem |Urem | And | Or | Xor | Shl | Lshr | Ashr -> ( match args with | [x; y] -> ( - ( match op with - | Eq | Dq | And | Or | Xor -> assert (sorted x y) - | _ -> () ) ; - match (typ_of x, typ_of y) with - | Some typ, Some typ' -> assert (Typ.castable typ typ') - | _ -> assert true ) + match (typ_of x, typ_of y) with + | Some typ, Some typ' -> assert (Typ.castable typ typ') + | _ -> assert true ) | _ -> assert_arity 2 ) | Splat | Memory | Concat | Ord | Uno | Select -> assert_arity 2 | Conditional | Update -> assert_arity 3 @@ -421,12 +481,22 @@ let invariant ?(partial = false) e = | Struct_rec {elts} -> assert (not (Vector.is_empty elts)) ; assert_arity 0 - | App _ -> fail "uncurry cannot return App" () + +let bits_of_int exp = + match exp with + | Integer {typ} -> ( + match Typ.prim_bit_size_of typ with + | Some bits -> bits + | None -> violates invariant exp ) + | _ -> fail "bits_of_int" () (** Variables are the expressions constructed by [Var] *) module Var = struct include T + let equal = equal + let pp = pp + type var = t module Set = struct @@ -436,7 +506,7 @@ module Var = struct type t = Set.M(T).t [@@deriving compare, sexp] - let pp vs = Set.pp T.pp vs + let pp vs = Set.pp pp_t vs let empty = Set.empty (module T) let of_vector = Set.of_vector (module T) end @@ -502,7 +572,7 @@ module Var = struct let pp fs s = Format.fprintf fs "@[<1>[%a]@]" (List.pp ",@ " (fun fs (k, v) -> - Format.fprintf fs "@[[%a ↦ %a]@]" T.pp k T.pp v )) + Format.fprintf fs "@[[%a ↦ %a]@]" pp_t k pp_t v )) (Map.to_alist s) let empty = Map.empty (module T) @@ -561,6 +631,9 @@ let fold_exps e ~init ~f = let z = match e with | App {op; arg} -> fold_exps_ op (fold_exps_ arg z) + | AppN {op; args} -> + fold_exps_ op + (Mset.fold args ~init:z ~f:(fun arg _ z -> fold_exps_ arg z)) | Struct_rec {elts} -> Vector.fold elts ~init:z ~f:(fun z elt -> fold_exps_ elt z) | _ -> z @@ -594,177 +667,235 @@ let simp_convert signed (dst : Typ.t) src arg = let simp_gt x y = match (x, y) with | Integer {data= i}, Integer {data= j; typ= Integer {bits}} -> - bool (Z.gt ~bits i j) + bool (Z.bgt ~bits i j) | _ -> App {op= App {op= Gt; arg= x}; arg= y} let simp_ugt x y = match (x, y) with | Integer {data= i}, Integer {data= j; typ= Integer {bits}} -> - bool (Z.ugt ~bits i j) + bool (Z.bugt ~bits i j) | _ -> App {op= App {op= Ugt; arg= x}; arg= y} let simp_ge x y = match (x, y) with | Integer {data= i}, Integer {data= j; typ= Integer {bits}} -> - bool (Z.geq ~bits i j) + bool (Z.bgeq ~bits i j) | _ -> App {op= App {op= Ge; arg= x}; arg= y} let simp_uge x y = match (x, y) with | Integer {data= i}, Integer {data= j; typ= Integer {bits}} -> - bool (Z.ugeq ~bits i j) + bool (Z.bugeq ~bits i j) | _ -> App {op= App {op= Uge; arg= x}; arg= y} let simp_lt x y = match (x, y) with | Integer {data= i}, Integer {data= j; typ= Integer {bits}} -> - bool (Z.lt ~bits i j) + bool (Z.blt ~bits i j) | _ -> App {op= App {op= Lt; arg= x}; arg= y} let simp_ult x y = match (x, y) with | Integer {data= i}, Integer {data= j; typ= Integer {bits}} -> - bool (Z.ult ~bits i j) + bool (Z.bult ~bits i j) | _ -> App {op= App {op= Ult; arg= x}; arg= y} let simp_le x y = match (x, y) with | Integer {data= i}, Integer {data= j; typ= Integer {bits}} -> - bool (Z.leq ~bits i j) + bool (Z.bleq ~bits i j) | _ -> App {op= App {op= Le; arg= x}; arg= y} let simp_ule x y = match (x, y) with | Integer {data= i}, Integer {data= j; typ= Integer {bits}} -> - bool (Z.uleq ~bits i j) + bool (Z.buleq ~bits i j) | _ -> App {op= App {op= Ule; arg= x}; arg= y} let simp_ord x y = App {op= App {op= Ord; arg= x}; arg= y} let simp_uno x y = App {op= App {op= Uno; arg= x}; arg= y} -(* see assert_polynomial for representation invariants of polynomials - * suffixes are used on function names to indicate their type, e.g. - * mul_mf multiplies a monomial by a factor, and - * mul_pm multiplies a polynomial by a monomial - *) -let rec simp_mul typ axs bys = - let bits = - match Typ.prim_bit_size_of typ with - | Some bits -> bits - | None -> fail "multiplication not defined at type %a" Typ.pp typ () - in - let one = integer Z.one typ in - let mul_mf x y = - match (x, y) with - | Integer {data; typ= Integer {bits= 1}}, _ when Z.is_false data -> x - | _, Integer {data; typ= Integer {bits= 1}} when Z.is_false data -> y - | Integer {typ= Integer {bits= 1}}, y -> y - | x, Integer {typ= Integer {bits= 1}} -> x - | Integer {data}, y when Z.equal Z.one data -> y - | x, Integer {data} when Z.equal Z.one data -> x - | _, Integer {data} when Z.equal Z.zero data -> y - | Integer {data}, _ when Z.equal Z.zero data -> x - | Integer {data= i; typ}, Integer {data= j} -> - let bits = Option.value_exn (Typ.prim_bit_size_of typ) in - integer (Z.mul ~bits i j) typ - | _ -> App {op= App {op= Mul {typ}; arg= x}; arg= y} - in - let rec mul_mm axs bys = - match (axs, bys) with - | Integer {data}, by0N when Z.equal Z.one data -> by0N - | ax0J, Integer {data} when Z.equal Z.one data -> ax0J - | ax0J, by0N -> ( - match - match (ax0J, by0N) with - | ( App {op= App {op= Mul _; arg= ax0I}; arg= axJ} - , App {op= App {op= Mul _; arg= by0M}; arg= byN} ) -> - (ax0I, axJ, by0M, byN) - | App {op= App {op= Mul _; arg= ax0I}; arg= axJ}, byN -> - (ax0I, axJ, one, byN) - | axJ, App {op= App {op= Mul _; arg= by0M}; arg= byN} -> - (one, axJ, by0M, byN) - | axJ, byN -> (one, axJ, one, byN) - with - | ax0I, Integer {data= i}, by0M, Integer {data= j} -> - mul_mf (mul_mm ax0I by0M) (integer (Z.mul ~bits i j) typ) - | ax0I, axJ, by0M, byN -> - if compare axJ byN <= 0 then mul_mf (mul_mm ax0J by0M) byN - else mul_mf (mul_mm ax0I by0N) axJ ) - in - let rec mul_pm ax0J by = - match ax0J with - | App {op= App {op= Add _; arg= ax0I}; arg= axJ} -> - simp_add typ (mul_pm ax0I by) (mul_mm axJ by) - | _ -> mul_mm ax0J by - in - let rec mul_pp axs by0N = - match by0N with - | App {op= App {op= Add _; arg= by0M}; arg= byN} -> - simp_add typ (mul_pp axs by0M) (mul_pm axs byN) - | _ -> mul_pm axs by0N - in - mul_pp axs bys +let simp_div x y = + match (x, y) with + (* i / j *) + | Integer {data= i; typ}, Integer {data= j} -> + let bits = Option.value_exn (Typ.prim_bit_size_of typ) in + integer (Z.bdiv ~bits i j) typ + (* e / 1 ==> e *) + | e, Integer {data} when Z.is_one data -> e + | _ -> App {op= App {op= Div; arg= x}; arg= y} -and simp_add typ axs bys = - let bits = - match Typ.prim_bit_size_of typ with - | Some bits -> bits - | None -> fail "addition not defined at type %a" Typ.pp typ () - in - let zero = integer Z.zero typ in - let add_pm x y = - match (x, y) with - | Integer {data}, y when Z.equal Z.zero data -> y - | x, Integer {data} when Z.equal Z.zero data -> x - | Integer {data= i; typ}, Integer {data= j} -> - let bits = Option.value_exn (Typ.prim_bit_size_of typ) in - integer (Z.add ~bits i j) typ - | _ -> App {op= App {op= Add {typ}; arg= x}; arg= y} +let simp_udiv x y = + match (x, y) with + (* i u/ j *) + | Integer {data= i; typ}, Integer {data= j} -> + let bits = Option.value_exn (Typ.prim_bit_size_of typ) in + integer (Z.budiv ~bits i j) typ + (* e u/ 1 ==> e *) + | e, Integer {data} when Z.is_one data -> e + | _ -> App {op= App {op= Udiv; arg= x}; arg= y} + +let simp_rem x y = + match (x, y) with + (* i % j *) + | Integer {data= i; typ}, Integer {data= j} -> + let bits = Option.value_exn (Typ.prim_bit_size_of typ) in + integer (Z.brem ~bits i j) typ + (* e % 1 ==> 0 *) + | _, Integer {data; typ} when Z.is_one data -> integer Z.zero typ + | _ -> App {op= App {op= Rem; arg= x}; arg= y} + +let simp_urem x y = + match (x, y) with + (* i u% j *) + | Integer {data= i; typ}, Integer {data= j} -> + let bits = Option.value_exn (Typ.prim_bit_size_of typ) in + integer (Z.burem ~bits i j) typ + (* e u% 1 ==> 0 *) + | _, Integer {data; typ} when Z.is_one data -> integer Z.zero typ + | _ -> App {op= App {op= Urem; arg= x}; arg= y} + +(* Sums of polynomial terms represented by multisets. A sum ∑ᵢ cᵢ × + Xᵢ of monomials Xᵢ with coefficients cᵢ is represented by a + multiset where the elements are Xᵢ with multiplicities cᵢ. A constant + is treated as the coefficient of the empty monomial, which is the unit of + multiplication 1. *) +module Sum = struct + let empty = empty_mset + + let add coeff exp sum = + assert (not (Z.is_zero coeff)) ; + match exp with + | Integer {data} when Z.is_zero data -> sum + | Integer {data; typ} -> + Mset.add sum (integer Z.one typ) Z.(coeff * data) + | _ -> Mset.add sum exp coeff + + let singleton ?(coeff = Z.one) exp = add coeff exp empty + + let map sum ~f = + Mset.fold sum ~init:empty ~f:(fun e c sum -> add c (f e) sum) + + let mul_const const sum = + assert (not (Z.is_zero const)) ; + if Z.is_one const then sum + else Mset.map_counts ~f:(fun _ -> Z.mul const) sum + + let to_exp typ sum = + match Mset.length sum with + | 0 -> integer Z.zero typ + | 1 -> ( + match Mset.min_elt sum with + | Some (Integer _, z) -> integer z typ + | Some (arg, z) when Z.is_one z -> arg + | _ -> AppN {op= Add {typ}; args= sum} ) + | _ -> AppN {op= Add {typ}; args= sum} +end + +let rec simp_add_ typ es poly = + (* (coeff × exp) + poly *) + let f exp coeff poly = + match (exp, poly) with + (* (0 × e) + s ==> 0 (optim) *) + | _ when Z.is_zero coeff -> poly + (* (c × 0) + s ==> s (optim) *) + | Integer {data}, _ when Z.is_zero data -> poly + (* (c × cᵢ) + cⱼ ==> c×cᵢ+cⱼ *) + | Integer {data= i}, Integer {data= j} -> + integer (Z.badd ~bits:(bits_of_int exp) Z.(coeff * i) j) typ + (* (c × ∑ᵢ cᵢ × Xᵢ) + s ==> (∑ᵢ (c × cᵢ) × Xᵢ) + s *) + | AppN {op= Add _; args}, _ -> + simp_add_ typ (Sum.mul_const coeff args) poly + (* (c₀ × X₀) + (∑ᵢ₌₁ⁿ cᵢ × Xᵢ) ==> ∑ᵢ₌₀ⁿ + cᵢ × Xᵢ *) + | _, AppN {op= Add _; args} -> Sum.to_exp typ (Sum.add coeff exp args) + (* (c₁ × X₁) + X₂ ==> ∑ᵢ₌₁² cᵢ × Xᵢ for c₂ = 1 *) + | _ -> Sum.to_exp typ (Sum.add coeff exp (Sum.singleton poly)) in - let rec add_pp axs bys = - match (axs, bys) with - | Integer {data}, by0N when Z.equal Z.zero data -> by0N - | ax0J, Integer {data} when Z.equal Z.zero data -> ax0J - | ax0J, by0N -> ( - match - match (ax0J, by0N) with - | ( App {op= App {op= Add _; arg= ax0I}; arg= axJ} - , App {op= App {op= Add _; arg= by0M}; arg= byN} ) -> - (ax0I, axJ, by0M, byN) - | App {op= App {op= Add _; arg= ax0I}; arg= axJ}, byN -> - (ax0I, axJ, zero, byN) - | axJ, App {op= App {op= Add _; arg= by0M}; arg= byN} -> - (zero, axJ, by0M, byN) - | axJ, byN -> (zero, axJ, zero, byN) - with - | ax0I, Integer {data= i}, by0M, Integer {data= j} -> - add_pm (add_pp ax0I by0M) (integer (Z.add ~bits i j) typ) - | ax0I, axJ, by0M, byN -> - let xJ = indeterminate axJ in - let yN = indeterminate byN in - let ord = compare xJ yN in - if ord < 0 then add_pm (add_pp ax0J by0M) byN - else if ord > 0 then add_pm (add_pp ax0I by0N) axJ - else - let aJ = coefficient axJ in - let bN = coefficient byN in - let c = Z.add ~bits aJ bN in - if Z.equal Z.zero c then add_pp ax0I by0M - else add_pm (add_pp ax0I by0M) (simp_mul typ (integer c typ) xJ) - ) + Mset.fold ~f es ~init:poly + +let simp_add typ es = simp_add_ typ es (integer Z.zero typ) +let simp_add2 typ e f = simp_add_ typ (Sum.singleton e) f + +(* Products of indeterminants represented by multisets. A product ∏ᵢ + xᵢ^nᵢ of indeterminates xᵢ is represented by a multiset where the + elements are xᵢ and the multiplicities are the exponents nᵢ. *) +module Prod = struct + let empty = empty_mset + let add exp prod = Mset.add prod exp Z.one + let singleton exp = add exp empty + let union = Mset.union +end + +(* map over each monomial of a polynomial *) +let poly_map_monos poly ~f = + match poly with + | AppN {op= Add {typ}; args= sum} -> + Sum.to_exp typ + (Sum.map sum ~f:(function + | AppN {op= Mul _ as mul; args= prod} -> + AppN {op= mul; args= f prod} + | _ -> violates invariant poly )) + | _ -> fail "poly_map_monos" () + +let rec simp_mul2 typ e f = + match (e, f) with + (* c₁ × c₂ ==> c₁×c₂ *) + | Integer {data= i}, Integer {data= j} -> + integer (Z.bmul ~bits:(bits_of_int e) i j) typ + (* 0 × f ==> 0 *) + | Integer {data}, _ when Z.is_zero data -> e + (* e × 0 ==> 0 *) + | _, Integer {data} when Z.is_zero data -> f + (* c × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ c × cᵤ × ∏ⱼ + yᵤⱼ *) + | Integer {data}, AppN {op= Add _; args} + |AppN {op= Add _; args}, Integer {data} -> + Sum.to_exp typ (Sum.mul_const data args) + (* c₁ × x₁ ==> ∑ᵢ₌₁ cᵢ × xᵢ *) + | Integer {data= c}, x | x, Integer {data= c} -> + Sum.to_exp typ (Sum.singleton ~coeff:c x) + (* (∏ᵤ₌₀ⁱ xᵤ) × (∏ᵥ₌ᵢ₊₁ⁿ xᵥ) ==> + ∏ⱼ₌₀ⁿ xⱼ *) + | AppN {op= Mul _ as mul; args= xs1}, AppN {op= Mul _; args= xs2} -> + AppN {op= mul; args= Prod.union xs1 xs2} + (* (∏ᵢ xᵢ) × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ cᵤ × + ∏ᵢ xᵢ × ∏ⱼ yᵤⱼ *) + | AppN {op= Mul _; args= prod}, (AppN {op= Add _} as poly) + |(AppN {op= Add _} as poly), AppN {op= Mul _; args= prod} -> + poly_map_monos ~f:(Prod.union prod) poly + (* x₀ × (∏ᵢ₌₁ⁿ xᵢ) ==> ∏ᵢ₌₀ⁿ xᵢ *) + | AppN {op= Mul _ as mul; args= xs1}, x + |x, AppN {op= Mul _ as mul; args= xs1} -> + AppN {op= mul; args= Prod.add x xs1} + (* e × (∑ᵤ cᵤ × ∏ⱼ yᵤⱼ) ==> ∑ᵤ e × cᵤ × ∏ⱼ + yᵤⱼ *) + | AppN {op= Add _; args}, e | e, AppN {op= Add _; args} -> + simp_add typ (Sum.map ~f:(fun m -> simp_mul2 typ e m) args) + (* x₁ × x₂ ==> ∏ᵢ₌₁² xᵢ *) + | _ -> AppN {op= Mul {typ}; args= Prod.add e (Prod.singleton f)} + +let simp_mul typ es = + (* (bas ^ pwr) × exp *) + let rec mul_pwr bas pwr exp = + if Z.is_zero pwr then exp + else mul_pwr bas (Z.pred pwr) (simp_mul2 typ bas exp) in - add_pp axs bys + let one = integer Z.one typ in + Mset.fold es ~init:one ~f:(fun bas pwr exp -> + if Z.sign pwr >= 0 then mul_pwr bas pwr exp + else simp_div exp (mul_pwr bas (Z.neg pwr) one) ) -let simp_negate typ x = simp_mul typ (integer Z.minus_one typ) x +let simp_negate typ x = simp_mul2 typ (integer Z.minus_one typ) x let simp_sub typ x y = match (x, y) with (* i - j *) | Integer {data= i; typ}, Integer {data= j} -> let bits = Option.value_exn (Typ.prim_bit_size_of typ) in - integer (Z.sub ~bits i j) typ + integer (Z.bsub ~bits i j) typ (* x - y ==> x + (-1 * y) *) - | _ -> simp_add typ x (simp_negate typ y) + | _ -> simp_add2 typ x (simp_negate typ y) let simp_cond cnd thn els = match cnd with @@ -780,7 +911,7 @@ let simp_and x y = (* i && j *) | Integer {data= i; typ}, Integer {data= j} -> let bits = Option.value_exn (Typ.prim_bit_size_of typ) in - integer (Z.logand ~bits i j) typ + integer (Z.blogand ~bits i j) typ (* e && true ==> e *) | Integer {data; typ= Integer {bits= 1}}, e |e, Integer {data; typ= Integer {bits= 1}} @@ -791,19 +922,16 @@ let simp_and x y = |_, (Integer {data; typ= Integer {bits= 1}} as f) when Z.is_false data -> f - | _ -> - let ord = compare x y in - (* e && e ==> e *) - if ord = 0 then x - else if ord < 0 then App {op= App {op= And; arg= x}; arg= y} - else App {op= App {op= And; arg= y}; arg= x} + (* e && e ==> e *) + | _ when equal x y -> x + | _ -> App {op= App {op= And; arg= x}; arg= y} let simp_or x y = match (x, y) with (* i || j *) | Integer {data= i; typ}, Integer {data= j} -> let bits = Option.value_exn (Typ.prim_bit_size_of typ) in - integer (Z.logor ~bits i j) typ + integer (Z.blogor ~bits i j) typ (* e || true ==> true *) | (Integer {data; typ= Integer {bits= 1}} as t), _ |_, (Integer {data; typ= Integer {bits= 1}} as t) @@ -814,12 +942,9 @@ let simp_or x y = |e, Integer {data; typ= Integer {bits= 1}} when Z.is_false data -> e - | _ -> - let ord = compare x y in - (* e || e ==> e *) - if ord = 0 then x - else if ord < 0 then App {op= App {op= Or; arg= x}; arg= y} - else App {op= App {op= Or; arg= y}; arg= x} + (* e || e ==> e *) + | _ when equal x y -> x + | _ -> App {op= App {op= Or; arg= x}; arg= y} let rec simp_not (typ : Typ.t) exp = match (exp, typ) with @@ -866,44 +991,20 @@ let rec simp_not (typ : Typ.t) exp = App {op= App {op= Xor; arg= integer (Z.of_bool true) typ}; arg= e} and simp_eq x y = - let coeff_sign = function - | Integer {data} | App {op= App {op= Mul _}; arg= Integer {data}} -> - Z.sign data - | _ -> 1 - in - match - match (x, y) with - | (App {op= App {op= Add {typ} | Mul {typ}}} | Integer {typ}), _ - |_, (App {op= App {op= Add {typ} | Mul {typ}}} | Integer {typ}) -> ( - match simp_sub typ x y with - (* x = y ==> x' = -y' where x-y = x' + y' *) - | App {op= App {op= Add {typ}; arg= x'}; arg= y'} -> - if coeff_sign y' < 0 then (x', simp_negate typ y') - else (simp_negate typ x', y') - (* x = y ==> x-y = 0 *) - | x_y -> - let x_y = - if coeff_sign x_y < 0 then simp_negate typ x_y else x_y - in - (x_y, integer Z.zero typ) ) - | _ -> (x, y) - with + match (x, y) with | Integer {data= i; typ}, Integer {data= j} -> let bits = Option.value_exn (Typ.prim_bit_size_of typ) in (* i = j *) - bool (Z.eq ~bits i j) + bool (Z.beq ~bits i j) | b, Integer {data; typ= Integer {bits= 1}} |Integer {data; typ= Integer {bits= 1}}, b -> if Z.is_false data then (* b = false ==> ¬b *) simp_not Typ.bool b else (* b = true ==> b *) b - | x, y -> - let ord = compare x y in - (* e = e ==> true *) - if ord = 0 then bool true - else if ord < 0 then App {op= App {op= Eq; arg= x}; arg= y} - else App {op= App {op= Eq; arg= y}; arg= x} + (* e = e ==> true *) + | x, y when equal x y -> bool true + | x, y -> App {op= App {op= Eq; arg= x}; arg= y} and simp_dq x y = match simp_eq x y with @@ -916,25 +1017,22 @@ let simp_xor x y = (* i xor j *) | Integer {data= i; typ}, Integer {data= j} -> let bits = Option.value_exn (Typ.prim_bit_size_of typ) in - integer (Z.logxor ~bits i j) typ + integer (Z.blogxor ~bits i j) typ (* true xor b ==> ¬b *) | Integer {data; typ= Integer {bits= 1}}, b |b, Integer {data; typ= Integer {bits= 1}} when Z.is_true data -> simp_not Typ.bool b - | _ -> - let ord = compare x y in - if ord <= 0 then App {op= App {op= Xor; arg= x}; arg= y} - else App {op= App {op= Xor; arg= y}; arg= x} + | _ -> App {op= App {op= Xor; arg= x}; arg= y} let simp_shl x y = match (x, y) with (* i shl j *) | Integer {data= i; typ}, Integer {data= j} when Z.fits_int j -> let bits = Option.value_exn (Typ.prim_bit_size_of typ) in - integer (Z.shift_left ~bits i (Z.to_int j)) typ + integer (Z.bshift_left ~bits i (Z.to_int j)) typ (* e shl 0 ==> e *) - | e, Integer {data} when Z.equal Z.zero data -> e + | e, Integer {data} when Z.is_zero data -> e | _ -> App {op= App {op= Shl; arg= x}; arg= y} let simp_lshr x y = @@ -942,9 +1040,9 @@ let simp_lshr x y = (* i lshr j *) | Integer {data= i; typ}, Integer {data= j} when Z.fits_int j -> let bits = Option.value_exn (Typ.prim_bit_size_of typ) in - integer (Z.shift_right_trunc ~bits i (Z.to_int j)) typ + integer (Z.bshift_right_trunc ~bits i (Z.to_int j)) typ (* e lshr 0 ==> e *) - | e, Integer {data} when Z.equal Z.zero data -> e + | e, Integer {data} when Z.is_zero data -> e | _ -> App {op= App {op= Lshr; arg= x}; arg= y} let simp_ashr x y = @@ -952,50 +1050,35 @@ let simp_ashr x y = (* i ashr j *) | Integer {data= i; typ}, Integer {data= j} when Z.fits_int j -> let bits = Option.value_exn (Typ.prim_bit_size_of typ) in - integer (Z.shift_right ~bits i (Z.to_int j)) typ + integer (Z.bshift_right ~bits i (Z.to_int j)) typ (* e ashr 0 ==> e *) - | e, Integer {data} when Z.equal Z.zero data -> e + | e, Integer {data} when Z.is_zero data -> e | _ -> App {op= App {op= Ashr; arg= x}; arg= y} -let simp_div x y = - match (x, y) with - (* i / j *) - | Integer {data= i; typ}, Integer {data= j} -> - let bits = Option.value_exn (Typ.prim_bit_size_of typ) in - integer (Z.div ~bits i j) typ - (* e / 1 ==> e *) - | e, Integer {data} when Z.equal Z.one data -> e - | _ -> App {op= App {op= Div; arg= x}; arg= y} - -let simp_udiv x y = - match (x, y) with - (* i u/ j *) - | Integer {data= i; typ}, Integer {data= j} -> - let bits = Option.value_exn (Typ.prim_bit_size_of typ) in - integer (Z.udiv ~bits i j) typ - (* e u/ 1 ==> e *) - | Integer {data}, e when Z.equal Z.one data -> e - | _ -> App {op= App {op= Udiv; arg= x}; arg= y} +(** Access *) -let simp_rem x y = - match (x, y) with - (* i % j *) - | Integer {data= i; typ}, Integer {data= j} -> - let bits = Option.value_exn (Typ.prim_bit_size_of typ) in - integer (Z.rem ~bits i j) typ - (* e % 1 ==> 0 *) - | _, Integer {data; typ} when Z.equal Z.one data -> integer Z.zero typ - | _ -> App {op= App {op= Rem; arg= x}; arg= y} +let iter e ~f = + match e with + | App {op; arg} -> f op ; f arg + | AppN {op; args} -> + f op ; + Mset.iter ~f:(fun arg _ -> f arg) args + | _ -> () -let simp_urem x y = - match (x, y) with - (* i u% j *) - | Integer {data= i; typ}, Integer {data= j} -> - let bits = Option.value_exn (Typ.prim_bit_size_of typ) in - integer (Z.urem ~bits i j) typ - (* e u% 1 ==> 0 *) - | _, Integer {data; typ} when Z.equal Z.one data -> integer Z.zero typ - | _ -> App {op= App {op= Urem; arg= x}; arg= y} +let fold e ~init:s ~f = + match e with + | App {op; arg} -> + let s = f s op in + let s = f s arg in + s + | AppN {op; args} -> + let s = f s op in + let s = Mset.fold ~f:(fun e _ s -> f s e) args ~init:s in + s + | _ -> s + +let for_all e ~f = fold ~f:(fun so_far a -> so_far && f a) ~init:true e +let exists e ~f = fold ~f:(fun found a -> found || f a) ~init:false e let app1 ?(partial = false) op arg = ( match (op, arg) with @@ -1011,8 +1094,6 @@ let app1 ?(partial = false) op arg = | App {op= Ule; arg= x}, y -> simp_ule x y | App {op= Ord; arg= x}, y -> simp_ord x y | App {op= Uno; arg= x}, y -> simp_uno x y - | App {op= Add {typ}; arg= x}, y -> simp_add typ x y - | App {op= Mul {typ}; arg= x}, y -> simp_mul typ x y | App {op= Div; arg= x}, y -> simp_div x y | App {op= Udiv; arg= x}, y -> simp_udiv x y | App {op= Rem; arg= x}, y -> simp_rem x y @@ -1027,10 +1108,38 @@ let app1 ?(partial = false) op arg = | Convert {signed; dst; src}, x -> simp_convert signed dst src x | _ -> App {op; arg} ) |> check (invariant ~partial) + |> check (fun e -> + (* every App subexp of output appears in input *) + match op with + | App {op= Eq | Dq} -> () + | _ -> + iter e ~f:(function + | App _ as a -> + assert ( + equal a op || equal a arg + || Trace.report + "simplifying %a %a yields %a with new subexp %a" + pp op pp arg pp e pp a ) + | _ -> () ) ) let app2 op x y = app1 (app1 ~partial:true op x) y let app3 op x y z = app1 (app1 ~partial:true (app1 ~partial:true op x) y) z -let appN op xs = List.fold xs ~init:op ~f:app1 + +let appN op args = + ( match op with + | Add {typ} -> simp_add typ args + | Mul {typ} -> simp_mul typ args + | _ -> AppN {op; args} ) + |> check invariant + +let check1 op typ x = + type_check typ x ; + op typ x |> check invariant + +let check2 op typ x y = + type_check2 typ x y ; + op typ x y |> check invariant + let splat ~byt ~siz = app2 Splat byt siz let memory ~siz ~arr = app2 Memory siz arr let concat = app2 Concat @@ -1046,17 +1155,10 @@ let ult = app2 Ult let ule = app2 Ule let ord = app2 Ord let uno = app2 Uno - -let add typ x y = - type_check2 typ x y ; - app2 (Add {typ}) x y - -let sub typ x y = type_check2 typ x y ; simp_sub typ x y - -let mul typ x y = - type_check2 typ x y ; - app2 (Mul {typ}) x y - +let neg = check1 simp_negate +let add = check2 simp_add2 +let sub = check2 simp_sub +let mul = check2 simp_mul2 let div = app2 Div let udiv = app2 Udiv let rem = app2 Rem @@ -1064,12 +1166,12 @@ let urem = app2 Urem let and_ = app2 And let or_ = app2 Or let xor = app2 Xor -let not_ typ x = type_check typ x ; simp_not typ x +let not_ = check1 simp_not let shl = app2 Shl let lshr = app2 Lshr let ashr = app2 Ashr let conditional ~cnd ~thn ~els = app3 Conditional cnd thn els -let record elts = appN Record elts |> check invariant +let record elts = List.fold ~f:app1 ~init:Record elts let select ~rcd ~idx = app2 Select rcd idx let update ~rcd ~elt ~idx = app3 Update rcd elt idx @@ -1098,24 +1200,7 @@ let struct_rec key = let convert ?(signed = false) ~dst ~src exp = app1 (Convert {signed; dst; src}) exp -(** Access *) - -let fold e ~init:z ~f = - match e with - | App {op; arg; _} -> - let z = f z op in - let z = f z arg in - z - | _ -> z - -let fold_map e ~init:z ~f = - match e with - | App {op; arg} -> - let z, op' = f z op in - let z, arg' = f z arg in - if op' == op && arg' == arg then (z, e) - else (z, app1 ~partial:true op' arg') - | _ -> (z, e) +(** Transform *) let map e ~f = match e with @@ -1123,9 +1208,28 @@ let map e ~f = let op' = f op in let arg' = f arg in if op' == op && arg' == arg then e else app1 ~partial:true op' arg' + | AppN {op; args} -> + let op' = f op in + let args' = Mset.map ~f:(fun arg z -> (f arg, z)) args in + if op' == op && args' == args then e else appN op' args' | _ -> e -(** Update *) +let fold_map e ~init:s ~f = + match e with + | App {op; arg} -> + let s, op' = f s op in + let s, arg' = f s arg in + if op' == op && arg' == arg then (s, e) + else (s, app1 ~partial:true op' arg') + | AppN {op; args} -> + let s, op' = f s op in + let args', s = + Mset.fold_map args ~init:s ~f:(fun x z s -> + let s, x' = f s x in + (x', z, s) ) + in + if op' == op && args' == args then (s, e) else (s, appN op' args') + | _ -> (s, e) let rename e sub = let rec rename_ e sub = @@ -1135,6 +1239,55 @@ let rename e sub = in rename_ e sub |> check (invariant ~partial:true) +(** Destruct *) + +let offset e = + ( match e with + | AppN {op= Add {typ}; args} -> + let offset = Mset.count args (integer Z.one typ) in + if Z.is_zero offset then None else Some (offset, typ) + | _ -> None ) + |> check (function + | Some (k, _) -> assert (not (Z.is_zero k)) + | None -> () ) + +let base e = + ( match e with + | AppN {op= Add {typ} as op; args} -> ( + let args = Mset.remove args (integer Z.one typ) in + match Mset.length args with + | 0 -> integer Z.zero typ + | 1 -> ( + match Mset.min_elt args with + | Some (arg, z) when Z.is_one z -> arg + | _ -> AppN {op; args} ) + | _ -> AppN {op; args} ) + | _ -> e ) + |> check (invariant ~partial:true) + +let base_offset e = + ( match e with + | AppN {op= Add {typ} as op; args} -> ( + match Mset.count_and_remove args (integer Z.one typ) with + | Some (offset, args) -> + let base = + match Mset.length args with + | 0 -> integer Z.zero typ + | 1 -> ( + match Mset.min_elt args with + | Some (arg, z) when Z.is_one z -> arg + | _ -> AppN {op; args} ) + | _ -> AppN {op; args} + in + Some (base, offset, typ) + | None -> None ) + | _ -> None ) + |> check (function + | Some (b, k, _) -> + invariant b ; + assert (not (Z.is_zero k)) + | None -> () ) + (** Query *) let is_true = function @@ -1145,7 +1298,11 @@ let is_false = function | Integer {data; typ= Integer {bits= 1}} -> Z.is_false data | _ -> false +let is_simple = function App _ | AppN _ -> false | _ -> true + let rec is_constant = function | Var _ | Nondet _ -> false | App {op; arg} -> is_constant arg && is_constant op + | AppN {op; args} -> + Mset.for_all ~f:(fun arg _ -> is_constant arg) args && is_constant op | _ -> true diff --git a/sledge/src/llair/exp.mli b/sledge/src/llair/exp.mli index 91757ef1f..75c2121ba 100644 --- a/sledge/src/llair/exp.mli +++ b/sledge/src/llair/exp.mli @@ -21,9 +21,14 @@ treated as atomic since, as they are recursive, doing otherwise would require inductive reasoning. *) -type t = private +type comparator_witness + +type mset = (t, comparator_witness) Mset.t + +and t = private | App of {op: t; arg: t} (** Application of function symbol to argument, curried *) + | AppN of {op: t; args: mset} | Var of {id: int; name: string} (** Local variable / virtual register *) | Nondet of {msg: string} (** Anonymous local variable with arbitrary value, representing @@ -72,9 +77,9 @@ type t = private (** Convert between specified types, possibly with loss of information *) [@@deriving compare, hash, sexp] -type exp = t +val comparator : (t, comparator_witness) Comparator.t -include Comparator.S with type t := t +type exp = t val equal : t -> t -> bool val sort : t -> t -> t * t @@ -150,6 +155,7 @@ val ult : t -> t -> t val ule : t -> t -> t val ord : t -> t -> t val uno : t -> t -> t +val neg : Typ.t -> t -> t val add : Typ.t -> t -> t -> t val sub : Typ.t -> t -> t -> t val mul : Typ.t -> t -> t -> t @@ -182,16 +188,30 @@ val struct_rec : val convert : ?signed:bool -> dst:Typ.t -> src:Typ.t -> t -> t +(** Destruct *) + +val base_offset : t -> (t * Z.t * Typ.t) option +(** Decompose an addition of a constant "offset" to a "base" exp. *) + +val base : t -> t +(** Like [base_offset] but does not construct the "offset" exp. *) + +val offset : t -> (Z.t * Typ.t) option +(** Like [base_offset] but does not construct the "base" exp. *) + (** Access *) +val iter : t -> f:(t -> unit) -> unit val fold_vars : t -> init:'a -> f:('a -> Var.t -> 'a) -> 'a val fold_exps : t -> init:'a -> f:('a -> t -> 'a) -> 'a val fold : t -> init:'a -> f:('a -> t -> 'a) -> 'a -val fold_map : t -> init:'a -> f:('a -> t -> 'a * t) -> 'a * t -val map : t -> f:(t -> t) -> t +val for_all : t -> f:(t -> bool) -> bool +val exists : t -> f:(t -> bool) -> bool -(** Update *) +(** Transform *) +val map : t -> f:(t -> t) -> t +val fold_map : t -> init:'a -> f:('a -> t -> 'a * t) -> 'a * t val rename : t -> Var.Subst.t -> t (** Query *) @@ -199,4 +219,5 @@ val rename : t -> Var.Subst.t -> t val fv : t -> Var.Set.t val is_true : t -> bool val is_false : t -> bool +val is_simple : t -> bool val is_constant : t -> bool diff --git a/sledge/src/llair/exp_test.ml b/sledge/src/llair/exp_test.ml index e84c32fce..e98f958dc 100644 --- a/sledge/src/llair/exp_test.ml +++ b/sledge/src/llair/exp_test.ml @@ -7,8 +7,9 @@ let%test_module _ = ( module struct + (* let () = Trace.init ~margin:68 ~config:all () *) let () = Trace.init ~margin:68 ~config:none () - let pp = Format.printf "%t%a@." (fun _ -> Trace.flush ()) Exp.pp + let pp = Format.printf "@\n%a@." Exp.pp let char = Typ.integer ~bits:8 let ( ! ) i = Exp.integer (Z.of_int i) char let ( + ) = Exp.add char @@ -63,16 +64,6 @@ let%test_module _ = pp (!(-128) || !127) ; [%expect {| -1 |}] - let%test "monomial coefficient must be toplevel" = - match !7 * z * (!2 * y) with - | App {op= App {op= Mul _}; arg= Integer _} -> true - | _ -> false - - let%test "polynomial constant must be toplevel" = - match (!13 * z) + !42 + (!3 * y) with - | App {op= App {op= Add _}; arg= Integer _} -> true - | _ -> false - let%expect_test _ = pp (z + !42 + !13) ; [%expect {| (%z_2 + 55) |}] @@ -87,11 +78,11 @@ let%test_module _ = let%expect_test _ = pp (y * z * y) ; - [%expect {| (%y_1 × %y_1 × %z_2) |}] + [%expect {| (%y_1^2 × %z_2) |}] let%expect_test _ = pp ((!2 * z * z) + (!3 * z) + !4) ; - [%expect {| ((2 × %z_2 × %z_2) + (3 × %z_2) + 4) |}] + [%expect {| (2 × (%z_2^2) + 3 × %z_2 + 4) |}] let%expect_test _ = pp @@ -104,9 +95,9 @@ let%test_module _ = + (!9 * z * z * z) ) ; [%expect {| - ((7 × %y_1 × %y_1 × %z_2) + (8 × %y_1 × %z_2 × %z_2) - + (9 × %z_2 × %z_2 × %z_2) + (5 × %y_1 × %y_1) + (6 × %y_1 × %z_2) - + (4 × %z_2 × %z_2) + (3 × %y_1) + (2 × %z_2) + 1) |}] + (6 × (%y_1 × %z_2) + 8 × (%y_1 × %z_2^2) + 5 × (%y_1^2) + + 7 × (%y_1^2 × %z_2) + 4 × (%z_2^2) + 9 × (%z_2^3) + 3 × %y_1 + + 2 × %z_2 + 1) |}] let%expect_test _ = pp (!0 * z * y) ; @@ -118,15 +109,15 @@ let%test_module _ = let%expect_test _ = pp (!7 * z * (!2 * y)) ; - [%expect {| (14 × %y_1 × %z_2) |}] + [%expect {| (14 × (%y_1 × %z_2)) |}] let%expect_test _ = pp (!13 + (!42 * z)) ; - [%expect {| ((42 × %z_2) + 13) |}] + [%expect {| (42 × %z_2 + 13) |}] let%expect_test _ = pp ((!13 * z) + !42) ; - [%expect {| ((13 × %z_2) + 42) |}] + [%expect {| (13 × %z_2 + 42) |}] let%expect_test _ = pp ((!2 * z) - !3 + ((!(-2) * z) + !3)) ; @@ -134,32 +125,31 @@ let%test_module _ = let%expect_test _ = pp ((!3 * y) + (!13 * z) + !42) ; - [%expect {| ((3 × %y_1) + (13 × %z_2) + 42) |}] + [%expect {| (3 × %y_1 + 13 × %z_2 + 42) |}] let%expect_test _ = pp ((!13 * z) + !42 + (!3 * y)) ; - [%expect {| ((3 × %y_1) + (13 × %z_2) + 42) |}] + [%expect {| (3 × %y_1 + 13 × %z_2 + 42) |}] let%expect_test _ = pp ((!13 * z) + !42 + (!3 * y) + (!2 * z)) ; - [%expect {| ((3 × %y_1) + (15 × %z_2) + 42) |}] + [%expect {| (3 × %y_1 + 15 × %z_2 + 42) |}] let%expect_test _ = pp ((!13 * z) + !42 + (!3 * y) + (!(-13) * z)) ; - [%expect {| ((3 × %y_1) + 42) |}] + [%expect {| (3 × %y_1 + 42) |}] let%expect_test _ = pp (z + !42 + ((!3 * y) + (!(-1) * z))) ; - [%expect {| ((3 × %y_1) + 42) |}] + [%expect {| (3 × %y_1 + 42) |}] let%expect_test _ = pp (!(-1) * (z + (!(-1) * y))) ; - [%expect {| (%y_1 + (-1 × %z_2)) |}] + [%expect {| (%y_1 + -1 × %z_2) |}] let%expect_test _ = pp (((!3 * y) + !2) * (!4 + (!5 * z))) ; - [%expect - {| ((15 × %y_1 × %z_2) + (12 × %y_1) + (10 × %z_2) + 8) |}] + [%expect {| (15 × (%y_1 × %z_2) + 12 × %y_1 + 10 × %z_2 + 8) |}] let%expect_test _ = pp (((!2 * z) - !3 + ((!(-2) * z) + !3)) * (!4 + (!5 * z))) ; @@ -167,11 +157,11 @@ let%test_module _ = let%expect_test _ = pp ((!13 * z) + !42 - ((!3 * y) + (!13 * z))) ; - [%expect {| ((-3 × %y_1) + 42) |}] + [%expect {| (-3 × %y_1 + 42) |}] let%expect_test _ = pp (z = y) ; - [%expect {| (%y_1 = %z_2) |}] + [%expect {| (%z_2 = %y_1) |}] let%expect_test _ = pp (z = z) ; @@ -203,55 +193,55 @@ let%test_module _ = let%expect_test _ = pp (y - (!(-3) * y) + !4) ; - [%expect {| ((4 × %y_1) + 4) |}] + [%expect {| (4 × %y_1 + 4) |}] let%expect_test _ = pp ((!(-3) * y) + !4 - y) ; - [%expect {| ((-4 × %y_1) + 4) |}] + [%expect {| (-4 × %y_1 + 4) |}] let%expect_test _ = pp (y = (!(-3) * y) + !4) ; - [%expect {| ((4 × %y_1) = 4) |}] + [%expect {| (%y_1 = (-3 × %y_1 + 4)) |}] let%expect_test _ = pp ((!(-3) * y) + !4 = y) ; - [%expect {| ((4 × %y_1) = 4) |}] + [%expect {| ((-3 × %y_1 + 4) = %y_1) |}] let%expect_test _ = pp (Exp.sub Typ.bool (Exp.bool true) (z = !4)) ; - [%expect {| ((%z_2 = 4) + -1) |}] + [%expect {| (-1 × (%z_2 = 4) + -1) |}] let%expect_test _ = pp (Exp.add Typ.bool (Exp.bool true) (z = !4) = (z = !4)) ; - [%expect {| 0 |}] + [%expect {| (((%z_2 = 4) + -1) = (%z_2 = 4)) |}] let%expect_test _ = pp ((!13 * z) + !42 = (!3 * y) + (!13 * z)) ; - [%expect {| ((3 × %y_1) = 42) |}] + [%expect {| ((13 × %z_2 + 42) = (3 × %y_1 + 13 × %z_2)) |}] let%expect_test _ = pp ((!13 * z) + !(-42) = (!3 * y) + (!13 * z)) ; - [%expect {| ((-3 × %y_1) = 42) |}] + [%expect {| ((13 × %z_2 + -42) = (3 × %y_1 + 13 × %z_2)) |}] let%expect_test _ = pp ((!13 * z) + !42 = (!(-3) * y) + (!13 * z)) ; - [%expect {| ((-3 × %y_1) = 42) |}] + [%expect {| ((13 × %z_2 + 42) = (-3 × %y_1 + 13 × %z_2)) |}] let%expect_test _ = pp ((!10 * z) + !42 = (!(-3) * y) + (!13 * z)) ; - [%expect {| (((-3 × %y_1) + (3 × %z_2)) = 42) |}] + [%expect {| ((10 × %z_2 + 42) = (-3 × %y_1 + 13 × %z_2)) |}] let%expect_test _ = pp ~~((!13 * z) + !(-42) != (!3 * y) + (!13 * z)) ; - [%expect {| ((-3 × %y_1) = 42) |}] + [%expect {| ((13 × %z_2 + -42) = (3 × %y_1 + 13 × %z_2)) |}] let%expect_test _ = pp ~~(y > !2 && z <= !3) ; - [%expect {| ((%z_2 > 3) || (%y_1 <= 2)) |}] + [%expect {| ((%y_1 <= 2) || (%z_2 > 3)) |}] let%expect_test _ = pp ~~(y >= !2 || z < !3) ; - [%expect {| ((%z_2 >= 3) && (%y_1 < 2)) |}] + [%expect {| ((%y_1 < 2) && (%z_2 >= 3)) |}] let%expect_test _ = pp Exp.(eq z null) ; @@ -261,7 +251,7 @@ let%test_module _ = {| (%z_2 = null) - (%z_2 = null) + (null = %z_2) - (%z_2 = null) |}] + (null = %z_2) |}] end ) diff --git a/sledge/src/symbheap/congruence.ml b/sledge/src/symbheap/congruence.ml index 170a0d268..c7e974b0f 100644 --- a/sledge/src/symbheap/congruence.ml +++ b/sledge/src/symbheap/congruence.ml @@ -7,16 +7,65 @@ (** Congruence closure with integer offsets *) -(* For background, see: - - Robert Nieuwenhuis, Albert Oliveras: Fast congruence closure and - extensions. Inf. Comput. 205(4): 557-580 (2007) - - and, for a more detailed correctness proof of the case without integer - offsets, see section 5 of: - - Aleksandar Nanevski, Viktor Vafeiadis, Josh Berdine: Structuring the - verification of heap-manipulating programs. POPL 2010: 261-274 *) +(** For background, see: + + Robert Nieuwenhuis, Albert Oliveras: Fast congruence closure and + extensions. Inf. Comput. 205(4): 557-580 (2007) + + and, for a more detailed correctness proof of the case without integer + offsets, see section 5 of: + + Aleksandar Nanevski, Viktor Vafeiadis, Josh Berdine: Structuring the + verification of heap-manipulating programs. POPL 2010: 261-274 *) + +(** Lazy flattening: + + The congruence closure data structure is used to lazily flatten + expressions. Flattening expressions gives each compound expression (e.g. + an application) a "name", which is treated as an atomic symbol. In the + background papers, fresh symbols are introduced to name compound + expressions in a pre-processing pass, but here we do not a priori know + the carrier (set of all expressions equations might relate). Instead, we + use the expression itself as its "name" and use the representative map + to record this naming. That is, if [f(a+i)] is in the domain of the + representative map, then "f(a+i)" is the name of the compound expression + [f(a+i)]. If [f(a+i)] is not in the domain of the representative map, + then it is not yet in the "carrier" of the relation. Adding it to the + carrier, which logically amounts to adding the equation [f(a+i) = + "f(a+i)"], extends the representative map, as well as the lookup map and + use lists, after which [f(a+i)] can be used as if it was a simple symbol + name for the compound expression. + + Note that merging a compound equation of the form [f(a+i)+j = b+k] + results in naming [f(a+i)] and then merging the simple equation + ["f(a+i)"+j = b+k] normalized to ["f(a+i)" = b+(k-j)]. In particular, + every equation is either of the form ["f(a+i)" = f(a+i)] or of the form + [a = b+i], but not of the form [f(a+i) = b+j]. + + By the same reasoning, the range of the lookup table does not need + offsets, as every exp in the range of the lookup table will be the name + of a compound exp. + + A consequence of lazy flattening is that the equations stored in the + lookup table, use lists, and pending equation list in the background + papers are here all of the form [f(a+i) = "f(a+i)"], and hence are + represented by the application expression itself. + + Sparse carrier: + + For symbols, that is expressions that are not compound [App]lications, + there are no cooperative invariants between components of the data + structure that need to be established. So adding a symbol to the carrier + would amount to adding an identity association to the representatives + map. Since we need to use a map instead of an array anyhow, this can be + represented sparsely by omitting identity associations in the + representative map. Note that identity associations for compound + expressions are still needed to record which compound expressions have + been added to the carrier. + + Notation: + + - often use identifiers such as [a'] for the representative of [a] *) (** set of exps representing congruence classes *) module Cls = struct @@ -33,8 +82,10 @@ module Cls = struct let remove_exn = List.remove_exn let union = List.rev_append let fold_map = List.fold_map + let iter = List.iter let is_empty = List.is_empty let length = List.length + let map = List.map_preserving_phys_equal let mem = List.mem ~equal:Exp.equal end @@ -44,70 +95,71 @@ module Use = struct let equal = [%compare.equal: t] - let pp fs use = - Format.fprintf fs "@[{@[%a@]}@]" (List.pp ",@ " Exp.pp) use + let pp fs uses = + Format.fprintf fs "@[{@[%a@]}@]" (List.pp ",@ " Exp.pp) uses let empty = [] - let singleton exp = [exp] - let add use exp = exp :: use + let singleton use = [use] + let add uses use = use :: uses let union = List.rev_append let fold = List.fold + let iter = List.iter + let exists = List.exists let is_empty = List.is_empty + let map = List.map_preserving_phys_equal end type 'a exp_map = 'a Map.M(Exp).t [@@deriving compare, sexp] -let empty_map = Map.empty (module Exp) - +(** see also [invariant] *) type t = - { sat: bool (** [false] if constraints are inconsistent *) + { sat: bool (** [false] only if constraints are inconsistent *) ; rep: Exp.t exp_map - (** map [a] to [a'+k], indicating that [a=a'+k] holds, and that [a'] - (without the offset [k]) is the 'rep(resentative)' of [a] *) + (** map [a] to [a'+k], indicating that [a = a'+k] holds, and that + [a'] (without the offset [k]) is the 'rep(resentative)' of [a] *) ; lkp: Exp.t exp_map - (** inverse of mapping rep over sub-expressions: map [f'(a'+i)] to - [f(a+j)+k], an (offsetted) app(lication expression) in the - relation which normalizes to one in the 'equivalence modulo - offset' class of [f'(a'+i)], indicating that [f'(a'+i) = - f(a+j)+k] holds, for some [k] where [rep f = f'] and [rep a = - a'+(i-j)] *) + (** map [f'(a'+i)] to [f(a+j)], indicating that [f'(a'+i) = f(a+j)] + holds, where [f(a+j)] is in the carrier *) ; cls: Cls.t exp_map (** inverse rep: map each rep [a'] to all the [a+k] in its class, - i.e., [cls a' = {a+k | rep a = a'+(-k)}] *) + i.e., [cls a' = {a+(-k) | rep a = a'+k}] *) ; use: Use.t exp_map (** super-expression relation for representatives: map each - representative [a'] of [a] to the application expressions in the - relation where [a] (possibly + an offset) appears as an - immediate sub-expression *) + representative [a'] of [a] to the compound expressions in the + carrier where [a] (possibly + an offset) appears as an immediate + sub-expression *) ; pnd: (Exp.t * Exp.t) list - (** equations to be added to the relation, to enable delaying adding - equations discovered while invariants are temporarily broken *) - } + (** equations of the form [a+i = b+j], where [a] and [b] are in the + carrier, to be added to the relation by merging the classes of + [a] and [b] *) } [@@deriving compare, sexp] -(** The expressions in the range of [lkp] and [use], as well as those in - [pnd], are 'in the relation' in the sense that there is some constraint - involving them, and in practice are expressions which have been passed - to [merge] as opposed to having been constructed internally. *) +(** Pretty-printing *) let pp_eq fs (e, f) = Format.fprintf fs "@[%a = %a@]" Exp.pp e Exp.pp f let pp fs {sat; rep; lkp; cls; use; pnd} = let pp_alist pp_k pp_v fs alist = let pp_assoc fs (k, v) = - Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_k k pp_v v + Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_k k pp_v (k, v) in Format.fprintf fs "[@[%a@]]" (List.pp ";@ " pp_assoc) alist in - let pp_pnd fs pnd = - if not (List.is_empty pnd) then - Format.fprintf fs ";@ pnd= [@[%a@]];" (List.pp ";@ " pp_eq) pnd - in + let pp_exp_v fs (k, v) = if not (Exp.equal k v) then Exp.pp fs v in + let pp_cls_v fs (_, v) = Cls.pp fs v in + let pp_use_v fs (_, v) = Use.pp fs v in Format.fprintf fs "@[{@[sat= %b;@ rep= %a;@ lkp= %a;@ cls= %a;@ use= %a%a@]}@]" sat - (pp_alist Exp.pp Exp.pp) (Map.to_alist rep) (pp_alist Exp.pp Exp.pp) - (Map.to_alist lkp) (pp_alist Exp.pp Cls.pp) (Map.to_alist cls) - (pp_alist Exp.pp Use.pp) (Map.to_alist use) pp_pnd pnd + (pp_alist Exp.pp pp_exp_v) + (Map.to_alist rep) + (pp_alist Exp.pp pp_exp_v) + (Map.to_alist lkp) + (pp_alist Exp.pp pp_cls_v) + (Map.to_alist cls) + (pp_alist Exp.pp pp_use_v) + (Map.to_alist use) + (List.pp ~pre:";@ pnd= [@[" ";@ " pp_eq ~suf:"@]];") + pnd let pp_classes fs {cls} = List.pp "@ @<2>∧ " @@ -138,32 +190,41 @@ let pp_diff fs (r, s) = let pp_sdiff_exps fs (c, d) = pp_sdiff_list "" Exp.pp Exp.compare fs (c, d) in - let pp_sdiff_elt pp_val pp_sdiff_val fs = function + let pp_sdiff_uses fs (c, d) = + pp_sdiff_list "" Exp.pp Exp.compare fs (c, d) + in + let pp_sdiff_elt pp_key pp_val pp_sdiff_val fs = function | k, `Left v -> - Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" Exp.pp k pp_val v + Format.fprintf fs "-- [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v | k, `Right v -> - Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" Exp.pp k pp_val v + Format.fprintf fs "++ [@[%a@ @<2>↦ %a@]]" pp_key k pp_val v | k, `Unequal vv -> - Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" Exp.pp k pp_sdiff_val vv + Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_key k pp_sdiff_val vv in let pp_sdiff_exp_map = let pp_sdiff_exp fs (u, v) = Format.fprintf fs "-- %a ++ %a" Exp.pp u Exp.pp v in - pp_sdiff_map (pp_sdiff_elt Exp.pp pp_sdiff_exp) Exp.equal + pp_sdiff_map (pp_sdiff_elt Exp.pp Exp.pp pp_sdiff_exp) Exp.equal + in + let pp_sdiff_app_map = + let pp_sdiff_app fs (u, v) = + Format.fprintf fs "-- %a ++ %a" Exp.pp u Exp.pp v + in + pp_sdiff_map (pp_sdiff_elt Exp.pp Exp.pp pp_sdiff_app) Exp.equal in let pp_sat fs = if not (Bool.equal r.sat s.sat) then Format.fprintf fs "sat= @[-- %b@ ++ %b@];@ " r.sat s.sat in let pp_rep fs = pp_sdiff_exp_map "rep" fs r.rep s.rep in - let pp_lkp fs = pp_sdiff_exp_map "lkp" fs r.lkp s.lkp in + let pp_lkp fs = pp_sdiff_app_map "lkp" fs r.lkp s.lkp in let pp_cls fs = - let pp_sdiff_cls = pp_sdiff_elt Cls.pp pp_sdiff_exps in + let pp_sdiff_cls = pp_sdiff_elt Exp.pp Cls.pp pp_sdiff_exps in pp_sdiff_map pp_sdiff_cls Cls.equal "cls" fs r.cls s.cls in let pp_use fs = - let pp_sdiff_use = pp_sdiff_elt Use.pp pp_sdiff_exps in + let pp_sdiff_use = pp_sdiff_elt Exp.pp Use.pp pp_sdiff_uses in pp_sdiff_map pp_sdiff_use Use.equal "use" fs r.use s.use in let pp_pnd fs = @@ -172,46 +233,177 @@ let pp_diff fs (r, s) = Format.fprintf fs "@[{@[%t%t%t%t%t%t@]}@]" pp_sat pp_rep pp_lkp pp_cls pp_use pp_pnd -let invariant r = - Invariant.invariant [%here] r [%sexp_of: t] - @@ fun () -> - Map.iteri r.rep ~f:(fun ~key:e ~data:e' -> assert (not (Exp.equal e e'))) ; - Map.iteri r.cls ~f:(fun ~key:e' ~data:cls -> assert (Cls.mem cls e')) ; - Map.iteri r.use ~f:(fun ~key:_ ~data:use -> assert (not (Use.is_empty use)) - ) +(** Auxiliary functions for manipulating "base plus offset" expressions *) -(* Auxiliary functions for manipulating "base plus offset" expressions *) - -let map_sum e ~f = - match e with - | Exp.App {op= App {op= Add {typ}; arg= a}; arg= i} -> +(** solve a+i = b for a, yielding a = b-i *) +let solve_for_base ai b = + match Exp.base_offset ai with + | Some (a, i, typ) -> (a, Exp.sub typ b (Exp.integer i typ)) + | None -> (ai, b) + +(** subtract offset from both sides of equation a+i = b, yielding b-i *) +let subtract_offset ai b = + match Exp.offset ai with + | Some (i, typ) -> Exp.sub typ b (Exp.integer i typ) + | None -> b + +(** [map_base ~f a+i] is [f(a) + i] and [map_base ~f a] is [f(a)] *) +let map_base ai ~f = + match Exp.base_offset ai with + | Some (a, i, typ) -> let a' = f a in - if a' == a then e else Exp.add typ a' i - | a -> f a + if a' == a then ai else Exp.add typ a' (Exp.integer i typ) + | None -> f ai -let fold_sum e ~init ~f = - match e with - | Exp.App {op= App {op= Add _; arg= a}; arg= Integer _} -> f init a - | a -> f init a +(** [norm_base r a] is [a'+k] where [r] implies [a = a'+k] and [a'] is a + rep, requires [a] to not have any offset and be in the carrier *) +let norm_base r e = + assert (Option.is_none (Exp.offset e)) ; + try Map.find_exn r.rep e with Caml.Not_found -> + assert (Exp.is_simple e) ; + e -let base_of = function - | Exp.App {op= App {op= Add _; arg= a}; arg= Integer _} -> a - | a -> a +(** [norm r a+i] is [a'+k] where [r] implies [a+i = a'+k] and [a'] is a rep, + requires [a] to be in the carrier *) +let norm r e = map_base ~f:(norm_base r) e -(** solve a+i = b for a, yielding a = b-i *) -let solve_for_base ai b = - match ai with - | Exp.App {op= App {op= Add {typ}; arg= a}; arg= i} -> (a, Exp.sub typ b i) - | _ -> (ai, b) +(** test membership in carrier, strictly in the sense that an exp with an + offset is not in the carrier even when its base is *) +let in_car r e = Exp.is_simple e || Map.mem r.rep e -(** [norm r a+i] = [a'+k] where [r] implies [a+i=a'+k] and [a'] is a rep *) -let norm r e = - map_sum e ~f:(fun a -> try Map.find_exn r.rep a with Caml.Not_found -> a) +(** test if an exp is a representative, requires exp to have no offset *) +let is_rep r e = Exp.equal e (norm_base r e) -(* Core closure operations *) +let pre_invariant r = + Invariant.invariant [%here] r [%sexp_of: t] + @@ fun () -> + Map.iteri r.rep ~f:(fun ~key:a ~data:a'k -> + (* carrier is stored without offsets *) + assert (Option.is_none (Exp.offset a)) ; + (* carrier is closed under sub-expressions *) + Exp.iter a ~f:(fun bj -> + assert ( + in_car r (Exp.base bj) + || Trace.report "@[subexp %a of %a not in carrier of@ %a@]" + Exp.pp bj Exp.pp a pp r ) ) ; + let a', a_k = solve_for_base a'k a in + (* carrier is closed under rep *) + assert (in_car r a') ; + if Exp.is_simple a' then + (* rep is sparse for symbols *) + assert ( + (not (Map.mem r.rep a')) + || Trace.report + "no symbol rep should be in rep domain: %a @<2>↦ %a@\n%a" + Exp.pp a Exp.pp a' pp r ) + else + (* rep is idempotent for applications *) + assert ( + is_rep r a' + || Trace.report + "every app rep should be its own rep: %a @<2>↦ %a" Exp.pp a + Exp.pp a' ) ; + match Map.find r.cls a' with + | None -> + (* every rep in dom of cls *) + assert ( + Trace.report "rep not in dom of cls: %a@\n%a" Exp.pp a' pp r ) + | Some a_cls -> + (* every exp is in class of its rep *) + assert ( + (* rep a = a'+k so expect a-k in cls a' *) + Cls.mem a_cls a_k + || Trace.report "%a = %a by rep but %a not in cls@\n%a" Exp.pp a + Exp.pp a'k Exp.pp a_k pp r ) ) ; + Map.iteri r.cls ~f:(fun ~key:a' ~data:a_cls -> + (* domain of cls are reps *) + assert (is_rep r a') ; + (* cls contained in inverse of rep *) + Cls.iter a_cls ~f:(fun ak -> + let a, a'_k = solve_for_base ak a' in + assert ( + in_car r a + || Trace.report "%a in cls of %a but not in carrier" Exp.pp a + Exp.pp a' ) ; + let a'' = norm_base r a in + assert ( + (* a' = a+k in cls so expect rep a = a'-k *) + Exp.equal a'' a'_k + || Trace.report "%a = %a by cls but @<2>≠ %a by rep" Exp.pp a' + Exp.pp ak Exp.pp a'' ) ) ) ; + Map.iteri r.use ~f:(fun ~key:a' ~data:a_use -> + assert ( + (not (Use.is_empty a_use)) + || Trace.report "empty use list should not have been added" ) ; + Use.iter a_use ~f:(fun u -> + (* uses are applications *) + assert (not (Exp.is_simple u)) ; + (* uses have no offsets *) + assert (Option.is_none (Exp.offset u)) ; + (* subexps of uses in carrier *) + Exp.iter u ~f:(fun bj -> assert (in_car r (Exp.base bj))) ; + (* every rep is a subexp-modulo-rep of each of its uses *) + assert ( + Exp.exists u ~f:(fun bj -> Exp.equal a' (Exp.base (norm r bj))) + || Trace.report + "rep %a has use %a, but is not the rep of any immediate \ + subexp of the use" + Exp.pp a' Exp.pp u ) ; + (* every use has a corresponding entry in lkp... *) + let v = + try Map.find_exn r.lkp (Exp.map ~f:(norm r) u) + with Caml.Not_found -> + fail "no lkp entry for use %a of %a" Exp.pp u Exp.pp a' + in + (* ...which is (eventually) provably equal *) + if List.is_empty r.pnd then + assert (Exp.equal (norm_base r u) (norm_base r v)) ) ) ; + Map.iteri r.lkp ~f:(fun ~key:a ~data:c -> + (* subexps of domain of lkp in carrier *) + Exp.iter a ~f:(fun bj -> assert (in_car r (Exp.base bj))) ; + (* range of lkp are applications in carrier *) + assert (in_car r c) ; + (* there may be stale entries in lkp whose subexps are no longer reps, + which will therefore never be used, and hence are unconstrained *) + if Exp.equal a (Exp.map ~f:(norm r) a) then ( + let c_' = Exp.map ~f:(norm r) c in + (* lkp contains equalities provable modulo normalizing sub-exps *) + assert ( + Exp.equal a c_' + || Trace.report "%a sub-normalizes to %a @<2>≠ %a" Exp.pp c + Exp.pp c_' Exp.pp a ) ; + let c' = norm_base r c in + Exp.iter a ~f:(fun bj -> + (* every subexp of an app in domain of lkp has an associated use *) + let b' = Exp.base (norm r bj) in + let b_use = + try Map.find_exn r.use b' with Caml.Not_found -> + fail "no use list for subexp %a of lkp key %a" Exp.pp bj + Exp.pp a + in + assert ( + Use.exists b_use ~f:(fun u -> + Exp.equal a (Exp.map ~f:(norm r) u) + && Exp.equal c' (norm_base r u) ) + || Trace.report + "no corresponding use for subexp %a of lkp key %a" Exp.pp + bj Exp.pp a ) ) ) ) ; + List.iter r.pnd ~f:(fun (ai, bj) -> + assert (in_car r (Exp.base ai)) ; + assert (in_car r (Exp.base bj)) ) + +let invariant r = + Invariant.invariant [%here] r [%sexp_of: t] + @@ fun () -> + pre_invariant r ; + assert (List.is_empty r.pnd) + +(** Core closure operations *) type prefer = Exp.t -> over:Exp.t -> int +let empty_map = Map.empty (module Exp) + let true_ = { sat= true ; rep= empty_map @@ -219,42 +411,9 @@ let true_ = ; cls= empty_map ; use= empty_map ; pnd= [] } + |> check invariant -let false_ = {true_ with sat= false} - -(** Add app exps (and sub-exps) to the relation. This populates the [lkp] - and [use] maps, treating an exp [e] of form [f(a)] as an equation - between the app [f(a)] and the 'symbol' [e]. This has the effect of - using [e] as a 'name' of the app [f(a)], rather than using an explicit - 'flattening' transformation introducing new symbols for each - application. *) -let rec extend r e = - fold_sum e ~init:r ~f:(fun r -> function - | App _ as fa -> - let r, fa' = - Exp.fold_map fa ~init:r ~f:(fun r b -> - let r, c = extend r b in - (r, norm r c) ) - in - Map.find_or_add r.lkp fa' - ~if_found:(fun d -> - let r = {r with pnd= (e, d) :: r.pnd} in - (r, d) ) - ~default:e - ~if_added:(fun lkp -> - let use = - Exp.fold fa' ~init:r.use ~f:(fun use b' -> - if Exp.is_constant b' then use - else - Map.update use b' ~f:(function - | Some b_use -> Use.add b_use fa - | None -> Use.singleton fa ) ) - in - let r = {r with lkp; use} in - (r, e) ) - | _ -> (r, e) ) - -exception Unsat +let false_ = {true_ with sat= false} |> check invariant (** Add an equation [b+j] = [a+i] using [a] as the new rep. This removes [b] from the [cls] and [use] maps, as it is no longer a rep. The [rep] map @@ -267,9 +426,12 @@ let add_directed_equation r0 ~exp:bj ~rep:ai = [%Trace.call fun {pf} -> pf "@[%a@ %a@]@ %a" Exp.pp bj Exp.pp ai pp r0] ; let r = r0 in - let a = base_of ai in (* b+j = a+i so b = a+i-j *) - let b, ai_j = solve_for_base bj ai in + let b, aij = solve_for_base bj ai in + assert ((not (in_car r b)) || is_rep r b) ; + (* compute a from aij in case ai is an int and j is a non-0 offset *) + let a = Exp.base aij in + assert ((not (in_car r a)) || is_rep r a) ; let b_cls, cls = try Map.find_and_remove_exn r.cls b with Caml.Not_found -> (Cls.singleton b, r.cls) @@ -278,32 +440,29 @@ let add_directed_equation r0 ~exp:bj ~rep:ai = try Map.find_and_remove_exn r.use b with Caml.Not_found -> (Use.empty, r.use) in - let rep, a_cls_delta = - Cls.fold_map b_cls ~init:r.rep ~f:(fun rep ck -> + let r, a_cls_delta = + Cls.fold_map b_cls ~init:r ~f:(fun r ck -> (* c+k = b = a+i-j so c = a+i-j-k *) - let c, ai_j_k = solve_for_base ck ai_j in - if Exp.is_false (Exp.eq c ai_j_k) then raise Unsat ; - let rep = Map.set rep ~key:c ~data:ai_j_k in - (* a+i-j = c+k so a = c+k+j-i *) - let _, ckj_i = solve_for_base ai_j ck in - (rep, ckj_i) ) - in - let cls = - Map.update cls a ~f:(function - | Some a_cls -> Cls.union a_cls_delta a_cls - | None -> Cls.add a_cls_delta a ) + let c, aijk = solve_for_base ck aij in + let r = + if Exp.equal a c || Exp.is_false (Exp.eq c aijk) then false_ + else if Exp.is_simple c && Exp.equal c aijk then r + else {r with rep= Map.set r.rep ~key:c ~data:aijk} + in + (* a+i-j-k = c so a = c-i+j+k *) + let cijk = subtract_offset aijk c in + (r, cijk) ) in - let r = {r with rep; cls; use} in let r, a_use_delta = Use.fold b_use ~init:(r, Use.empty) ~f:(fun (r, a_use_delta) u -> - let u' = Exp.map ~f:(norm r) u in - Map.find_or_add r.lkp u' + let u_' = Exp.map ~f:(norm r) u in + Map.find_or_add r.lkp u_' ~if_found:(fun v -> let r = {r with pnd= (u, v) :: r.pnd} in - (* no need to add u to use a since u is an app already in r - (since u' found in r.lkp) that is equal to v, so will be in - use of u's subexps, and u = v is added to pnd so propagate - will combine them later *) + (* no need to add u to use a since u is already in r (since u_' + found in r.lkp) that is equal to v, so will be in use of u_'s + subexps, and u = v is added to pnd so propagate will combine + them later *) (r, a_use_delta) ) ~default:u ~if_added:(fun lkp -> @@ -311,6 +470,11 @@ let add_directed_equation r0 ~exp:bj ~rep:ai = let a_use_delta = Use.add a_use_delta u in (r, a_use_delta) ) ) in + let cls = + Map.update cls a ~f:(function + | Some a_cls -> Cls.union a_cls_delta a_cls + | None -> Cls.add a_cls_delta a ) + in let use = if Use.is_empty a_use_delta then use else @@ -318,105 +482,136 @@ let add_directed_equation r0 ~exp:bj ~rep:ai = | Some a_use -> Use.union a_use_delta a_use | None -> a_use_delta ) in - let r = {r with use} in - r |> check invariant + let r = if not r.sat then false_ else {r with cls; use} in + r |> - [%Trace.retn fun {pf} r -> pf "%a" pp_diff (r0, r)] + [%Trace.retn fun {pf} r -> + pf "%a" pp_diff (r0, r) ; + pre_invariant r] -(** Close the relation with the pending equations. *) -let rec propagate_ ?prefer r = - [%Trace.call fun {pf} -> pf "%a" pp r] +let prefer_rep ?prefer r e ~over:d = + [%Trace.call fun {pf} -> pf "@[%a@ %a@]@ %a" Exp.pp d Exp.pp e pp r] ; - ( match r.pnd with - | [] -> r - | (d, e) :: pnd -> - let d' = norm r d in - let e' = norm r e in - let r = {r with pnd} in - let r = - if Exp.equal (base_of d') (base_of e') then - if Exp.equal d' e' then r else {r with sat= false; pnd= []} - else - let prefer_e_over_d = - match (Exp.is_constant d, Exp.is_constant e) with - | true, false -> -1 - | false, true -> 1 - | _ -> ( - match prefer with - | Some prefer -> prefer e' ~over:d' - | None -> 0 ) - in - match prefer_e_over_d with - | n when n < 0 -> add_directed_equation r ~exp:e' ~rep:d' - | p when p > 0 -> add_directed_equation r ~exp:d' ~rep:e' - | _ -> - let len_d = - try Cls.length (Map.find_exn r.cls d') - with Caml.Not_found -> 1 - in - let len_e = - try Cls.length (Map.find_exn r.cls e') - with Caml.Not_found -> 1 - in - if len_d > len_e then add_directed_equation r ~exp:e' ~rep:d' - else add_directed_equation r ~exp:d' ~rep:e' + let prefer_e_over_d = + match (Exp.is_constant d, Exp.is_constant e) with + | true, false -> -1 + | false, true -> 1 + | _ -> ( + match prefer with Some prefer -> prefer e ~over:d | None -> 0 ) + in + ( match prefer_e_over_d with + | n when n < 0 -> false + | p when p > 0 -> true + | _ -> + let len_e = + try Cls.length (Map.find_exn r.cls e) with Caml.Not_found -> 1 + in + let len_d = + try Cls.length (Map.find_exn r.cls d) with Caml.Not_found -> 1 in - propagate_ ?prefer r ) + len_e >= len_d ) |> - [%Trace.retn fun {pf} r' -> pf "%a" pp_diff (r, r')] + [%Trace.retn fun {pf} -> pf "%b"] + +let choose_preferred ?prefer r d e = + if prefer_rep ?prefer r e ~over:d then (d, e) else (e, d) + +let add_equation ?prefer r d e = + let d = norm r d in + let e = norm r e in + if (not r.sat) || Exp.equal d e then r + else ( + [%Trace.call fun {pf} -> pf "@[%a@ %a@]@ %a" Exp.pp d Exp.pp e pp r] + ; + let exp, rep = choose_preferred ?prefer r d e in + add_directed_equation r ~exp ~rep + |> + [%Trace.retn fun {pf} r' -> pf "%a" pp_diff (r, r')] ) + +(** normalize, and add base to carrier if needed *) +let rec norm_extend r ek = + [%Trace.call fun {pf} -> pf "%a@ %a" Exp.pp ek pp r] + ; + let e = Exp.base ek in + ( if Exp.is_simple e then (r, norm r ek) + else + Map.find_or_add r.rep e + ~if_found:(fun e' -> + match Exp.offset ek with + | Some (k, typ) -> (r, Exp.add typ e' (Exp.integer k typ)) + | None -> (r, e') ) + ~default:e + ~if_added:(fun rep -> + let cls = Map.set r.cls ~key:e ~data:(Cls.singleton e) in + let r = {r with rep; cls} in + let r, e_' = Exp.fold_map ~f:norm_extend ~init:r e in + Map.find_or_add r.lkp e_' + ~if_found:(fun d -> + let pnd = (e, d) :: r.pnd in + let d' = norm_base r d in + ({r with rep; pnd}, d') ) + ~default:e + ~if_added:(fun lkp -> + let use = + Exp.fold e_' ~init:r.use ~f:(fun use b'j -> + let b' = Exp.base b'j in + Map.update use b' ~f:(function + | Some b_use -> Use.add b_use e + | None -> Use.singleton e ) ) + in + ({r with lkp; use}, e) ) ) ) + |> + [%Trace.retn fun {pf} (r', e') -> pf "%a@ %a" Exp.pp e' pp_diff (r, r')] -let propagate ?prefer r = +let norm_extend r ek = norm_extend r ek |> check (fst >> pre_invariant) +let extend r ek = fst (norm_extend r ek) + +(** Close the relation with the pending equations. *) +let rec propagate ?prefer r = [%Trace.call fun {pf} -> pf "%a" pp r] ; - (try propagate_ ?prefer r with Unsat -> false_) + ( match r.pnd with + | (d, e) :: pnd -> + let r = {r with pnd} in + let r = add_equation ?prefer r d e in + propagate ?prefer r + | [] -> r ) |> - [%Trace.retn fun {pf} r' -> pf "%a" pp_diff (r, r')] + [%Trace.retn fun {pf} r' -> + pf "%a" pp_diff (r, r') ; + invariant r'] let merge ?prefer r d e = if not r.sat then r else - let r, a = extend r d in - let r = - if Exp.equal d e then r - else - let r, b = extend r e in - let r = {r with pnd= (a, b) :: r.pnd} in - r - in + let r = extend r d in + let r = extend r e in + let r = add_equation ?prefer r d e in propagate ?prefer r -let rec normalize_ r e = - [%Trace.call fun {pf} -> pf "%a" Exp.pp e] +let rec normalize r ek = + [%Trace.call fun {pf} -> pf "%a@ %a" Exp.pp ek pp r] ; - map_sum e ~f:(function - | App _ as fa -> ( - let fa' = Exp.map ~f:(normalize_ r) fa in - match Map.find_exn r.lkp fa' with - | exception _ -> fa' - | c -> norm r (Exp.map ~f:(norm r) c) ) - | c -> - let c' = norm r c in - if c' == c then c else normalize_ r c' ) + map_base ek ~f:(fun e -> + try Map.find_exn r.rep e with Caml.Not_found -> + Exp.map ~f:(normalize r) e ) |> [%Trace.retn fun {pf} -> pf "%a" Exp.pp] -let normalize r e = - [%Trace.call fun {pf} -> pf "%a" pp r] - ; - normalize_ r e - |> - [%Trace.retn fun {pf} -> pf "%a" Exp.pp] - -let mem_eq r e f = Exp.equal (normalize r e) (normalize r f) +let mem_eq r d e = Exp.equal (normalize r d) (normalize r e) (** Exposed interface *) +let extend r e = propagate (extend r e) let and_eq = merge let and_ ?prefer r s = if not r.sat then r else if not s.sat then s else + let s, r = + if Map.length s.rep <= Map.length r.rep then (s, r) else (r, s) + in Map.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> merge ?prefer r e e') let or_ ?prefer r s = @@ -435,33 +630,36 @@ let or_ ?prefer r s = (* assumes that f is injective and for any set of exps E, f[E] is disjoint from E *) let map_exps ({sat= _; rep; lkp; cls; use; pnd} as r) ~f = - [%Trace.call fun {pf} -> pf "%a@." pp r] + [%Trace.call fun {pf} -> + pf "%a" pp r ; + assert (List.is_empty pnd)] ; - assert (List.is_empty pnd) ; - let map ~equal_data ~f_data m = + let map ~equal_key ~equal_data ~f_key ~f_data m = Map.fold m ~init:m ~f:(fun ~key ~data m -> - let key' = f key in + let key' = f_key key in let data' = f_data data in - if Exp.equal key' key then + if equal_key key' key then if equal_data data' data then m else Map.set m ~key ~data:data' - else - Map.remove m key - |> Map.add_exn ~key:key' - ~data:(if data' == data then data else data') ) + else Map.remove m key |> Map.add_exn ~key:key' ~data:data' ) + in + let rep' = + map rep ~equal_key:Exp.equal ~equal_data:Exp.equal ~f_key:f ~f_data:f + in + let lkp' = + map lkp ~equal_key:Exp.equal ~equal_data:Exp.equal ~f_key:f ~f_data:f in - let exp_map = map ~equal_data:Exp.equal ~f_data:f in - let exp_list_map = - map ~equal_data:[%compare.equal: Exp.t list] - ~f_data:(List.map_preserving_phys_equal ~f) + let cls' = + map cls ~equal_key:Exp.equal ~equal_data:[%compare.equal: Exp.t list] + ~f_key:f ~f_data:(Cls.map ~f) + in + let use' = + map use ~equal_key:Exp.equal ~equal_data:[%compare.equal: Exp.t list] + ~f_key:f ~f_data:(Use.map ~f) in - let rep' = exp_map rep in - let lkp' = exp_map lkp in - let cls' = exp_list_map cls in - let use' = exp_list_map use in ( if rep' == rep && lkp' == lkp && cls' == cls && use' == use then r else {r with rep= rep'; lkp= lkp'; cls= cls'; use= use'} ) |> - [%Trace.retn fun {pf} -> pf "%a@." pp] + [%Trace.retn fun {pf} r -> pf "%a" pp r ; invariant r] let rename r sub = map_exps r ~f:(fun e -> Exp.rename e sub) @@ -484,30 +682,25 @@ let classes {cls} = cls let entails r s = Map.for_alli s.rep ~f:(fun ~key:e ~data:e' -> mem_eq r e e') -(* a - b = k if a = b+k *) let difference r a b = [%Trace.call fun {pf} -> pf "%a@ %a@ %a" Exp.pp a Exp.pp b pp r] ; - let r, a = extend r a in - let r, b = extend r b in - let r = propagate r in - let ci = normalize r a in - let dj = normalize r b in - (* a - b = (c+i) - (d+j) *) - ( match solve_for_base dj ci with - (* d = (c+i)-j = c+(i-j) & c = d *) - | d, App {op= App {op= Add _; arg= c}; arg= Integer {data= i_j}} - when Exp.equal d c -> - (* a - b = (c+i) - (d+j) = i-j *) - Some i_j - | Integer {data= j}, Integer {data= i} -> Some (Z.sub i j) - | d, ci_j when Exp.equal d ci_j -> Some Z.zero + let r, a = norm_extend r a in + let r, b = norm_extend r b in + ( match (a, b) with + | _ when Exp.equal a b -> Some Z.zero + | (AppN {op= Add {typ} | Mul {typ}} | Integer {typ}), _ + |_, (AppN {op= Add {typ} | Mul {typ}} | Integer {typ}) -> ( + let a_b = Exp.sub typ a b in + let r, a_b = norm_extend r a_b in + let r = propagate r in + match normalize r a_b with Integer {data} -> Some data | _ -> None ) | _ -> None ) |> [%Trace.retn fun {pf} -> - function - | Some d -> pf "%a" Z.pp_print d - | None -> pf "c+i: %a@ d+j: %a" Exp.pp ci Exp.pp dj] + function Some d -> pf "%a" Z.pp_print d | None -> ()] + +(** Tests *) let%test_module _ = ( module struct diff --git a/sledge/src/symbheap/congruence.mli b/sledge/src/symbheap/congruence.mli index 6e4aa759e..471ab3d6f 100644 --- a/sledge/src/symbheap/congruence.mli +++ b/sledge/src/symbheap/congruence.mli @@ -24,6 +24,9 @@ type prefer = Exp.t -> over:Exp.t -> int val true_ : t (** The diagonal relation, which only equates each exp with itself. *) +val extend : t -> Exp.t -> t +(** Extend the carrier of the relation. *) + val merge : ?prefer:prefer -> t -> Exp.t -> Exp.t -> t (** Merge the equivalence classes of exps together. If [prefer a ~over:b] is positive, then [b] will not be used as the representative of a class diff --git a/sledge/src/symbheap/congruence_test.ml b/sledge/src/symbheap/congruence_test.ml index 3e67d964d..9c2122a89 100644 --- a/sledge/src/symbheap/congruence_test.ml +++ b/sledge/src/symbheap/congruence_test.ml @@ -9,16 +9,22 @@ let%test_module _ = ( module struct open Congruence - let printf pp = Format.printf "@.%a@." pp + (* let () = Trace.init ~margin:160 ~config:all () *) + + let () = Trace.init ~margin:68 ~config:none () + let printf pp = Format.printf "@\n%a@." pp + let pp = printf pp + let pp_classes = printf pp_classes let of_eqs = List.fold ~init:true_ ~f:(fun r (a, b) -> and_eq r a b) - let mem_eq x y r = entails r (and_eq true_ x y) + let mem_eq x y r = Exp.equal (normalize r x) (normalize r y) let i8 = Typ.integer ~bits:8 let i64 = Typ.integer ~bits:64 let ( ! ) i = Exp.integer (Z.of_int i) Typ.siz let ( + ) = Exp.add Typ.siz let ( - ) = Exp.sub Typ.siz + let ( * ) = Exp.mul Typ.siz let f = Exp.convert ~dst:i64 ~src:i8 - let g = Exp.xor + let g = Exp.rem let wrt = Var.Set.empty let t_, wrt = Var.fresh "t" ~wrt let u_, wrt = Var.fresh "u" ~wrt @@ -37,6 +43,11 @@ let%test_module _ = let f1 = of_eqs [(!0, !1)] let%test _ = is_false f1 + + let%expect_test _ = + pp f1 ; + [%expect {| {sat= false; rep= []; lkp= []; cls= []; use= []} |}] + let%test _ = Map.is_empty (classes f1) let%test _ = is_false (merge f1 !1 !1) @@ -65,49 +76,73 @@ let%test_module _ = let r0 = true_ let%expect_test _ = - printf pp r0 ; + pp r0 ; [%expect {| {sat= true; rep= []; lkp= []; cls= []; use= []} |}] - let%expect_test _ = printf pp_classes r0 ; [%expect {| |}] + let%expect_test _ = pp_classes r0 ; [%expect {| |}] let%test _ = difference r0 (f x) (f x) |> Poly.equal (Some (Z.of_int 0)) let%test _ = difference r0 !4 !3 |> Poly.equal (Some (Z.of_int 1)) let r1 = of_eqs [(x, y)] - let%test _ = not (Map.is_empty (classes r1)) - let%expect_test _ = - printf pp r1 ; + pp_classes r1 ; + pp r1 ; [%expect {| + %y_6 = %x_5 + {sat= true; rep= [[%x_5 ↦ %y_6]]; lkp= []; cls= [[%y_6 ↦ {%x_5, %y_6}]]; use= []} |}] - let%expect_test _ = printf pp_classes r1 ; [%expect {| %y_6 = %x_5 |}] + let%test _ = not (Map.is_empty (classes r1)) let%test _ = mem_eq x y r1 let r2 = of_eqs [(x, y); (f x, y); (f y, z)] let%expect_test _ = - printf pp r2 ; + pp_classes r2 ; + pp r2 ; [%expect {| + %y_6 = ((i64)(i8) %x_5) = ((i64)(i8) %y_6) = %x_5 = %z_7 + {sat= true; rep= [[((i64)(i8) %x_5) ↦ %y_6]; [((i64)(i8) %y_6) ↦ %y_6]; [%x_5 ↦ %y_6]; [%z_7 ↦ %y_6]]; lkp= [[((i64)(i8) %y_6) ↦ ((i64)(i8) %x_5)]]; - cls= [[%y_6 ↦ {((i64)(i8) %x_5), ((i64)(i8) %y_6), %x_5, %y_6, %z_7}]]; - use= [[%y_6 ↦ {((i64)(i8) %x_5)}]]} |}] + cls= [[%y_6 + ↦ {((i64)(i8) %x_5), ((i64)(i8) %y_6), %x_5, %y_6, %z_7}]]; + use= [[%y_6 ↦ {((i64)(i8) %x_5)}]; + [(i64)(i8) ↦ {((i64)(i8) %x_5)}]]} |}] + + let r2' = extend r2 (f z) let%expect_test _ = - printf pp_classes r2 ; + pp_classes r2' ; + pp r2' ; [%expect - {| %y_6 = ((i64)(i8) %x_5) = ((i64)(i8) %y_6) = %x_5 = %z_7 |}] + {| + %y_6 = ((i64)(i8) %x_5) = ((i64)(i8) %y_6) = ((i64)(i8) %z_7) + = %x_5 = %z_7 + + {sat= true; + rep= [[((i64)(i8) %x_5) ↦ %y_6]; + [((i64)(i8) %y_6) ↦ %y_6]; + [((i64)(i8) %z_7) ↦ %y_6]; + [%x_5 ↦ %y_6]; + [%z_7 ↦ %y_6]]; + lkp= [[((i64)(i8) %y_6) ↦ ((i64)(i8) %x_5)]]; + cls= [[%y_6 + ↦ {((i64)(i8) %x_5), ((i64)(i8) %y_6), ((i64)(i8) %z_7), + %x_5, %y_6, %z_7}]]; + use= [[%y_6 ↦ {((i64)(i8) %x_5)}]; + [(i64)(i8) ↦ {((i64)(i8) %x_5)}]]} |}] let%test _ = mem_eq x z r2 let%test _ = mem_eq x y (or_ r1 r2) @@ -115,7 +150,9 @@ let%test_module _ = let%test _ = not (mem_eq x z (or_ r1 r2)) let%test _ = mem_eq x z (or_ f1 r2) let%test _ = mem_eq x z (or_ r2 f3) - let%test _ = mem_eq (f x) (f z) r2 + let%test _ = mem_eq (f y) y r2 + let%test _ = mem_eq (f x) (f z) r2' + let%test _ = not (mem_eq (f x) (f z) r2) let%test _ = mem_eq (g x y) (g z y) r2 let%test _ = @@ -137,28 +174,31 @@ let%test_module _ = let r3 = of_eqs [(g y z, w); (v, w); (g y w, t); (x, v); (x, u); (u, z)] let%expect_test _ = - printf pp r3 ; + pp_classes r3 ; + pp r3 ; [%expect {| + ∧ %w_4 = (%y_6 rem %w_4) = (%y_6 rem %z_7) = %t_1 = %u_2 = %v_3 + = %x_5 = %z_7 + {sat= true; - rep= [[(%w_4 xor %y_6) ↦ %w_4]; - [(%y_6 xor %z_7) ↦ %w_4]; + rep= [[(%y_6 rem %w_4) ↦ %w_4]; + [(%y_6 rem %z_7) ↦ %w_4]; + [(rem %y_6) ↦ ]; [%t_1 ↦ %w_4]; [%u_2 ↦ %w_4]; [%v_3 ↦ %w_4]; [%x_5 ↦ %w_4]; [%z_7 ↦ %w_4]]; - lkp= [[(%w_4 xor %y_6) ↦ (%w_4 xor %y_6)]; - [(%y_6 xor %z_7) ↦ (%y_6 xor %z_7)]; - [(xor %w_4) ↦ (xor %w_4)]; - [(xor %y_6) ↦ (xor %y_6)]]; - cls= [[%w_4 - ↦ {(%w_4 xor %y_6), (%y_6 xor %z_7), %t_1, %u_2, %v_3, %w_4, - %x_5, %z_7}]]; - use= [[(xor %w_4) ↦ {(%w_4 xor %y_6)}]; - [(xor %y_6) ↦ {(%y_6 xor %z_7)}]; - [%w_4 ↦ {(xor %w_4)}]; - [%y_6 ↦ {(%w_4 xor %y_6), (xor %y_6)}]]} |}] + lkp= [[(%y_6 rem %w_4) ↦ ]; [(%y_6 rem %z_7) ↦ ]; [(rem %y_6) ↦ ]]; + cls= [[(rem %y_6) ↦ {(rem %y_6)}]; + [%w_4 + ↦ {(%y_6 rem %w_4), (%y_6 rem %z_7), %t_1, %u_2, %v_3, + %w_4, %x_5, %z_7}]]; + use= [[(rem %y_6) ↦ {(%y_6 rem %w_4), (%y_6 rem %z_7)}]; + [%w_4 ↦ {(%y_6 rem %w_4)}]; + [%y_6 ↦ {(rem %y_6)}]; + [rem ↦ {(rem %y_6)}]]} |}] let%test _ = mem_eq t z r3 let%test _ = mem_eq x z r3 @@ -166,6 +206,21 @@ let%test_module _ = let r4 = of_eqs [(w + !2, x - !3); (x - !5, y + !7); (y, z - !4)] + let%expect_test _ = + pp_classes r4 ; + pp r4 ; + [%expect + {| + %y_6 = (%w_4 + -7) = (%x_5 + -12) = (%z_7 + -4) + + {sat= true; + rep= [[%w_4 ↦ (%y_6 + 7)]; + [%x_5 ↦ (%y_6 + 12)]; + [%z_7 ↦ (%y_6 + 4)]]; + lkp= []; + cls= [[%y_6 ↦ {(%w_4 + -7), (%x_5 + -12), (%z_7 + -4), %y_6}]]; + use= []} |}] + let%test _ = mem_eq x (w + !5) r4 let%test _ = difference r4 x w |> Poly.equal (Some (Z.of_int 5)) @@ -176,16 +231,164 @@ let%test_module _ = let r6 = of_eqs [(x, !1); (!1, y)] + let%expect_test _ = + pp_classes r6 ; + pp r6 ; + [%expect + {| + 1 = %x_5 = %y_6 + + {sat= true; + rep= [[%x_5 ↦ 1]; [%y_6 ↦ 1]]; + lkp= []; + cls= [[1 ↦ {%x_5, %y_6, 1}]]; + use= []} |}] + let%test _ = mem_eq x y r6 let r7 = of_eqs [(v, x); (w, z); (y, z)] + let%expect_test _ = + pp_classes r7 ; + pp r7 ; + pp (merge ~prefer:(fun e ~over:f -> Exp.compare e f) r7 x z) ; + pp (merge ~prefer:(fun e ~over:f -> Exp.compare f e) r7 x z) ; + [%expect + {| + %x_5 = %v_3 + ∧ %z_7 = %w_4 = %y_6 + + {sat= true; + rep= [[%v_3 ↦ %x_5]; [%w_4 ↦ %z_7]; [%y_6 ↦ %z_7]]; + lkp= []; + cls= [[%x_5 ↦ {%v_3, %x_5}]; [%z_7 ↦ {%w_4, %y_6, %z_7}]]; + use= []} + + {sat= true; + rep= [[%v_3 ↦ %z_7]; [%w_4 ↦ %z_7]; [%x_5 ↦ %z_7]; [%y_6 ↦ %z_7]]; + lkp= []; + cls= [[%z_7 ↦ {%v_3, %w_4, %x_5, %y_6, %z_7}]]; + use= []} + + {sat= true; + rep= [[%v_3 ↦ %x_5]; [%w_4 ↦ %x_5]; [%y_6 ↦ %x_5]; [%z_7 ↦ %x_5]]; + lkp= []; + cls= [[%x_5 ↦ {%v_3, %w_4, %x_5, %y_6, %z_7}]]; + use= []} |}] + let%test _ = normalize (merge r7 x z) w |> Exp.equal z let%test _ = - normalize (merge ~prefer:(fun e ~over:f -> Exp.compare f e) r7 x z) w - |> Exp.equal x + let prefer e ~over:f = Exp.compare e f in + prefer z ~over:x > 0 + + let%test _ = + let prefer e ~over:f = Exp.compare e f in + prefer z ~over:x > 0 + (* so x should not be the rep *) + && normalize (merge ~prefer r7 x z) w |> Exp.equal z + + let%test _ = + let prefer e ~over:f = Exp.compare f e in + prefer x ~over:z > 0 + (* so z should not be the rep *) + && normalize (merge ~prefer r7 x z) w |> Exp.equal x + + let%test _ = + mem_eq (g w x) (g w z) + (extend (of_eqs [(g w x, g y z); (x, z)]) (g w z)) + + let%test _ = + mem_eq (g w x) (g w z) + (extend (of_eqs [(g w x, g y w); (x, z)]) (g w z)) + + let r8 = of_eqs [(x + !42, (!3 * y) + (!13 * z)); (!13 * z, x)] + + let%expect_test _ = + pp_classes r8 ; + pp r8 ; + [%expect + {| + (3 × %y_6 + 13 × %z_7) = (%x_5 + 42) = (13 × %z_7 + 42) + + {sat= true; + rep= [[(3 × %y_6 + 13 × %z_7) ↦ ]; + [(13 × %z_7) ↦ (3 × %y_6 + 13 × %z_7 + -42)]; + [%x_5 ↦ (3 × %y_6 + 13 × %z_7 + -42)]]; + lkp= [[(3 × %y_6 + 13 × %z_7) ↦ ]; [(13 × %z_7) ↦ ]]; + cls= [[(3 × %y_6 + 13 × %z_7) + ↦ {(%x_5 + 42), (3 × %y_6 + 13 × %z_7), (13 × %z_7 + 42)}]]; + use= [[%y_6 ↦ {(3 × %y_6 + 13 × %z_7)}]; + [%z_7 ↦ {(13 × %z_7), (3 × %y_6 + 13 × %z_7)}]; + [+ ↦ {(13 × %z_7), (3 × %y_6 + 13 × %z_7)}]]} |}] + + (* (* incomplete *) + * let%test _ = mem_eq y !14 r8 *) + + let r9 = of_eqs [(x, z - !16)] + + let%expect_test _ = + pp_classes r9 ; + pp r9 ; + [%expect + {| + %z_7 = (%x_5 + 16) + + {sat= true; + rep= [[%x_5 ↦ (%z_7 + -16)]]; + lkp= []; + cls= [[%z_7 ↦ {(%x_5 + 16), %z_7}]]; + use= []} |}] - let%test _ = mem_eq (g w x) (g w z) (of_eqs [(g w x, g y z); (x, z)]) - let%test _ = mem_eq (g w x) (g w z) (of_eqs [(g w x, g y w); (x, z)]) + let%test _ = difference r9 z (x + !8) |> Poly.equal (Some (Z.of_int 8)) + + let r10 = of_eqs [(!16, z - x)] + + let%expect_test _ = + pp_classes r10 ; + pp r10 ; + Format.printf "@.%a@." Exp.pp (z - (x + !8)) ; + Format.printf "@.%a@." Exp.pp (normalize r10 (z - (x + !8))) ; + Format.printf "@.%a@." Exp.pp (x + !8 - z) ; + Format.printf "@.%a@." Exp.pp (normalize r10 (x + !8 - z)) ; + [%expect + {| + 16 = (-1 × %x_5 + %z_7) + + {sat= true; + rep= [[(-1 × %x_5 + %z_7) ↦ 16]]; + lkp= [[(-1 × %x_5 + %z_7) ↦ ]]; + cls= [[16 ↦ {(-1 × %x_5 + %z_7), 16}]]; + use= [[%x_5 ↦ {(-1 × %x_5 + %z_7)}]; + [%z_7 ↦ {(-1 × %x_5 + %z_7)}]; + [+ ↦ {(-1 × %x_5 + %z_7)}]]} + + (-1 × %x_5 + %z_7 + -8) + + 8 + + (%x_5 + -1 × %z_7 + 8) + + (%x_5 + -1 × %z_7 + 8) |}] + + let%test _ = difference r10 z (x + !8) |> Poly.equal (Some (Z.of_int 8)) + + (* (* incomplete *) + * let%test _ = + * difference r10 (x + !8) z |> Poly.equal (Some (Z.of_int (-8))) *) + + let r11 = of_eqs [(!16, z - x); (x + !8 - z, z - !16 + !8 - z)] + + let%expect_test _ = + pp_classes r11 ; + [%expect + {| + -16 = (%x_5 + -1 × %z_7) + ∧ 16 = (-1 × %x_5 + %z_7) |}] + + let r12 = of_eqs [(!16, z - x); (x + !8 - z, z + !16 + !8 - z)] + + let%expect_test _ = + pp_classes r12 ; + [%expect {| 16 = (-1 × %x_5 + %z_7) = (%x_5 + -1 × %z_7) |}] end )