[sledge] Change Struct_rec to a generic n-ary recursive application

Reviewed By: bennostein

Differential Revision: D17665266

fbshipit-source-id: dd938ac31
master
Josh Berdine 5 years ago committed by Facebook Github Bot
parent 356b4f0b4e
commit 7ecd091ff3

@ -56,6 +56,7 @@ module rec T : sig
[@@deriving compare, equal, hash, sexp] [@@deriving compare, equal, hash, sexp]
type opN = Concat | Record [@@deriving compare, equal, hash, sexp] type opN = Concat | Record [@@deriving compare, equal, hash, sexp]
type recN = Record [@@deriving compare, equal, hash, sexp]
type t = type t =
(* nary arithmetic *) (* nary arithmetic *)
@ -70,8 +71,8 @@ module rec T : sig
| Ap2 of op2 * t * t | Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t | Ap3 of op3 * t * t * t
| ApN of opN * t vector | ApN of opN * t vector
(* array/struct constants and operations *) (* recursive application *)
| Struct_rec of {elts: t vector} (** NOTE: may be cyclic *) | RecN of recN * t vector (** NOTE: cyclic *)
(* numeric constants *) (* numeric constants *)
| Integer of {data: Z.t} | Integer of {data: Z.t}
| Float of {data: string} | Float of {data: string}
@ -120,6 +121,7 @@ and T0 : sig
type op3 = Conditional [@@deriving compare, equal, hash, sexp] type op3 = Conditional [@@deriving compare, equal, hash, sexp]
type opN = Concat | Record [@@deriving compare, equal, hash, sexp] type opN = Concat | Record [@@deriving compare, equal, hash, sexp]
type recN = Record [@@deriving compare, equal, hash, sexp]
type t = type t =
| Add of qset | Add of qset
@ -131,7 +133,7 @@ and T0 : sig
| Ap2 of op2 * t * t | Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t | Ap3 of op3 * t * t * t
| ApN of opN * t vector | ApN of opN * t vector
| Struct_rec of {elts: t vector} | RecN of recN * t vector
| Integer of {data: Z.t} | Integer of {data: Z.t}
| Float of {data: string} | Float of {data: string}
[@@deriving compare, equal, hash, sexp] [@@deriving compare, equal, hash, sexp]
@ -166,6 +168,7 @@ end = struct
type op3 = Conditional [@@deriving compare, equal, hash, sexp] type op3 = Conditional [@@deriving compare, equal, hash, sexp]
type opN = Concat | Record [@@deriving compare, equal, hash, sexp] type opN = Concat | Record [@@deriving compare, equal, hash, sexp]
type recN = Record [@@deriving compare, equal, hash, sexp]
type t = type t =
| Add of qset | Add of qset
@ -177,7 +180,7 @@ end = struct
| Ap2 of op2 * t * t | Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t | Ap3 of op3 * t * t * t
| ApN of opN * t vector | ApN of opN * t vector
| Struct_rec of {elts: t vector} | RecN of recN * t vector
| Integer of {data: Z.t} | Integer of {data: Z.t}
| Float of {data: string} | Float of {data: string}
[@@deriving compare, equal, hash, sexp] [@@deriving compare, equal, hash, sexp]
@ -194,13 +197,13 @@ let empty_qset = Qset.empty (module T)
let fix (f : (t -> 'a as 'f) -> 'f) (bot : 'f) (e : t) : 'a = let fix (f : (t -> 'a as 'f) -> 'f) (bot : 'f) (e : t) : 'a =
let rec fix_f seen e = let rec fix_f seen e =
match e with match e with
| Struct_rec _ -> | RecN _ ->
if List.mem ~equal:( == ) seen e then f bot e if List.mem ~equal:( == ) seen e then f bot e
else f (fix_f (e :: seen)) e else f (fix_f (e :: seen)) e
| _ -> f (fix_f seen) e | _ -> f (fix_f seen) e
in in
let rec fix_f_seen_nil e = let rec fix_f_seen_nil e =
match e with Struct_rec _ -> f (fix_f [e]) e | _ -> f fix_f_seen_nil e match e with RecN _ -> f (fix_f [e]) e | _ -> f fix_f_seen_nil e
in in
fix_f_seen_nil e fix_f_seen_nil e
@ -267,7 +270,7 @@ let rec pp ?is_x fs term =
| Ap2 (Update idx, rcd, elt) -> | Ap2 (Update idx, rcd, elt) ->
pf "[%a@ @[| %i → %a@]]" pp rcd idx pp elt pf "[%a@ @[| %i → %a@]]" pp rcd idx pp elt
| ApN (Record, elts) -> pf "{%a}" pp_record elts | ApN (Record, elts) -> pf "{%a}" pp_record elts
| Struct_rec {elts} -> pf "{|%a|}" (Vector.pp ",@ " pp) elts | RecN (Record, elts) -> pf "{|%a|}" (Vector.pp ",@ " pp) elts
| Ap1 (Extract {unsigned; bits}, arg) -> | Ap1 (Extract {unsigned; bits}, arg) ->
pf "(%s%i)@ %a" (if unsigned then "u" else "i") bits pp arg pf "(%s%i)@ %a" (if unsigned then "u" else "i") bits pp arg
| Ap1 (Convert {dst; unsigned= true; src= Integer {bits}}, arg) -> | Ap1 (Convert {dst; unsigned= true; src= Integer {bits}}, arg) ->
@ -373,8 +376,8 @@ let invariant e =
| Ap1 (Select _, _) -> () | Ap1 (Select _, _) -> ()
| Ap3 (Conditional, _, _, _) -> () | Ap3 (Conditional, _, _, _) -> ()
| Ap2 (Update _, _, _) -> () | Ap2 (Update _, _, _) -> ()
| ApN (Record, elts) -> assert (not (Vector.is_empty elts)) | ApN (Record, elts) | RecN (Record, elts) ->
| Struct_rec {elts} -> assert (not (Vector.is_empty elts)) assert (not (Vector.is_empty elts))
(** Variables are the terms constructed by [Var] *) (** Variables are the terms constructed by [Var] *)
module Var = struct module Var = struct
@ -498,12 +501,10 @@ let fold_terms e ~init ~f =
| Ap1 (_, x) -> fold_terms_ x s | Ap1 (_, x) -> fold_terms_ x s
| Ap2 (_, x, y) -> fold_terms_ y (fold_terms_ x s) | Ap2 (_, x, y) -> fold_terms_ y (fold_terms_ x s)
| Ap3 (_, x, y, z) -> fold_terms_ z (fold_terms_ y (fold_terms_ x s)) | Ap3 (_, x, y, z) -> fold_terms_ z (fold_terms_ y (fold_terms_ x s))
| ApN (_, xs) -> | ApN (_, xs) | RecN (_, xs) ->
Vector.fold ~f:(fun s x -> fold_terms_ x s) xs ~init:s Vector.fold ~f:(fun s x -> fold_terms_ x s) xs ~init:s
| Add args | Mul args -> | Add args | Mul args ->
Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms_ arg s) Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms_ arg s)
| Struct_rec {elts= args} ->
Vector.fold args ~init:s ~f:(fun s elt -> fold_terms_ elt s)
| _ -> s | _ -> s
in in
f s e f s e
@ -881,9 +882,8 @@ let iter e ~f =
| Ap1 (_, x) -> f x | Ap1 (_, x) -> f x
| Ap2 (_, x, y) -> f x ; f y | Ap2 (_, x, y) -> f x ; f y
| Ap3 (_, x, y, z) -> f x ; f y ; f z | Ap3 (_, x, y, z) -> f x ; f y ; f z
| ApN (_, xs) -> Vector.iter ~f xs | ApN (_, xs) | RecN (_, xs) -> Vector.iter ~f xs
| Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args | Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args
| Struct_rec {elts= args} -> Vector.iter ~f args
| _ -> () | _ -> ()
let fold e ~init:s ~f = let fold e ~init:s ~f =
@ -891,10 +891,9 @@ let fold e ~init:s ~f =
| Ap1 (_, x) -> f x s | Ap1 (_, x) -> f x s
| Ap2 (_, x, y) -> f y (f x s) | Ap2 (_, x, y) -> f y (f x s)
| Ap3 (_, x, y, z) -> f z (f y (f x s)) | Ap3 (_, x, y, z) -> f z (f y (f x s))
| ApN (_, xs) -> Vector.fold ~f:(fun s x -> f x s) xs ~init:s | ApN (_, xs) | RecN (_, xs) ->
Vector.fold ~f:(fun s x -> f x s) xs ~init:s
| Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s | Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s
| Struct_rec {elts= args} ->
Vector.fold ~f:(fun s e -> f e s) args ~init:s
| _ -> s | _ -> s
let norm1 op x = let norm1 op x =
@ -961,27 +960,27 @@ let record elts = normN Record elts
let select ~rcd ~idx = norm1 (Select idx) rcd let select ~rcd ~idx = norm1 (Select idx) rcd
let update ~rcd ~idx ~elt = norm2 (Update idx) rcd elt let update ~rcd ~idx ~elt = norm2 (Update idx) rcd elt
let struct_rec key = let rec_app key =
let memo_id = Hashtbl.create key in let memo_id = Hashtbl.create key in
let dummy = null in let dummy = null in
Staged.stage Staged.stage
@@ fun ~id elt_thks -> @@ fun ~id op elt_thks ->
match Hashtbl.find memo_id id with match Hashtbl.find memo_id id with
| None -> | None ->
(* Add placeholder to prevent computing [elts] in calls to (* Add placeholder to prevent computing [elts] in calls to [rec_app]
[struct_rec] from [elt_thks] for recursive occurrences of [id]. *) from [elt_thks] for recursive occurrences of [id]. *)
let elta = Array.create ~len:(Vector.length elt_thks) dummy in let elta = Array.create ~len:(Vector.length elt_thks) dummy in
let elts = Vector.of_array elta in let elts = Vector.of_array elta in
Hashtbl.set memo_id ~key:id ~data:elts ; Hashtbl.set memo_id ~key:id ~data:elts ;
Vector.iteri elt_thks ~f:(fun i (lazy elt) -> elta.(i) <- elt) ; Vector.iteri elt_thks ~f:(fun i (lazy elt) -> elta.(i) <- elt) ;
Struct_rec {elts} |> check invariant RecN (op, elts) |> check invariant
| Some elts -> | Some elts ->
(* Do not check invariant as invariant will be checked above after the (* Do not check invariant as invariant will be checked above after the
thunks are forced, before which invariant-checking may spuriously thunks are forced, before which invariant-checking may spuriously
fail. Note that it is important that the value constructed here fail. Note that it is important that the value constructed here
shares the array in the memo table, so that the update after shares the array in the memo table, so that the update after
forcing the recursive thunks also updates this value. *) forcing the recursive thunks also updates this value. *)
Struct_rec {elts} RecN (op, elts)
let extract ?(unsigned = false) ~bits term = let extract ?(unsigned = false) ~bits term =
norm1 (Extract {unsigned; bits}) term norm1 (Extract {unsigned; bits}) term
@ -993,7 +992,9 @@ let size_of t =
Option.bind (Typ.prim_bit_size_of t) ~f:(fun n -> Option.bind (Typ.prim_bit_size_of t) ~f:(fun n ->
if n % 8 = 0 then Some (integer (Z.of_int (n / 8))) else None ) if n % 8 = 0 then Some (integer (Z.of_int (n / 8))) else None )
let rec of_exp (e : Exp.t) = let of_exp e =
let rec_app = Staged.unstage (rec_app (module Exp)) in
let rec of_exp e =
let unsigned op typ x y = let unsigned op typ x y =
match Typ.prim_bit_size_of typ with match Typ.prim_bit_size_of typ with
| Some bits -> | Some bits ->
@ -1038,10 +1039,9 @@ let rec of_exp (e : Exp.t) =
conditional ~cnd:(of_exp cnd) ~thn:(of_exp thn) ~els:(of_exp els) conditional ~cnd:(of_exp cnd) ~thn:(of_exp thn) ~els:(of_exp els)
| ApN (Record, _, elts) -> record (Vector.map ~f:of_exp elts) | ApN (Record, _, elts) -> record (Vector.map ~f:of_exp elts)
| ApN (Struct_rec, _, elts) -> | ApN (Struct_rec, _, elts) ->
Staged.unstage rec_app ~id:e Record (Vector.map ~f:(fun e -> lazy (of_exp e)) elts)
(struct_rec (module Exp)) in
~id:e of_exp e
(Vector.map elts ~f:(fun e -> lazy (of_exp e)))
(** Transform *) (** Transform *)
@ -1074,11 +1074,11 @@ let map e ~f =
match e with match e with
| Add args -> map_qset addN ~f args | Add args -> map_qset addN ~f args
| Mul args -> map_qset mulN ~f args | Mul args -> map_qset mulN ~f args
| Struct_rec {elts= args} -> Struct_rec {elts= Vector.map args ~f:map_}
| Ap1 (op, x) -> map1 op ~f x | Ap1 (op, x) -> map1 op ~f x
| Ap2 (op, x, y) -> map2 op ~f x y | Ap2 (op, x, y) -> map2 op ~f x y
| Ap3 (op, x, y, z) -> map3 op ~f x y z | Ap3 (op, x, y, z) -> map3 op ~f x y z
| ApN (op, xs) -> mapN op ~f xs | ApN (op, xs) -> mapN op ~f xs
| RecN (op, xs) -> RecN (op, Vector.map ~f:map_ xs)
| _ -> e | _ -> e
in in
fix map_ (fun e -> e) e fix map_ (fun e -> e) e
@ -1102,17 +1102,16 @@ let rec is_constant e =
| Ap1 (_, x) -> is_constant x | Ap1 (_, x) -> is_constant x
| Ap2 (_, x, y) -> is_constant x && is_constant y | Ap2 (_, x, y) -> is_constant x && is_constant y
| Ap3 (_, x, y, z) -> is_constant x && is_constant y && is_constant z | Ap3 (_, x, y, z) -> is_constant x && is_constant y && is_constant z
| ApN (_, xs) -> Vector.for_all ~f:is_constant xs | ApN (_, xs) | RecN (_, xs) -> Vector.for_all ~f:is_constant xs
| Add args | Mul args -> | Add args | Mul args ->
Qset.for_all ~f:(fun arg _ -> is_constant arg) args Qset.for_all ~f:(fun arg _ -> is_constant arg) args
| Struct_rec {elts= args} -> Vector.for_all ~f:is_constant args
| _ -> true | _ -> true
let classify = function let classify = function
| Add _ | Mul _ -> `Interpreted | Add _ | Mul _ -> `Interpreted
| Ap2 ((Eq | Dq), _, _) -> `Simplified | Ap2 ((Eq | Dq), _, _) -> `Simplified
| Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> `Uninterpreted | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> `Uninterpreted
| _ -> `Atomic | RecN _ | Var _ | Nondet _ | Label _ | Integer _ | Float _ -> `Atomic
let solve e f = let solve e f =
[%Trace.call fun {pf} -> pf "%a@ %a" pp e pp f] [%Trace.call fun {pf} -> pf "%a@ %a" pp e pp f]

@ -59,6 +59,12 @@ type opN =
| Record (** Record (array / struct) constant *) | Record (** Record (array / struct) constant *)
[@@deriving compare, equal, hash, sexp] [@@deriving compare, equal, hash, sexp]
type recN =
| Record
(** Record constant that may recursively refer to itself
(transitively) from its args. NOTE: represented by cyclic values. *)
[@@deriving compare, equal, hash, sexp]
type qset = (t, comparator_witness) Qset.t type qset = (t, comparator_witness) Qset.t
and t = private and t = private
@ -74,9 +80,7 @@ and t = private
| Ap2 of op2 * t * t | Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t | Ap3 of op3 * t * t * t
| ApN of opN * t vector | ApN of opN * t vector
| Struct_rec of {elts: t vector} | RecN of recN * t vector
(** Struct constant that may recursively refer to itself
(transitively) from [elts]. NOTE: represented by cyclic values. *)
| Integer of {data: Z.t} | Integer of {data: Z.t}
(** Integer constant, or if [typ] is a [Pointer], null pointer value (** Integer constant, or if [typ] is a [Pointer], null pointer value
that never refers to an object *) that never refers to an object *)

Loading…
Cancel
Save