[sledge] Uncurry Record term constructor

Reviewed By: bennostein

Differential Revision: D17665260

fbshipit-source-id: 080f47739
master
Josh Berdine 5 years ago committed by Facebook Github Bot
parent 99b60d191a
commit 356b4f0b4e

@ -55,8 +55,7 @@ module rec T : sig
| Conditional
[@@deriving compare, equal, hash, sexp]
type opN = (* memory *)
| Concat [@@deriving compare, equal, hash, sexp]
type opN = Concat | Record [@@deriving compare, equal, hash, sexp]
type t =
(* nary arithmetic *)
@ -71,9 +70,7 @@ module rec T : sig
| Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t
| ApN of opN * t vector
| App of {op: t; arg: t}
(* array/struct constants and operations *)
| Record
| Struct_rec of {elts: t vector} (** NOTE: may be cyclic *)
(* numeric constants *)
| Integer of {data: Z.t}
@ -122,7 +119,7 @@ and T0 : sig
[@@deriving compare, equal, hash, sexp]
type op3 = Conditional [@@deriving compare, equal, hash, sexp]
type opN = Concat [@@deriving compare, equal, hash, sexp]
type opN = Concat | Record [@@deriving compare, equal, hash, sexp]
type t =
| Add of qset
@ -134,8 +131,6 @@ and T0 : sig
| Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t
| ApN of opN * t vector
| App of {op: t; arg: t}
| Record
| Struct_rec of {elts: t vector}
| Integer of {data: Z.t}
| Float of {data: string}
@ -170,7 +165,7 @@ end = struct
[@@deriving compare, equal, hash, sexp]
type op3 = Conditional [@@deriving compare, equal, hash, sexp]
type opN = Concat [@@deriving compare, equal, hash, sexp]
type opN = Concat | Record [@@deriving compare, equal, hash, sexp]
type t =
| Add of qset
@ -182,8 +177,6 @@ end = struct
| Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t
| ApN of opN * t vector
| App of {op: t; arg: t}
| Record
| Struct_rec of {elts: t vector}
| Integer of {data: Z.t}
| Float of {data: string}
@ -214,13 +207,6 @@ let fix (f : (t -> 'a as 'f) -> 'f) (bot : 'f) (e : t) : 'a =
let fix_flip (f : ('z -> t -> 'a as 'f) -> 'f) (bot : 'f) (z : 'z) (e : t) =
fix (fun f' e z -> f (fun z e -> f' e z) z e) (fun e z -> bot z e) e z
let uncurry =
let rec uncurry_ acc_args = function
| App {op; arg} -> uncurry_ (arg :: acc_args) op
| op -> (op, acc_args)
in
uncurry_ []
let rec pp ?is_x fs term =
let get_var_style var =
match is_x with
@ -280,11 +266,7 @@ let rec pp ?is_x fs term =
| Ap1 (Select idx, rcd) -> pf "%a[%i]" pp rcd idx
| Ap2 (Update idx, rcd, elt) ->
pf "[%a@ @[| %i → %a@]]" pp rcd idx pp elt
| Record -> pf "{_}"
| App {op; arg} -> (
match uncurry term with
| Record, elts -> pf "{%a}" pp_record elts
| _ -> pf "(%a@ %a)" pp op pp arg )
| ApN (Record, elts) -> pf "{%a}" pp_record elts
| Struct_rec {elts} -> pf "{|%a|}" (Vector.pp ",@ " pp) elts
| Ap1 (Extract {unsigned; bits}, arg) ->
pf "(%s%i)@ %a" (if unsigned then "u" else "i") bits pp arg
@ -301,16 +283,15 @@ and pp_record fs elts =
[%Trace.fprintf
fs "%a"
(fun fs elts ->
let elta = Array.of_list elts in
match
String.init (Array.length elta) ~f:(fun i ->
match elta.(i) with
String.init (Vector.length elts) ~f:(fun i ->
match Vector.get elts i with
| Integer {data} -> Char.of_int_exn (Z.to_int data)
| _ -> raise (Invalid_argument "not a string") )
with
| s -> Format.fprintf fs "@[<h>%s@]" (String.escaped s)
| exception _ ->
Format.fprintf fs "@[<h>%a@]" (List.pp ",@ " pp) elts )
Format.fprintf fs "@[<h>%a@]" (Vector.pp ",@ " pp) elts )
elts]
type term = t
@ -322,8 +303,7 @@ let pp = pp_t
(** Invariant *)
(* an indeterminate (factor of a monomial) is any non-Add/Mul/Integer term *)
let rec assert_indeterminate = function
| App {op} -> assert_indeterminate op
let assert_indeterminate = function
| Integer _ | Add _ | Mul _ -> assert false
| _ -> assert true
@ -369,19 +349,12 @@ let assert_polynomial poly =
Qset.iter args ~f:(fun m c -> assert_poly_term m c |> Fn.id)
| _ -> assert false
let invariant ?(partial = false) e =
let invariant e =
Invariant.invariant [%here] e [%sexp_of: t]
@@ fun () ->
let op, args = uncurry e in
let assert_arity arity =
let nargs = List.length args in
assert (nargs = arity || (partial && nargs < arity))
in
match op with
| App _ -> fail "uncurry cannot return App" ()
| Integer _ -> assert_arity 0
| Var _ | Nondet _ | Label _ | Float _ -> assert_arity 0
| Ap1 (Extract _, _) -> assert true
match e with
| Var _ | Nondet _ | Label _ | Integer _ | Float _ -> ()
| Ap1 (Extract _, _) -> ()
| Ap1 (Convert {dst; src}, _) -> assert (Typ.convertible src dst)
| Add _ -> assert_polynomial e |> Fn.id
| Mul _ -> assert_monomial e |> Fn.id
@ -390,20 +363,18 @@ let invariant ?(partial = false) e =
| Lshr | Ashr )
, _
, _ ) ->
assert true
()
| ApN (Concat, args) -> assert (Vector.length args <> 1)
| Ap2 (Splat, _, siz) -> (
match siz with
| Integer {data} -> assert (not (Z.equal Z.zero data))
| _ -> () )
| Ap2 (Memory, _, _) -> assert true
| Ap1 (Select _, _) -> assert true
| Ap3 (Conditional, _, _, _) -> assert true
| Ap2 (Update _, _, _) -> assert true
| Record -> assert (partial || not (List.is_empty args))
| Struct_rec {elts} ->
assert (not (Vector.is_empty elts)) ;
assert_arity 0
| Ap2 (Memory, _, _) -> ()
| Ap1 (Select _, _) -> ()
| Ap3 (Conditional, _, _, _) -> ()
| Ap2 (Update _, _, _) -> ()
| ApN (Record, elts) -> assert (not (Vector.is_empty elts))
| Struct_rec {elts} -> assert (not (Vector.is_empty elts))
(** Variables are the terms constructed by [Var] *)
module Var = struct
@ -529,7 +500,6 @@ let fold_terms e ~init ~f =
| Ap3 (_, x, y, z) -> fold_terms_ z (fold_terms_ y (fold_terms_ x s))
| ApN (_, xs) ->
Vector.fold ~f:(fun s x -> fold_terms_ x s) xs ~init:s
| App {op= x; arg= y} -> fold_terms_ y (fold_terms_ x s)
| Add args | Mul args ->
Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms_ arg s)
| Struct_rec {elts= args} ->
@ -540,8 +510,6 @@ let fold_terms e ~init ~f =
in
fix fold_terms_ (fun _ s -> s) e init
let iter_terms e ~f = fold_terms e ~init:() ~f:(fun () e -> f e)
let fold_vars e ~init ~f =
fold_terms e ~init ~f:(fun z -> function
| Var _ as v -> f z (v :> Var.t) | _ -> z )
@ -576,6 +544,7 @@ let simp_convert ~unsigned dst src arg =
integer (Z.extract ~unsigned (min m n) data)
| _ -> Ap1 (Convert {unsigned; dst; src}, arg)
let simp_record elts = ApN (Record, elts)
let simp_select idx rcd = Ap1 (Select idx, rcd)
let simp_update idx rcd elt = Ap2 (Update idx, rcd, elt)
@ -913,7 +882,6 @@ let iter e ~f =
| Ap2 (_, x, y) -> f x ; f y
| Ap3 (_, x, y, z) -> f x ; f y ; f z
| ApN (_, xs) -> Vector.iter ~f xs
| App {op= x; arg= y} -> f x ; f y
| Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args
| Struct_rec {elts= args} -> Vector.iter ~f args
| _ -> ()
@ -924,18 +892,11 @@ let fold e ~init:s ~f =
| Ap2 (_, x, y) -> f y (f x s)
| Ap3 (_, x, y, z) -> f z (f y (f x s))
| ApN (_, xs) -> Vector.fold ~f:(fun s x -> f x s) xs ~init:s
| App {op= x; arg= y} -> f y (f x s)
| Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s
| Struct_rec {elts= args} ->
Vector.fold ~f:(fun s e -> f e s) args ~init:s
| _ -> s
let is_subterm ~sub ~sup =
With_return.with_return
@@ fun {return} ->
iter_terms sup ~f:(fun e -> if equal sub e then return true) ;
false
let norm1 op x =
( match op with
| Extract {unsigned; bits} -> simp_extract ~unsigned bits x
@ -968,22 +929,8 @@ let norm3 op x y z =
(match op with Conditional -> simp_cond x y z) |> check invariant
let normN op xs =
(match op with Concat -> simp_concat xs) |> check invariant
let app1 ?(partial = false) op arg =
App {op; arg}
|> check (invariant ~partial)
|> check (fun e ->
(* every App subterm of output appears in input *)
iter e ~f:(function
| App _ as a ->
assert (
is_subterm ~sub:a ~sup:op
|| is_subterm ~sub:a ~sup:arg
|| Trace.fail
"simplifying %a %a@ yields %a@ with new subterm %a" pp
op pp arg pp e pp a )
| _ -> () ) )
(match op with Concat -> simp_concat xs | Record -> simp_record xs)
|> check invariant
let addN args = simp_add args |> check invariant
let mulN args = simp_mul args |> check invariant
@ -1010,7 +957,7 @@ let shl = norm2 Shl
let lshr = norm2 Lshr
let ashr = norm2 Ashr
let conditional ~cnd ~thn ~els = norm3 Conditional cnd thn els
let record elts = List.fold ~f:app1 ~init:Record elts
let record elts = normN Record elts
let select ~rcd ~idx = norm1 (Select idx) rcd
let update ~rcd ~idx ~elt = norm2 (Update idx) rcd elt
@ -1089,8 +1036,7 @@ let rec of_exp (e : Exp.t) =
update ~rcd:(of_exp rcd) ~idx ~elt:(of_exp elt)
| Ap3 (Conditional, _, cnd, thn, els) ->
conditional ~cnd:(of_exp cnd) ~thn:(of_exp thn) ~els:(of_exp els)
| ApN (Record, _, elts) ->
record (Vector.to_list (Vector.map ~f:of_exp elts))
| ApN (Record, _, elts) -> record (Vector.map ~f:of_exp elts)
| ApN (Struct_rec, _, elts) ->
Staged.unstage
(struct_rec (module Exp))
@ -1143,7 +1089,7 @@ let rename sub e =
| Var _ -> Var.Subst.apply sub e
| _ -> map ~f:(rename_ sub) e
in
rename_ sub e |> check (invariant ~partial:true)
rename_ sub e |> check invariant
(** Query *)
@ -1157,7 +1103,6 @@ let rec is_constant e =
| Ap2 (_, x, y) -> is_constant x && is_constant y
| Ap3 (_, x, y, z) -> is_constant x && is_constant y && is_constant z
| ApN (_, xs) -> Vector.for_all ~f:is_constant xs
| App {op= x; arg= y} -> is_constant x && is_constant y
| Add args | Mul args ->
Qset.for_all ~f:(fun arg _ -> is_constant arg) args
| Struct_rec {elts= args} -> Vector.for_all ~f:is_constant args

@ -54,7 +54,9 @@ type op2 =
type op3 = Conditional (** If-then-else *)
[@@deriving compare, equal, hash, sexp]
type opN = Concat (** Byte-array concatenation *)
type opN =
| Concat (** Byte-array concatenation *)
| Record (** Record (array / struct) constant *)
[@@deriving compare, equal, hash, sexp]
type qset = (t, comparator_witness) Qset.t
@ -72,9 +74,6 @@ and t = private
| Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t
| ApN of opN * t vector
| App of {op: t; arg: t}
(** Application of function symbol to argument, curried *)
| Record (** Record (array / struct) constant *)
| Struct_rec of {elts: t vector}
(** Struct constant that may recursively refer to itself
(transitively) from [elts]. NOTE: represented by cyclic values. *)
@ -90,7 +89,7 @@ type term = t
val pp_full : ?is_x:(term -> bool) -> t pp
val pp : t pp
val invariant : ?partial:bool -> t -> unit
val invariant : t -> unit
(** Term.Var is re-exported as Var *)
module Var : sig
@ -174,7 +173,7 @@ val shl : t -> t -> t
val lshr : t -> t -> t
val ashr : t -> t -> t
val conditional : cnd:t -> thn:t -> els:t -> t
val record : t list -> t
val record : t vector -> t
val select : rcd:t -> idx:int -> t
val update : rcd:t -> idx:int -> elt:t -> t
val extract : ?unsigned:bool -> bits:int -> t -> t

Loading…
Cancel
Save