From 356b4f0b4eb0fb56b60b21b6ee81085f87ca622c Mon Sep 17 00:00:00 2001 From: Josh Berdine Date: Wed, 9 Oct 2019 08:36:34 -0700 Subject: [PATCH] [sledge] Uncurry Record term constructor Reviewed By: bennostein Differential Revision: D17665260 fbshipit-source-id: 080f47739 --- sledge/src/symbheap/term.ml | 105 +++++++++-------------------------- sledge/src/symbheap/term.mli | 11 ++-- 2 files changed, 30 insertions(+), 86 deletions(-) diff --git a/sledge/src/symbheap/term.ml b/sledge/src/symbheap/term.ml index b444acdea..2c3ae489e 100644 --- a/sledge/src/symbheap/term.ml +++ b/sledge/src/symbheap/term.ml @@ -55,8 +55,7 @@ module rec T : sig | Conditional [@@deriving compare, equal, hash, sexp] - type opN = (* memory *) - | Concat [@@deriving compare, equal, hash, sexp] + type opN = Concat | Record [@@deriving compare, equal, hash, sexp] type t = (* nary arithmetic *) @@ -71,9 +70,7 @@ module rec T : sig | Ap2 of op2 * t * t | Ap3 of op3 * t * t * t | ApN of opN * t vector - | App of {op: t; arg: t} (* array/struct constants and operations *) - | Record | Struct_rec of {elts: t vector} (** NOTE: may be cyclic *) (* numeric constants *) | Integer of {data: Z.t} @@ -122,7 +119,7 @@ and T0 : sig [@@deriving compare, equal, hash, sexp] type op3 = Conditional [@@deriving compare, equal, hash, sexp] - type opN = Concat [@@deriving compare, equal, hash, sexp] + type opN = Concat | Record [@@deriving compare, equal, hash, sexp] type t = | Add of qset @@ -134,8 +131,6 @@ and T0 : sig | Ap2 of op2 * t * t | Ap3 of op3 * t * t * t | ApN of opN * t vector - | App of {op: t; arg: t} - | Record | Struct_rec of {elts: t vector} | Integer of {data: Z.t} | Float of {data: string} @@ -170,7 +165,7 @@ end = struct [@@deriving compare, equal, hash, sexp] type op3 = Conditional [@@deriving compare, equal, hash, sexp] - type opN = Concat [@@deriving compare, equal, hash, sexp] + type opN = Concat | Record [@@deriving compare, equal, hash, sexp] type t = | Add of qset @@ -182,8 +177,6 @@ end = struct | Ap2 of op2 * t * t | Ap3 of op3 * t * t * t | ApN of opN * t vector - | App of {op: t; arg: t} - | Record | Struct_rec of {elts: t vector} | Integer of {data: Z.t} | Float of {data: string} @@ -214,13 +207,6 @@ let fix (f : (t -> 'a as 'f) -> 'f) (bot : 'f) (e : t) : 'a = let fix_flip (f : ('z -> t -> 'a as 'f) -> 'f) (bot : 'f) (z : 'z) (e : t) = fix (fun f' e z -> f (fun z e -> f' e z) z e) (fun e z -> bot z e) e z -let uncurry = - let rec uncurry_ acc_args = function - | App {op; arg} -> uncurry_ (arg :: acc_args) op - | op -> (op, acc_args) - in - uncurry_ [] - let rec pp ?is_x fs term = let get_var_style var = match is_x with @@ -280,11 +266,7 @@ let rec pp ?is_x fs term = | Ap1 (Select idx, rcd) -> pf "%a[%i]" pp rcd idx | Ap2 (Update idx, rcd, elt) -> pf "[%a@ @[| %i → %a@]]" pp rcd idx pp elt - | Record -> pf "{_}" - | App {op; arg} -> ( - match uncurry term with - | Record, elts -> pf "{%a}" pp_record elts - | _ -> pf "(%a@ %a)" pp op pp arg ) + | ApN (Record, elts) -> pf "{%a}" pp_record elts | Struct_rec {elts} -> pf "{|%a|}" (Vector.pp ",@ " pp) elts | Ap1 (Extract {unsigned; bits}, arg) -> pf "(%s%i)@ %a" (if unsigned then "u" else "i") bits pp arg @@ -301,16 +283,15 @@ and pp_record fs elts = [%Trace.fprintf fs "%a" (fun fs elts -> - let elta = Array.of_list elts in match - String.init (Array.length elta) ~f:(fun i -> - match elta.(i) with + String.init (Vector.length elts) ~f:(fun i -> + match Vector.get elts i with | Integer {data} -> Char.of_int_exn (Z.to_int data) | _ -> raise (Invalid_argument "not a string") ) with | s -> Format.fprintf fs "@[%s@]" (String.escaped s) | exception _ -> - Format.fprintf fs "@[%a@]" (List.pp ",@ " pp) elts ) + Format.fprintf fs "@[%a@]" (Vector.pp ",@ " pp) elts ) elts] type term = t @@ -322,8 +303,7 @@ let pp = pp_t (** Invariant *) (* an indeterminate (factor of a monomial) is any non-Add/Mul/Integer term *) -let rec assert_indeterminate = function - | App {op} -> assert_indeterminate op +let assert_indeterminate = function | Integer _ | Add _ | Mul _ -> assert false | _ -> assert true @@ -369,19 +349,12 @@ let assert_polynomial poly = Qset.iter args ~f:(fun m c -> assert_poly_term m c |> Fn.id) | _ -> assert false -let invariant ?(partial = false) e = +let invariant e = Invariant.invariant [%here] e [%sexp_of: t] @@ fun () -> - let op, args = uncurry e in - let assert_arity arity = - let nargs = List.length args in - assert (nargs = arity || (partial && nargs < arity)) - in - match op with - | App _ -> fail "uncurry cannot return App" () - | Integer _ -> assert_arity 0 - | Var _ | Nondet _ | Label _ | Float _ -> assert_arity 0 - | Ap1 (Extract _, _) -> assert true + match e with + | Var _ | Nondet _ | Label _ | Integer _ | Float _ -> () + | Ap1 (Extract _, _) -> () | Ap1 (Convert {dst; src}, _) -> assert (Typ.convertible src dst) | Add _ -> assert_polynomial e |> Fn.id | Mul _ -> assert_monomial e |> Fn.id @@ -390,20 +363,18 @@ let invariant ?(partial = false) e = | Lshr | Ashr ) , _ , _ ) -> - assert true + () | ApN (Concat, args) -> assert (Vector.length args <> 1) | Ap2 (Splat, _, siz) -> ( match siz with | Integer {data} -> assert (not (Z.equal Z.zero data)) | _ -> () ) - | Ap2 (Memory, _, _) -> assert true - | Ap1 (Select _, _) -> assert true - | Ap3 (Conditional, _, _, _) -> assert true - | Ap2 (Update _, _, _) -> assert true - | Record -> assert (partial || not (List.is_empty args)) - | Struct_rec {elts} -> - assert (not (Vector.is_empty elts)) ; - assert_arity 0 + | Ap2 (Memory, _, _) -> () + | Ap1 (Select _, _) -> () + | Ap3 (Conditional, _, _, _) -> () + | Ap2 (Update _, _, _) -> () + | ApN (Record, elts) -> assert (not (Vector.is_empty elts)) + | Struct_rec {elts} -> assert (not (Vector.is_empty elts)) (** Variables are the terms constructed by [Var] *) module Var = struct @@ -529,7 +500,6 @@ let fold_terms e ~init ~f = | Ap3 (_, x, y, z) -> fold_terms_ z (fold_terms_ y (fold_terms_ x s)) | ApN (_, xs) -> Vector.fold ~f:(fun s x -> fold_terms_ x s) xs ~init:s - | App {op= x; arg= y} -> fold_terms_ y (fold_terms_ x s) | Add args | Mul args -> Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms_ arg s) | Struct_rec {elts= args} -> @@ -540,8 +510,6 @@ let fold_terms e ~init ~f = in fix fold_terms_ (fun _ s -> s) e init -let iter_terms e ~f = fold_terms e ~init:() ~f:(fun () e -> f e) - let fold_vars e ~init ~f = fold_terms e ~init ~f:(fun z -> function | Var _ as v -> f z (v :> Var.t) | _ -> z ) @@ -576,6 +544,7 @@ let simp_convert ~unsigned dst src arg = integer (Z.extract ~unsigned (min m n) data) | _ -> Ap1 (Convert {unsigned; dst; src}, arg) +let simp_record elts = ApN (Record, elts) let simp_select idx rcd = Ap1 (Select idx, rcd) let simp_update idx rcd elt = Ap2 (Update idx, rcd, elt) @@ -913,7 +882,6 @@ let iter e ~f = | Ap2 (_, x, y) -> f x ; f y | Ap3 (_, x, y, z) -> f x ; f y ; f z | ApN (_, xs) -> Vector.iter ~f xs - | App {op= x; arg= y} -> f x ; f y | Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args | Struct_rec {elts= args} -> Vector.iter ~f args | _ -> () @@ -924,18 +892,11 @@ let fold e ~init:s ~f = | Ap2 (_, x, y) -> f y (f x s) | Ap3 (_, x, y, z) -> f z (f y (f x s)) | ApN (_, xs) -> Vector.fold ~f:(fun s x -> f x s) xs ~init:s - | App {op= x; arg= y} -> f y (f x s) | Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s | Struct_rec {elts= args} -> Vector.fold ~f:(fun s e -> f e s) args ~init:s | _ -> s -let is_subterm ~sub ~sup = - With_return.with_return - @@ fun {return} -> - iter_terms sup ~f:(fun e -> if equal sub e then return true) ; - false - let norm1 op x = ( match op with | Extract {unsigned; bits} -> simp_extract ~unsigned bits x @@ -968,22 +929,8 @@ let norm3 op x y z = (match op with Conditional -> simp_cond x y z) |> check invariant let normN op xs = - (match op with Concat -> simp_concat xs) |> check invariant - -let app1 ?(partial = false) op arg = - App {op; arg} - |> check (invariant ~partial) - |> check (fun e -> - (* every App subterm of output appears in input *) - iter e ~f:(function - | App _ as a -> - assert ( - is_subterm ~sub:a ~sup:op - || is_subterm ~sub:a ~sup:arg - || Trace.fail - "simplifying %a %a@ yields %a@ with new subterm %a" pp - op pp arg pp e pp a ) - | _ -> () ) ) + (match op with Concat -> simp_concat xs | Record -> simp_record xs) + |> check invariant let addN args = simp_add args |> check invariant let mulN args = simp_mul args |> check invariant @@ -1010,7 +957,7 @@ let shl = norm2 Shl let lshr = norm2 Lshr let ashr = norm2 Ashr let conditional ~cnd ~thn ~els = norm3 Conditional cnd thn els -let record elts = List.fold ~f:app1 ~init:Record elts +let record elts = normN Record elts let select ~rcd ~idx = norm1 (Select idx) rcd let update ~rcd ~idx ~elt = norm2 (Update idx) rcd elt @@ -1089,8 +1036,7 @@ let rec of_exp (e : Exp.t) = update ~rcd:(of_exp rcd) ~idx ~elt:(of_exp elt) | Ap3 (Conditional, _, cnd, thn, els) -> conditional ~cnd:(of_exp cnd) ~thn:(of_exp thn) ~els:(of_exp els) - | ApN (Record, _, elts) -> - record (Vector.to_list (Vector.map ~f:of_exp elts)) + | ApN (Record, _, elts) -> record (Vector.map ~f:of_exp elts) | ApN (Struct_rec, _, elts) -> Staged.unstage (struct_rec (module Exp)) @@ -1143,7 +1089,7 @@ let rename sub e = | Var _ -> Var.Subst.apply sub e | _ -> map ~f:(rename_ sub) e in - rename_ sub e |> check (invariant ~partial:true) + rename_ sub e |> check invariant (** Query *) @@ -1157,7 +1103,6 @@ let rec is_constant e = | Ap2 (_, x, y) -> is_constant x && is_constant y | Ap3 (_, x, y, z) -> is_constant x && is_constant y && is_constant z | ApN (_, xs) -> Vector.for_all ~f:is_constant xs - | App {op= x; arg= y} -> is_constant x && is_constant y | Add args | Mul args -> Qset.for_all ~f:(fun arg _ -> is_constant arg) args | Struct_rec {elts= args} -> Vector.for_all ~f:is_constant args diff --git a/sledge/src/symbheap/term.mli b/sledge/src/symbheap/term.mli index 5144a665b..ff55396ba 100644 --- a/sledge/src/symbheap/term.mli +++ b/sledge/src/symbheap/term.mli @@ -54,7 +54,9 @@ type op2 = type op3 = Conditional (** If-then-else *) [@@deriving compare, equal, hash, sexp] -type opN = Concat (** Byte-array concatenation *) +type opN = + | Concat (** Byte-array concatenation *) + | Record (** Record (array / struct) constant *) [@@deriving compare, equal, hash, sexp] type qset = (t, comparator_witness) Qset.t @@ -72,9 +74,6 @@ and t = private | Ap2 of op2 * t * t | Ap3 of op3 * t * t * t | ApN of opN * t vector - | App of {op: t; arg: t} - (** Application of function symbol to argument, curried *) - | Record (** Record (array / struct) constant *) | Struct_rec of {elts: t vector} (** Struct constant that may recursively refer to itself (transitively) from [elts]. NOTE: represented by cyclic values. *) @@ -90,7 +89,7 @@ type term = t val pp_full : ?is_x:(term -> bool) -> t pp val pp : t pp -val invariant : ?partial:bool -> t -> unit +val invariant : t -> unit (** Term.Var is re-exported as Var *) module Var : sig @@ -174,7 +173,7 @@ val shl : t -> t -> t val lshr : t -> t -> t val ashr : t -> t -> t val conditional : cnd:t -> thn:t -> els:t -> t -val record : t list -> t +val record : t vector -> t val select : rcd:t -> idx:int -> t val update : rcd:t -> idx:int -> elt:t -> t val extract : ?unsigned:bool -> bits:int -> t -> t