From de1689ac87ecc296ba8c0e75dde8c454c36aa6e5 Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Thu, 16 Apr 2020 03:39:49 -0700 Subject: [PATCH] [sledge] Change And and Or terms from binary to flattened n-ary Summary: The heights of And and Or terms can grow high. This interacts poorly with some unoptimized Equality operations such as normalization that do some processing at every subterm. Reviewed By: jvillard Differential Revision: D21042518 fbshipit-source-id: 55e6acbb1 --- sledge/lib/equality.ml | 4 +- sledge/lib/import/set.ml | 21 +++++++- sledge/lib/import/set_intf.ml | 16 +++++- sledge/lib/sh.ml | 4 +- sledge/lib/term.ml | 94 ++++++++++++++++++++++++++++------- sledge/lib/term.mli | 18 +++++-- sledge/lib/term_test.ml | 2 +- 7 files changed, 130 insertions(+), 29 deletions(-) diff --git a/sledge/lib/equality.ml b/sledge/lib/equality.ml index 113077702..086783cd4 100644 --- a/sledge/lib/equality.ml +++ b/sledge/lib/equality.ml @@ -19,7 +19,7 @@ let classify e = |Ap3 (Extract, _, _, _) |ApN (Concat, _) -> Interpreted - | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted + | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ | And _ | Or _ -> Uninterpreted | RecN _ | Var _ | Integer _ | Rational _ | Float _ | Nondet _ | Label _ -> Atomic @@ -673,7 +673,7 @@ let rec and_term_ us e r = let eq_false b r = and_eq_ us b Term.false_ r in match (e : Term.t) with | Integer {data} -> if Z.is_false data then false_ else r - | Ap2 (And, a, b) -> and_term_ us a (and_term_ us b r) + | And cs -> Term.Set.fold cs ~init:r ~f:(fun r e -> and_term_ us e r) | Ap2 (Eq, a, b) -> and_eq_ us a b r | Ap2 (Xor, Integer {data}, a) when Z.is_true data -> eq_false a r | Ap2 (Xor, a, Integer {data}) when Z.is_true data -> eq_false a r diff --git a/sledge/lib/import/set.ml b/sledge/lib/import/set.ml index 1f95fa93e..7ed66ed68 100644 --- a/sledge/lib/import/set.ml +++ b/sledge/lib/import/set.ml @@ -5,6 +5,7 @@ * LICENSE file in the root directory of this source tree. *) +open Import0 include Set_intf module Make (Elt : sig @@ -17,7 +18,11 @@ end) : S with type elt = Elt.t = struct include EltSet.Tree - let pp pp_elt fs x = List.pp ",@ " pp_elt fs (elements x) + let hash_fold_t hash_fold_elt s m = + fold ~f:hash_fold_elt ~init:(Hash.fold_int s (length m)) m + + let pp ?pre ?suf ?(sep = (",@ " : (unit, unit) fmt)) pp_elt fs x = + List.pp ?pre ?suf sep pp_elt fs (elements x) let pp_diff pp_elt fs (xs, ys) = let lose = diff xs ys and gain = diff ys xs in @@ -41,4 +46,18 @@ end) : S with type elt = Elt.t = struct match split s2 x with | _, Some _, _ -> false | l2, None, r2 -> disjoint l1 l2 && disjoint r1 r2 ) + + let choose_exn s = + with_return + @@ fun {return} -> + binary_search_segmented s `Last_on_left ~segment_of:return |> ignore ; + raise (Not_found_s (Atom __LOC__)) + + let choose s = try Some (choose_exn s) with Not_found_s _ -> None + + let pop_exn s = + let elt = choose_exn s in + (elt, remove s elt) + + let pop s = choose s |> Option.map ~f:(fun elt -> (elt, remove s elt)) end diff --git a/sledge/lib/import/set_intf.ml b/sledge/lib/import/set_intf.ml index 65a421583..91f02c5ef 100644 --- a/sledge/lib/import/set_intf.ml +++ b/sledge/lib/import/set_intf.ml @@ -18,7 +18,15 @@ module type S = sig include Core_kernel.Set_intf.Make_S_plain_tree(Elt).S - val pp : elt pp -> t pp + val hash_fold_t : elt Hash.folder -> t Hash.folder + + val pp : + ?pre:(unit, unit) fmt + -> ?suf:(unit, unit) fmt + -> ?sep:(unit, unit) fmt + -> elt pp + -> t pp + val pp_diff : elt pp -> (t * t) pp val of_ : elt -> t val of_option : elt option -> t @@ -27,4 +35,10 @@ module type S = sig val add_list : elt list -> t -> t val diff_inter : t -> t -> t * t val disjoint : t -> t -> bool + + val pop_exn : t -> elt * t + (** Find and remove an unspecified element. [O(1)]. *) + + val pop : t -> (elt * t) option + (** Find and remove an unspecified element. [O(1)]. *) end diff --git a/sledge/lib/sh.ml b/sledge/lib/sh.ml index e3293f58b..59cb1c03b 100644 --- a/sledge/lib/sh.ml +++ b/sledge/lib/sh.ml @@ -520,7 +520,9 @@ let rec pure (e : Term.t) = [%Trace.call fun {pf} -> pf "%a" Term.pp e] ; ( match e with - | Ap2 (Or, e1, e2) -> or_ (pure e1) (pure e2) + | Or es -> + let e0, e1N = Term.Set.pop_exn es in + Term.Set.fold e1N ~init:(pure e0) ~f:(fun q e -> or_ q (pure e)) | Ap3 (Conditional, cnd, thn, els) -> or_ (star (pure cnd) (pure thn)) diff --git a/sledge/lib/term.ml b/sledge/lib/term.ml index de1e2fc57..08fff4ee1 100644 --- a/sledge/lib/term.ml +++ b/sledge/lib/term.ml @@ -26,8 +26,6 @@ type op2 = | Uno | Div | Rem - | And - | Or | Xor | Shl | Lshr @@ -40,7 +38,22 @@ type op3 = Conditional | Extract [@@deriving compare, equal, hash, sexp] type opN = Concat | Record [@@deriving compare, equal, hash, sexp] type recN = Record [@@deriving compare, equal, hash, sexp] -module rec Qset : sig +module rec Set : sig + include Import.Set.S with type elt := T.t + + val hash : t -> int + val hash_fold_t : t Hash.folder + val t_of_sexp : Sexp.t -> t +end = struct + include Import.Set.Make (T) + + let hash_fold_t = hash_fold_t T.hash_fold_t + let hash = Hash.of_fold hash_fold_t + + include Provide_of_sexp (T) +end + +and Qset : sig include Import.Qset.S with type elt := T.t val hash : t -> int @@ -55,6 +68,7 @@ end = struct end and T : sig + type set = Set.t [@@deriving compare, equal, hash, sexp] type qset = Qset.t [@@deriving compare, equal, hash, sexp] type t = @@ -64,6 +78,8 @@ and T : sig | Ap3 of op3 * t * t * t | ApN of opN * t iarray | RecN of recN * t iarray (** NOTE: cyclic *) + | And of set + | Or of set | Add of qset | Mul of qset | Label of {parent: string; name: string} @@ -73,6 +89,7 @@ and T : sig | Rational of {data: Q.t} [@@deriving compare, equal, hash, sexp] end = struct + type set = Set.t [@@deriving compare, equal, hash, sexp] type qset = Qset.t [@@deriving compare, equal, hash, sexp] type t = @@ -82,6 +99,8 @@ end = struct | Ap3 of op3 * t * t * t | ApN of opN * t iarray | RecN of recN * t iarray (** NOTE: cyclic *) + | And of set + | Or of set | Add of qset | Mul of qset | Label of {parent: string; name: string} @@ -109,7 +128,6 @@ end include T module Map = struct include Map.Make (T) include Provide_of_sexp (T) end -module Set = struct include Set.Make (T) include Provide_of_sexp (T) end let fix (f : (t -> 'a as 'f) -> 'f) (bot : 'f) (e : t) : 'a = let rec fix_f seen e = @@ -174,8 +192,8 @@ let rec ppx strength fs term = pf "(%a)" (Qset.pp "@ @<2>× " pp_mono_term) args | Ap2 (Div, x, y) -> pf "(%a@ / %a)" pp x pp y | Ap2 (Rem, x, y) -> pf "(%a@ rem %a)" pp x pp y - | Ap2 (And, x, y) -> pf "(%a@ && %a)" pp x pp y - | Ap2 (Or, x, y) -> pf "(%a@ || %a)" pp x pp y + | And xs -> pf "(@[%a@])" (Set.pp ~sep:" &&@ " pp) xs + | Or xs -> pf "(@[%a@])" (Set.pp ~sep:" ||@ " pp) xs | Ap2 (Xor, x, Integer {data}) when Z.is_true data -> pf "¬%a" pp x | Ap2 (Xor, Integer {data}, x) when Z.is_true data -> pf "¬%a" pp x | Ap2 (Xor, x, y) -> pf "(%a@ xor %a)" pp x pp y @@ -221,6 +239,18 @@ let pp_diff fs (x, y) = Format.fprintf fs "-- %a ++ %a" pp x pp y (** Invariant *) +let assert_conjunction = function + | And cs -> + Set.iter cs ~f:(fun c -> + assert (match c with And _ -> false | _ -> true) ) + | _ -> assert false + +let assert_disjunction = function + | Or cs -> + Set.iter cs ~f:(fun c -> + assert (match c with Or _ -> false | _ -> true) ) + | _ -> assert false + (* an indeterminate (factor of a monomial) is any non-Add/Mul/Integer/Rational term *) let assert_indeterminate = function @@ -285,6 +315,8 @@ let invariant e = Invariant.invariant [%here] e [%sexp_of: t] @@ fun () -> match e with + | And _ -> assert_conjunction e |> Fn.id + | Or _ -> assert_disjunction e |> Fn.id | Add _ -> assert_polynomial e |> Fn.id | Mul _ -> assert_monomial e |> Fn.id | Ap2 (Memory, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _) -> @@ -646,12 +678,13 @@ let rec is_boolean = function | Ap1 ((Unsigned {bits= 1} | Convert {dst= Integer {bits= 1; _}; _}), _) |Ap2 ((Eq | Dq | Lt | Le), _, _) -> true - | Ap2 ((Div | Rem | And | Or | Xor | Shl | Lshr | Ashr), x, y) + | Ap2 ((Div | Rem | Xor | Shl | Lshr | Ashr), x, y) |Ap3 (Conditional, _, x, y) -> is_boolean x || is_boolean y + | And xs | Or xs -> Set.for_all ~f:is_boolean xs | _ -> false -let rec simp_and x y = +let rec simp_and2 x y = match (x, y) with (* i && j *) | Integer {data= i}, Integer {data= j} -> integer (Z.logand i j) @@ -663,12 +696,16 @@ let rec simp_and x y = f (* e && (c ? t : f) ==> (c ? e && t : e && f) *) | e, Ap3 (Conditional, c, t, f) | Ap3 (Conditional, c, t, f), e -> - simp_cond c (simp_and e t) (simp_and e f) + simp_cond c (simp_and2 e t) (simp_and2 e f) (* e && e ==> e *) | _ when equal x y -> x - | _ -> Ap2 (And, x, y) + | _ -> + let add s = function And cs -> Set.union s cs | c -> Set.add s c in + And (add (add Set.empty x) y) -let rec simp_or x y = +let simp_and xs = Set.fold xs ~init:true_ ~f:simp_and2 + +let rec simp_or2 x y = match (x, y) with (* i || j *) | Integer {data= i}, Integer {data= j} -> integer (Z.logor i j) @@ -680,10 +717,14 @@ let rec simp_or x y = | (Integer {data}, e | e, Integer {data}) when Z.is_false data -> e (* e || (c ? t : f) ==> (c ? e || t : e || f) *) | e, Ap3 (Conditional, c, t, f) | Ap3 (Conditional, c, t, f), e -> - simp_cond c (simp_or e t) (simp_or e f) + simp_cond c (simp_or2 e t) (simp_or2 e f) (* e || e ==> e *) | _ when equal x y -> x - | _ -> Ap2 (Or, x, y) + | _ -> + let add s = function Or cs -> Set.union s cs | c -> Set.add s c in + Or (add (add Set.empty x) y) + +let simp_or xs = Set.fold xs ~init:false_ ~f:simp_or2 (* aggregate sizes *) @@ -920,9 +961,9 @@ and simp_not term = (* ¬(x = nan ∨ y = nan) ==> x ≠ nan ∧ y ≠ nan *) | Ap2 (Uno, x, y) -> simp_ord x y (* ¬(a ∧ b) ==> ¬a ∨ ¬b *) - | Ap2 (And, x, y) -> simp_or (simp_not x) (simp_not y) + | And xs -> simp_or (Set.map ~f:simp_not xs) (* ¬(a ∨ b) ==> ¬a ∧ ¬b *) - | Ap2 (Or, x, y) -> simp_and (simp_not x) (simp_not y) + | Or xs -> simp_and (Set.map ~f:simp_not xs) (* ¬¬e ==> e *) | Ap2 (Xor, Integer {data}, e) when Z.is_true data -> e | Ap2 (Xor, e, Integer {data}) when Z.is_true data -> e @@ -1024,8 +1065,6 @@ let norm2 op x y = | Uno -> simp_uno x y | Div -> simp_div x y | Rem -> simp_rem x y - | And -> simp_and x y - | Or -> simp_or x y | Xor -> simp_xor x y | Shl -> simp_shl x y | Lshr -> simp_lshr x y @@ -1065,8 +1104,10 @@ let mul e f = simp_mul2 e f |> check invariant let mulN args = simp_mul args |> check invariant let div = norm2 Div let rem = norm2 Rem -let and_ = norm2 And -let or_ = norm2 Or +let and_ e f = simp_and2 e f |> check invariant +let or_ e f = simp_or2 e f |> check invariant +let andN es = simp_and es |> check invariant +let orN es = simp_or es |> check invariant let not_ e = simp_not e |> check invariant let xor = norm2 Xor let shl = norm2 Shl @@ -1108,11 +1149,17 @@ let map e ~f = let xs' = IArray.map_endo ~f xs in if xs' == xs then e else normN op xs' in + let map_set mk ~f args = + let args' = Set.map ~f args in + if args' == args then e else mk args' + in let map_qset mk ~f args = let args' = Qset.map ~f:(fun arg q -> (f arg, q)) args in if args' == args then e else mk args' in match e with + | And args -> map_set andN ~f args + | Or args -> map_set orN ~f args | Add args -> map_qset addN ~f args | Mul args -> map_qset mulN ~f args | Ap1 (op, x) -> map1 op ~f x @@ -1197,6 +1244,7 @@ let iter e ~f = | Ap2 (_, x, y) -> f x ; f y | Ap3 (_, x, y, z) -> f x ; f y ; f z | ApN (_, xs) | RecN (_, xs) -> IArray.iter ~f xs + | And args | Or args -> Set.iter ~f args | Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> () @@ -1206,6 +1254,7 @@ let exists e ~f = | Ap2 (_, x, y) -> f x || f y | Ap3 (_, x, y, z) -> f x || f y || f z | ApN (_, xs) | RecN (_, xs) -> IArray.exists ~f xs + | And args | Or args -> Set.exists ~f args | Add args | Mul args -> Qset.exists ~f:(fun arg _ -> f arg) args | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> false @@ -1215,6 +1264,7 @@ let for_all e ~f = | Ap2 (_, x, y) -> f x && f y | Ap3 (_, x, y, z) -> f x && f y && f z | ApN (_, xs) | RecN (_, xs) -> IArray.for_all ~f xs + | And args | Or args -> Set.for_all ~f args | Add args | Mul args -> Qset.for_all ~f:(fun arg _ -> f arg) args | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> true @@ -1225,6 +1275,7 @@ let fold e ~init:s ~f = | Ap3 (_, x, y, z) -> f z (f y (f x s)) | ApN (_, xs) | RecN (_, xs) -> IArray.fold ~f:(fun s x -> f x s) xs ~init:s + | And args | Or args -> Set.fold ~f:(fun s e -> f e s) args ~init:s | Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> s @@ -1235,6 +1286,7 @@ let iter_terms e ~f = | Ap2 (_, x, y) -> iter_terms_ x ; iter_terms_ y | Ap3 (_, x, y, z) -> iter_terms_ x ; iter_terms_ y ; iter_terms_ z | ApN (_, xs) | RecN (_, xs) -> IArray.iter ~f:iter_terms_ xs + | And args | Or args -> Set.iter args ~f:iter_terms_ | Add args | Mul args -> Qset.iter args ~f:(fun arg _ -> iter_terms_ arg) | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> () ) ; @@ -1251,6 +1303,8 @@ let fold_terms e ~init ~f = | Ap3 (_, x, y, z) -> fold_terms_ z (fold_terms_ y (fold_terms_ x s)) | ApN (_, xs) | RecN (_, xs) -> IArray.fold ~f:(fun s x -> fold_terms_ x s) xs ~init:s + | And args | Or args -> + Set.fold args ~init:s ~f:(fun s x -> fold_terms_ x s) | Add args | Mul args -> Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms_ arg s) | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> s @@ -1290,6 +1344,8 @@ let height e = | Ap3 (_, a, b, c) -> 1 + max (height_ a) (max (height_ b) (height_ c)) | ApN (_, v) | RecN (_, v) -> 1 + IArray.fold v ~init:0 ~f:(fun m a -> max m (height_ a)) + | And bs | Or bs -> + 1 + Set.fold bs ~init:0 ~f:(fun m a -> max m (height_ a)) | Add qs | Mul qs -> 1 + Qset.fold qs ~init:0 ~f:(fun a _ m -> max m (height_ a)) | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> 0 diff --git a/sledge/lib/term.mli b/sledge/lib/term.mli index afe6fd53e..2d131c414 100644 --- a/sledge/lib/term.mli +++ b/sledge/lib/term.mli @@ -39,8 +39,6 @@ type op2 = | Rem (** Remainder of division, satisfies [a = b * div a b + rem a b] and for integers [rem a b] has same sign as [a], and [|rem a b| < |b|] *) - | And (** Conjunction, boolean or bitwise *) - | Or (** Disjunction, boolean or bitwise *) | Xor (** Exclusive-or, bitwise *) | Shl (** Shift left, bitwise *) | Lshr (** Logical shift right, bitwise *) @@ -62,7 +60,14 @@ type opN = type recN = Record (** Recursive record (array / struct) constant *) [@@deriving compare, equal, hash, sexp] -module rec Qset : sig +module rec Set : sig + include Import.Set.S with type elt := T.t + + val hash_fold_t : t Hash.folder + val t_of_sexp : Sexp.t -> t +end + +and Qset : sig include Import.Qset.S with type elt := T.t val hash_fold_t : t Hash.folder @@ -70,6 +75,8 @@ module rec Qset : sig end and T : sig + type set = Set.t [@@deriving compare, equal, hash, sexp] + type qset = Qset.t [@@deriving compare, equal, hash, sexp] and t = private @@ -83,6 +90,8 @@ and T : sig (** Recursive n-ary application, may recursively refer to itself (transitively) from its args. NOTE: represented by cyclic values. *) + | And of set (** Conjunction, boolean or bitwise *) + | Or of set (** Disjunction, boolean or bitwise *) | Add of qset (** Sum of terms with rational coefficients *) | Mul of qset (** Product of terms with rational exponents *) | Label of {parent: string; name: string} @@ -107,8 +116,9 @@ module Var : sig module Map : Map.S with type key := t module Set : sig - include Set.S with type elt := t + include Import.Set.S with type elt := t + val hash_fold_t : t Hash.folder val sexp_of_t : t -> Sexp.t val t_of_sexp : Sexp.t -> t val ppx : strength -> t pp diff --git a/sledge/lib/term_test.ml b/sledge/lib/term_test.ml index 7be43cc99..c069d2196 100644 --- a/sledge/lib/term_test.ml +++ b/sledge/lib/term_test.ml @@ -230,7 +230,7 @@ let%test_module _ = let%expect_test _ = pp ~~(!2 < y && z <= !3) ; - [%expect {| ((%y_1 ≤ 2) || (3 < %z_2)) |}] + [%expect {| ((3 < %z_2) || (%y_1 ≤ 2)) |}] let%expect_test _ = pp ~~(!2 <= y || z < !3) ;