[sledge] Represent recursive records non-recursively

Summary:
In LLVM it is possible for struct constant values to be directly
recursive, with no pointer dereference to close the cycle. These
appear for example as the values of vtables from C++ code.

Currently such recursive records in the Exp and Term languages are
represented as genuinely cyclic values. Compared to a standard term
representation, the presence of cyclic values is a significant
complication everywhere. Since the backend solver does not do anything
such as induction over these, they have to be treated as essentially
atomic.

This patch changes the representation to a standard non-recursive tree
term structure. Instead of cyclic references, an explicit constructor
is added for the "non-tree edges", which simply indicates which
ancestor record value to which the recursive reference points.

There is a potential issue with this representation, since for
mutually recursive records, the representation is not canonical: it
chooses one of the records in the cycle to start from and expresses
the cycles relative to that. Currently the choice of representation is
dictated by the frontend. For the case of vtables, the frontend
translates globals in the same order they appear in the LLVM IR, so
the representation choice is fixed.

It may turn out that other potential uses require more reasoning
support in the backend solver, which would involve a theory of
equality of record values induced by equating the representations
resulting from different rotations of the cycle of records.

Reviewed By: jvillard

Differential Revision: D21441533

fbshipit-source-id: 0c5a11378
master
Josh Berdine 5 years ago committed by Facebook GitHub Bot
parent 849c61221d
commit 9d9060d213

@ -343,16 +343,6 @@ let should_inline : Llvm.llvalue -> bool =
| None -> true ) | None -> true )
| None -> true | None -> true
module Llvalue = struct
type t = Llvm.llvalue
let hash = Hashtbl.hash
let compare = Poly.compare
let sexp_of_t llv = Sexp.Atom (Llvm.string_of_llvalue llv)
end
let struct_rec = Staged.unstage (Exp.struct_rec (module Llvalue))
let ptr_fld x ~ptr ~fld ~lltyp = let ptr_fld x ~ptr ~fld ~lltyp =
let offset = let offset =
Llvm_target.DataLayout.offset_of_element lltyp fld x.lldatalayout Llvm_target.DataLayout.offset_of_element lltyp fld x.lldatalayout
@ -377,34 +367,34 @@ let convert_to_siz =
let xlate_llvm_eh_typeid_for : x -> Typ.t -> Exp.t -> Exp.t = let xlate_llvm_eh_typeid_for : x -> Typ.t -> Exp.t -> Exp.t =
fun x typ arg -> Exp.convert typ ~to_:(i32 x) arg fun x typ arg -> Exp.convert typ ~to_:(i32 x) arg
let rec xlate_intrinsic_exp : string -> (x -> Llvm.llvalue -> Exp.t) option let rec xlate_intrinsic_exp stk :
= string -> (x -> Llvm.llvalue -> Exp.t) option =
fun name -> fun name ->
match name with match name with
| "llvm.eh.typeid.for" -> | "llvm.eh.typeid.for" ->
Some Some
(fun x llv -> (fun x llv ->
let rand = Llvm.operand llv 0 in let rand = Llvm.operand llv 0 in
let arg = xlate_value x rand in let arg = xlate_value stk x rand in
let src = xlate_type x (Llvm.type_of rand) in let src = xlate_type x (Llvm.type_of rand) in
xlate_llvm_eh_typeid_for x src arg ) xlate_llvm_eh_typeid_for x src arg )
| _ -> None | _ -> None
and xlate_value ?(inline = false) : x -> Llvm.llvalue -> Exp.t = and xlate_value ?(inline = false) stk : x -> Llvm.llvalue -> Exp.t =
fun x llv -> fun x llv ->
let xlate_value_ llv = let xlate_value_ llv =
match Llvm.classify_value llv with match Llvm.classify_value llv with
| Instruction Call -> ( | Instruction Call -> (
let func = Llvm.operand llv (Llvm.num_arg_operands llv) in let func = Llvm.operand llv (Llvm.num_arg_operands llv) in
let fname = Llvm.value_name func in let fname = Llvm.value_name func in
match xlate_intrinsic_exp fname with match xlate_intrinsic_exp stk fname with
| Some intrinsic when inline || should_inline llv -> intrinsic x llv | Some intrinsic when inline || should_inline llv -> intrinsic x llv
| _ -> Exp.reg (xlate_name x llv) ) | _ -> Exp.reg (xlate_name x llv) )
| Instruction (Invoke | Alloca | Load | PHI | LandingPad | VAArg) | Instruction (Invoke | Alloca | Load | PHI | LandingPad | VAArg)
|Argument -> |Argument ->
Exp.reg (xlate_name x llv) Exp.reg (xlate_name x llv)
| Function | GlobalVariable -> Exp.reg (xlate_global x llv).reg | Function | GlobalVariable -> Exp.reg (xlate_global stk x llv).reg
| GlobalAlias -> xlate_value x (Llvm.operand llv 0) | GlobalAlias -> xlate_value stk x (Llvm.operand llv 0)
| ConstantInt -> xlate_int x llv | ConstantInt -> xlate_int x llv
| ConstantFP -> xlate_float x llv | ConstantFP -> xlate_float x llv
| ConstantPointerNull -> Exp.null | ConstantPointerNull -> Exp.null
@ -419,35 +409,27 @@ and xlate_value ?(inline = false) : x -> Llvm.llvalue -> Exp.t =
| ConstantVector | ConstantArray -> | ConstantVector | ConstantArray ->
let typ = xlate_type x (Llvm.type_of llv) in let typ = xlate_type x (Llvm.type_of llv) in
let len = Llvm.num_operands llv in let len = Llvm.num_operands llv in
let f i = xlate_value x (Llvm.operand llv i) in let f i = xlate_value stk x (Llvm.operand llv i) in
Exp.record typ (IArray.init len ~f) Exp.record typ (IArray.init len ~f)
| ConstantDataVector -> | ConstantDataVector ->
let typ = xlate_type x (Llvm.type_of llv) in let typ = xlate_type x (Llvm.type_of llv) in
let len = Llvm.vector_size (Llvm.type_of llv) in let len = Llvm.vector_size (Llvm.type_of llv) in
let f i = xlate_value x (Llvm.const_element llv i) in let f i = xlate_value stk x (Llvm.const_element llv i) in
Exp.record typ (IArray.init len ~f) Exp.record typ (IArray.init len ~f)
| ConstantDataArray -> | ConstantDataArray ->
let typ = xlate_type x (Llvm.type_of llv) in let typ = xlate_type x (Llvm.type_of llv) in
let len = Llvm.array_length (Llvm.type_of llv) in let len = Llvm.array_length (Llvm.type_of llv) in
let f i = xlate_value x (Llvm.const_element llv i) in let f i = xlate_value stk x (Llvm.const_element llv i) in
Exp.record typ (IArray.init len ~f) Exp.record typ (IArray.init len ~f)
| ConstantStruct -> | ConstantStruct -> (
let typ = xlate_type x (Llvm.type_of llv) in let typ = xlate_type x (Llvm.type_of llv) in
let is_recursive = match List.findi llv stk with
Llvm.fold_left_uses | Some i -> Exp.rec_record i typ
(fun b use -> b || llv == Llvm.used_value use) | None ->
false llv let stk = llv :: stk in
in
if is_recursive then
let elt_thks =
IArray.init (Llvm.num_operands llv) ~f:(fun i ->
lazy (xlate_value x (Llvm.operand llv i)) )
in
struct_rec ~id:llv typ elt_thks
else
Exp.record typ Exp.record typ
(IArray.init (Llvm.num_operands llv) ~f:(fun i -> (IArray.init (Llvm.num_operands llv) ~f:(fun i ->
xlate_value x (Llvm.operand llv i) )) xlate_value stk x (Llvm.operand llv i) )) )
| BlockAddress -> | BlockAddress ->
let parent = find_name (Llvm.operand llv 0) in let parent = find_name (Llvm.operand llv 0) in
let name = find_name (Llvm.operand llv 1) in let name = find_name (Llvm.operand llv 1) in
@ -462,9 +444,9 @@ and xlate_value ?(inline = false) : x -> Llvm.llvalue -> Exp.t =
| SRem | FRem | Shl | LShr | AShr | And | Or | Xor | ICmp | FCmp | SRem | FRem | Shl | LShr | AShr | And | Or | Xor | ICmp | FCmp
| Select | GetElementPtr | ExtractElement | InsertElement | Select | GetElementPtr | ExtractElement | InsertElement
| ShuffleVector | ExtractValue | InsertValue ) as opcode ) -> | ShuffleVector | ExtractValue | InsertValue ) as opcode ) ->
if inline || should_inline llv then xlate_opcode x llv opcode if inline || should_inline llv then xlate_opcode stk x llv opcode
else Exp.reg (xlate_name x llv) else Exp.reg (xlate_name x llv)
| ConstantExpr -> xlate_opcode x llv (Llvm.constexpr_opcode llv) | ConstantExpr -> xlate_opcode stk x llv (Llvm.constexpr_opcode llv)
| GlobalIFunc -> todo "ifuncs: %a" pp_llvalue llv () | GlobalIFunc -> todo "ifuncs: %a" pp_llvalue llv ()
| Instruction (CatchPad | CleanupPad | CatchSwitch) -> | Instruction (CatchPad | CleanupPad | CatchSwitch) ->
todo "windows exception handling: %a" pp_llvalue llv () todo "windows exception handling: %a" pp_llvalue llv ()
@ -482,11 +464,11 @@ and xlate_value ?(inline = false) : x -> Llvm.llvalue -> Exp.t =
|> |>
[%Trace.retn fun {pf} exp -> pf "%a" Exp.pp exp] ) [%Trace.retn fun {pf} exp -> pf "%a" Exp.pp exp] )
and xlate_opcode : x -> Llvm.llvalue -> Llvm.Opcode.t -> Exp.t = and xlate_opcode stk : x -> Llvm.llvalue -> Llvm.Opcode.t -> Exp.t =
fun x llv opcode -> fun x llv opcode ->
[%Trace.call fun {pf} -> pf "%a" pp_llvalue llv] [%Trace.call fun {pf} -> pf "%a" pp_llvalue llv]
; ;
let xlate_rand i = xlate_value x (Llvm.operand llv i) in let xlate_rand i = xlate_value stk x (Llvm.operand llv i) in
let typ = lazy (xlate_type x (Llvm.type_of llv)) in let typ = lazy (xlate_type x (Llvm.type_of llv)) in
let check_vector = let check_vector =
lazy lazy
@ -497,7 +479,7 @@ and xlate_opcode : x -> Llvm.llvalue -> Llvm.Opcode.t -> Exp.t =
let dst = Lazy.force typ in let dst = Lazy.force typ in
let rand = Llvm.operand llv 0 in let rand = Llvm.operand llv 0 in
let src = xlate_type x (Llvm.type_of rand) in let src = xlate_type x (Llvm.type_of rand) in
let arg = xlate_value x rand in let arg = xlate_value stk x rand in
match (opcode : Llvm.Opcode.t) with match (opcode : Llvm.Opcode.t) with
| Trunc -> Exp.signed (Typ.bit_size_of dst) arg ~to_:dst | Trunc -> Exp.signed (Typ.bit_size_of dst) arg ~to_:dst
| SExt -> Exp.signed (Typ.bit_size_of src) arg ~to_:dst | SExt -> Exp.signed (Typ.bit_size_of src) arg ~to_:dst
@ -671,7 +653,7 @@ and xlate_opcode : x -> Llvm.llvalue -> Llvm.Opcode.t -> Exp.t =
| ShuffleVector -> ( | ShuffleVector -> (
(* translate shufflevector <N x t> %x, _, <N x i32> zeroinitializer to (* translate shufflevector <N x t> %x, _, <N x i32> zeroinitializer to
%x *) %x *)
let exp = xlate_value x (Llvm.operand llv 0) in let exp = xlate_value stk x (Llvm.operand llv 0) in
let exp_typ = xlate_type x (Llvm.type_of (Llvm.operand llv 0)) in let exp_typ = xlate_type x (Llvm.type_of (Llvm.operand llv 0)) in
let llmask = Llvm.operand llv 2 in let llmask = Llvm.operand llv 2 in
let mask_typ = xlate_type x (Llvm.type_of llmask) in let mask_typ = xlate_type x (Llvm.type_of llmask) in
@ -687,7 +669,7 @@ and xlate_opcode : x -> Llvm.llvalue -> Llvm.Opcode.t -> Exp.t =
|> |>
[%Trace.retn fun {pf} exp -> pf "%a" Exp.pp exp] [%Trace.retn fun {pf} exp -> pf "%a" Exp.pp exp]
and xlate_global : x -> Llvm.llvalue -> Global.t = and xlate_global stk : x -> Llvm.llvalue -> Global.t =
fun x llg -> fun x llg ->
Hashtbl.find_or_add memo_global llg ~default:(fun () -> Hashtbl.find_or_add memo_global llg ~default:(fun () ->
[%Trace.call fun {pf} -> pf "%a" pp_llvalue llg] [%Trace.call fun {pf} -> pf "%a" pp_llvalue llg]
@ -702,13 +684,18 @@ and xlate_global : x -> Llvm.llvalue -> Global.t =
let init = let init =
match Llvm.classify_value llg with match Llvm.classify_value llg with
| GlobalVariable -> | GlobalVariable ->
Option.map ~f:(xlate_value x) (Llvm.global_initializer llg) Option.map ~f:(xlate_value stk x) (Llvm.global_initializer llg)
| _ -> None | _ -> None
in in
Global.mk ?init g typ loc Global.mk ?init g typ loc
|> |>
[%Trace.retn fun {pf} -> pf "%a" Global.pp_defn] ) [%Trace.retn fun {pf} -> pf "%a" Global.pp_defn] )
let xlate_intrinsic_exp = xlate_intrinsic_exp []
let xlate_value ?inline = xlate_value ?inline []
let xlate_opcode = xlate_opcode []
let xlate_global = xlate_global []
type pop_thunk = Loc.t -> Llair.inst list type pop_thunk = Loc.t -> Llair.inst list
let pop_stack_frame_of_function : let pop_stack_frame_of_function :

@ -17,8 +17,8 @@ let classify e =
| Add _ | Ap2 (Memory, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _) -> | Add _ | Ap2 (Memory, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _) ->
Interpreted Interpreted
| Mul _ | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ | And _ | Or _ -> Uninterpreted | Mul _ | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ | And _ | Or _ -> Uninterpreted
| RecN _ | Var _ | Integer _ | Rational _ | Float _ | Nondet _ | Label _ | Var _ | Integer _ | Rational _ | Float _ | Nondet _ | Label _
-> |RecRecord _ ->
Atomic Atomic
let interpreted e = equal_kind (classify e) Interpreted let interpreted e = equal_kind (classify e) Interpreted

@ -57,10 +57,8 @@ module T = struct
| Conditional | Conditional
[@@deriving compare, equal, hash, sexp] [@@deriving compare, equal, hash, sexp]
type opN = type opN = (* array/struct constants *)
(* array/struct constants *)
| Record | Record
| Struct_rec (** NOTE: may be cyclic *)
[@@deriving compare, equal, hash, sexp] [@@deriving compare, equal, hash, sexp]
type t = {desc: desc; term: Term.t} type t = {desc: desc; term: Term.t}
@ -75,6 +73,7 @@ module T = struct
| Ap2 of op2 * Typ.t * t * t | Ap2 of op2 * Typ.t * t * t
| Ap3 of op3 * Typ.t * t * t * t | Ap3 of op3 * Typ.t * t * t * t
| ApN of opN * Typ.t * t iarray | ApN of opN * Typ.t * t iarray
| RecRecord of int * Typ.t
[@@deriving compare, equal, hash, sexp] [@@deriving compare, equal, hash, sexp]
end end
@ -84,24 +83,6 @@ module Map = Map.Make (T)
let term e = e.term let term e = e.term
let fix (f : (t -> 'a as 'f) -> 'f) (bot : 'f) (e : t) : 'a =
let rec fix_f seen e =
match e.desc with
| ApN (Struct_rec, _, _) ->
if List.mem ~equal:( == ) seen e then f bot e
else f (fix_f (e :: seen)) e
| _ -> f (fix_f seen) e
in
let rec fix_f_seen_nil e =
match e.desc with
| ApN (Struct_rec, _, _) -> f (fix_f [e]) e
| _ -> f fix_f_seen_nil e
in
fix_f_seen_nil e
let fix_flip (f : ('z -> t -> 'a as 'f) -> 'f) (bot : 'f) (z : 'z) (e : t) =
fix (fun f' e z -> f (fun z e -> f' e z) z e) (fun e z -> bot z e) e z
let pp_op2 fs op = let pp_op2 fs op =
let pf fmt = Format.fprintf fs fmt in let pf fmt = Format.fprintf fs fmt in
match op with match op with
@ -133,7 +114,6 @@ let pp_op2 fs op =
| Update idx -> pf "[_|%i→_]" idx | Update idx -> pf "[_|%i→_]" idx
let rec pp fs exp = let rec pp fs exp =
let pp_ pp fs exp =
let pf fmt = let pf fmt =
Format.pp_open_box fs 2 ; Format.pp_open_box fs 2 ;
Format.kfprintf (fun fs -> Format.pp_close_box fs ()) fs fmt Format.kfprintf (fun fs -> Format.pp_close_box fs ()) fs fmt
@ -168,9 +148,7 @@ let rec pp fs exp =
| Ap3 (Conditional, _, cnd, thn, els) -> | Ap3 (Conditional, _, cnd, thn, els) ->
pf "(%a@ ? %a@ : %a)" pp cnd pp thn pp els pf "(%a@ ? %a@ : %a)" pp cnd pp thn pp els
| ApN (Record, _, elts) -> pf "{%a}" pp_record elts | ApN (Record, _, elts) -> pf "{%a}" pp_record elts
| ApN (Struct_rec, _, elts) -> pf "{|%a|}" (IArray.pp ",@ " pp) elts | RecRecord (i, _) -> pf "rec_record %i" i
in
fix_flip pp_ (fun _ _ -> ()) fs exp
[@@warning "-9"] [@@warning "-9"]
and pp_record fs elts = and pp_record fs elts =
@ -256,7 +234,7 @@ let rec invariant exp =
assert (Typ.castable Typ.bool (typ_of cnd)) ; assert (Typ.castable Typ.bool (typ_of cnd)) ;
assert (Typ.castable typ (typ_of thn)) ; assert (Typ.castable typ (typ_of thn)) ;
assert (Typ.castable typ (typ_of els)) assert (Typ.castable typ (typ_of els))
| ApN ((Record | Struct_rec), typ, args) -> ( | ApN (Record, typ, args) -> (
match typ with match typ with
| Array {elt} -> | Array {elt} ->
assert ( assert (
@ -268,6 +246,7 @@ let rec invariant exp =
IArray.for_all2_exn elts args ~f:(fun typ arg -> IArray.for_all2_exn elts args ~f:(fun typ arg ->
Typ.castable typ (typ_of arg) ) ) Typ.castable typ (typ_of arg) ) )
| _ -> assert false ) | _ -> assert false )
| RecRecord _ -> ()
[@@warning "-9"] [@@warning "-9"]
(** Type query *) (** Type query *)
@ -295,7 +274,8 @@ and typ_of exp =
, _ , _
, _ ) , _ )
|Ap3 (Conditional, typ, _, _, _) |Ap3 (Conditional, typ, _, _, _)
|ApN ((Record | Struct_rec), typ, _) -> |ApN (Record, typ, _)
|RecRecord (_, typ) ->
typ typ
[@@warning "-9"] [@@warning "-9"]
@ -472,35 +452,12 @@ let update typ ~rcd idx ~elt =
; term= Term.update ~rcd:rcd.term ~idx ~elt:elt.term } ; term= Term.update ~rcd:rcd.term ~idx ~elt:elt.term }
|> check invariant |> check invariant
let struct_rec key = let rec_record i typ = {desc= RecRecord (i, typ); term= Term.rec_record i}
let memo_id = Hashtbl.create key in
let rec_app = (Staged.unstage (Term.rec_app key)) Term.Record in
Staged.stage
@@ fun ~id typ elt_thks ->
match Hashtbl.find memo_id id with
| None ->
(* Add placeholder to prevent computing [elts] in calls to
[struct_rec] from [elt_thks] for recursive occurrences of [id]. *)
let elta = Array.create ~len:(IArray.length elt_thks) null in
let elts = IArray.of_array elta in
Hashtbl.set memo_id ~key:id ~data:elts ;
let term =
rec_app ~id (IArray.map ~f:(fun elt -> lazy elt.term) elts)
in
IArray.iteri elt_thks ~f:(fun i (lazy elt) -> elta.(i) <- elt) ;
{desc= ApN (Struct_rec, typ, elts); term} |> check invariant
| Some elts ->
(* Do not check invariant as invariant will be checked above after the
thunks are forced, before which invariant-checking may spuriously
fail. Note that it is important that the value constructed here
shares the array in the memo table, so that the update after
forcing the recursive thunks also updates this value. *)
{desc= ApN (Struct_rec, typ, elts); term= rec_app ~id IArray.empty}
(** Traverse *) (** Traverse *)
let fold_exps e ~init ~f = let fold_exps e ~init ~f =
let fold_exps_ fold_exps_ e z = let rec fold_exps_ e z =
let z = let z =
match e.desc with match e.desc with
| Ap1 (_, _, x) -> fold_exps_ x z | Ap1 (_, _, x) -> fold_exps_ x z
@ -512,7 +469,7 @@ let fold_exps e ~init ~f =
in in
f z e f z e
in in
fix fold_exps_ (fun _ z -> z) e init fold_exps_ e init
let fold_regs e ~init ~f = let fold_regs e ~init ~f =
fold_exps e ~init ~f:(fun z x -> fold_exps e ~init ~f:(fun z x ->

@ -68,11 +68,7 @@ type op2 =
type op3 = Conditional (** If-then-else *) type op3 = Conditional (** If-then-else *)
[@@deriving compare, equal, hash, sexp] [@@deriving compare, equal, hash, sexp]
type opN = type opN = Record (** Record (array / struct) constant *)
| Record (** Record (array / struct) constant *)
| Struct_rec
(** Struct constant that may recursively refer to itself
(transitively) from [elts]. NOTE: represented by cyclic values. *)
[@@deriving compare, equal, hash, sexp] [@@deriving compare, equal, hash, sexp]
type t = private {desc: desc; term: Term.t} type t = private {desc: desc; term: Term.t}
@ -90,6 +86,7 @@ and desc = private
| Ap2 of op2 * Typ.t * t * t | Ap2 of op2 * Typ.t * t * t
| Ap3 of op3 * Typ.t * t * t * t | Ap3 of op3 * Typ.t * t * t * t
| ApN of opN * Typ.t * t iarray | ApN of opN * Typ.t * t iarray
| RecRecord of int * Typ.t (** Reference to ancestor recursive record *)
[@@deriving compare, equal, hash, sexp] [@@deriving compare, equal, hash, sexp]
val pp : t pp val pp : t pp
@ -188,17 +185,7 @@ val splat : Typ.t -> t -> t
val record : Typ.t -> t iarray -> t val record : Typ.t -> t iarray -> t
val select : Typ.t -> t -> int -> t val select : Typ.t -> t -> int -> t
val update : Typ.t -> rcd:t -> int -> elt:t -> t val update : Typ.t -> rcd:t -> int -> elt:t -> t
val rec_record : int -> Typ.t -> t
val struct_rec :
(module Hashtbl.Key_plain with type t = 'id)
-> (id:'id -> Typ.t -> t lazy_t iarray -> t) Staged.t
(** [struct_rec Id id element_thunks] constructs a possibly-cyclic [Struct]
value. Cycles are detected using [Id]. The caller of [struct_rec Id]
must ensure that a single unstaging of [struct_rec Id] is used for each
complete cyclic value. Also, the caller must ensure that recursive calls
to [struct_rec Id] provide [id] values that uniquely identify at least
one point on each cycle. Failure to obey these requirements will lead to
stack overflow. *)
(** Traverse *) (** Traverse *)

@ -18,6 +18,15 @@ let rec pp ?pre ?suf sep pp_elt fs = function
| xs -> Format.fprintf fs "%( %)%a" sep (pp sep pp_elt) xs ) ; | xs -> Format.fprintf fs "%( %)%a" sep (pp sep pp_elt) xs ) ;
Option.iter suf ~f:(Format.fprintf fs) Option.iter suf ~f:(Format.fprintf fs)
let findi x xs =
let rec findi_ i xs =
match xs with
| [] -> None
| x' :: _ when x == x' -> Some i
| _ :: xs -> findi_ (i + 1) xs
in
findi_ 0 xs
let pop_exn = function let pop_exn = function
| x :: xs -> (x, xs) | x :: xs -> (x, xs)
| [] -> raise (Not_found_s (Atom __LOC__)) | [] -> raise (Not_found_s (Atom __LOC__))

@ -22,6 +22,9 @@ val pp_diff :
-> 'a pp -> 'a pp
-> ('a list * 'a list) pp -> ('a list * 'a list) pp
val findi : 'a -> 'a t -> int option
(** [findi x xs] is [Some i] when [nth xs i == x], otherwise [None]. *)
val pop_exn : 'a list -> 'a * 'a list val pop_exn : 'a list -> 'a * 'a list
val find_map_remove : val find_map_remove :

@ -36,7 +36,6 @@ type op2 =
type op3 = Conditional | Extract [@@deriving compare, equal, hash, sexp] type op3 = Conditional | Extract [@@deriving compare, equal, hash, sexp]
type opN = Concat | Record [@@deriving compare, equal, hash, sexp] type opN = Concat | Record [@@deriving compare, equal, hash, sexp]
type recN = Record [@@deriving compare, equal, hash, sexp]
module rec Set : sig module rec Set : sig
include Import.Set.S with type elt := T.t include Import.Set.S with type elt := T.t
@ -77,7 +76,6 @@ and T : sig
| Ap2 of op2 * t * t | Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t | Ap3 of op3 * t * t * t
| ApN of opN * t iarray | ApN of opN * t iarray
| RecN of recN * t iarray (** NOTE: cyclic *)
| And of set | And of set
| Or of set | Or of set
| Add of qset | Add of qset
@ -87,6 +85,7 @@ and T : sig
| Float of {data: string} | Float of {data: string}
| Integer of {data: Z.t} | Integer of {data: Z.t}
| Rational of {data: Q.t} | Rational of {data: Q.t}
| RecRecord of int
[@@deriving compare, equal, hash, sexp] [@@deriving compare, equal, hash, sexp]
end = struct end = struct
type set = Set.t [@@deriving compare, equal, hash, sexp] type set = Set.t [@@deriving compare, equal, hash, sexp]
@ -98,7 +97,6 @@ end = struct
| Ap2 of op2 * t * t | Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t | Ap3 of op3 * t * t * t
| ApN of opN * t iarray | ApN of opN * t iarray
| RecN of recN * t iarray (** NOTE: cyclic *)
| And of set | And of set
| Or of set | Or of set
| Add of qset | Add of qset
@ -108,6 +106,7 @@ end = struct
| Float of {data: string} | Float of {data: string}
| Integer of {data: Z.t} | Integer of {data: Z.t}
| Rational of {data: Q.t} | Rational of {data: Q.t}
| RecRecord of int
[@@deriving compare, equal, hash, sexp] [@@deriving compare, equal, hash, sexp]
(* Note: solve (and invariant) requires Qset.min_elt to return a (* Note: solve (and invariant) requires Qset.min_elt to return a
@ -133,24 +132,8 @@ end
include T include T
module Map = struct include Map.Make (T) include Provide_of_sexp (T) end module Map = struct include Map.Make (T) include Provide_of_sexp (T) end
let fix (f : (t -> 'a as 'f) -> 'f) (bot : 'f) (e : t) : 'a =
let rec fix_f seen e =
match e with
| RecN _ ->
if List.mem ~equal:( == ) seen e then f bot e
else f (fix_f (e :: seen)) e
| _ -> f (fix_f seen) e
in
let rec fix_f_seen_nil e =
match e with RecN _ -> f (fix_f [e]) e | _ -> f fix_f_seen_nil e
in
fix_f_seen_nil e
let fix_flip (f : ('z -> t -> 'a as 'f) -> 'f) (bot : 'f) (z : 'z) (e : t) =
fix (fun f' e z -> f (fun z e -> f' e z) z e) (fun e z -> bot z e) e z
let rec ppx strength fs term = let rec ppx strength fs term =
let pp_ pp fs term = let rec pp fs term =
let pf fmt = let pf fmt =
Format.pp_open_box fs 2 ; Format.pp_open_box fs 2 ;
Format.kfprintf (fun fs -> Format.pp_close_box fs ()) fs fmt Format.kfprintf (fun fs -> Format.pp_close_box fs ()) fs fmt
@ -212,12 +195,12 @@ let rec ppx strength fs term =
| ApN (Concat, args) when IArray.is_empty args -> pf "@<2>⟨⟩" | ApN (Concat, args) when IArray.is_empty args -> pf "@<2>⟨⟩"
| ApN (Concat, args) -> pf "(%a)" (IArray.pp "@,^" pp) args | ApN (Concat, args) -> pf "(%a)" (IArray.pp "@,^" pp) args
| ApN (Record, elts) -> pf "{%a}" (pp_record strength) elts | ApN (Record, elts) -> pf "{%a}" (pp_record strength) elts
| RecN (Record, elts) -> pf "{|%a|}" (IArray.pp ",@ " pp) elts
| Ap1 (Select idx, rcd) -> pf "%a[%i]" pp rcd idx | Ap1 (Select idx, rcd) -> pf "%a[%i]" pp rcd idx
| Ap2 (Update idx, rcd, elt) -> | Ap2 (Update idx, rcd, elt) ->
pf "[%a@ @[| %i → %a@]]" pp rcd idx pp elt pf "[%a@ @[| %i → %a@]]" pp rcd idx pp elt
| RecRecord i -> pf "(rec_record %i)" i
in in
fix_flip pp_ (fun _ _ -> ()) fs term pp fs term
[@@warning "-9"] [@@warning "-9"]
and pp_record strength fs elts = and pp_record strength fs elts =
@ -325,8 +308,7 @@ let invariant e =
| Mul _ -> assert_monomial e |> Fn.id | Mul _ -> assert_monomial e |> Fn.id
| Ap2 (Memory, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _) -> | Ap2 (Memory, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _) ->
assert_aggregate e assert_aggregate e
| ApN (Record, elts) | RecN (Record, elts) -> | ApN (Record, elts) -> assert (not (IArray.is_empty elts))
assert (not (IArray.is_empty elts))
| Ap1 (Convert {src= Integer _; dst= Integer _}, _) -> assert false | Ap1 (Convert {src= Integer _; dst= Integer _}, _) -> assert false
| Ap1 (Convert {src; dst}, _) -> | Ap1 (Convert {src; dst}, _) ->
assert (Typ.convertible src dst) ; assert (Typ.convertible src dst) ;
@ -1023,28 +1005,7 @@ let simp_ashr x y =
let simp_record elts = ApN (Record, elts) let simp_record elts = ApN (Record, elts)
let simp_select idx rcd = Ap1 (Select idx, rcd) let simp_select idx rcd = Ap1 (Select idx, rcd)
let simp_update idx rcd elt = Ap2 (Update idx, rcd, elt) let simp_update idx rcd elt = Ap2 (Update idx, rcd, elt)
let simp_rec_record i = RecRecord i
let rec_app key =
let memo_id = Hashtbl.create key in
let dummy = null in
Staged.stage
@@ fun ~id op elt_thks ->
match Hashtbl.find memo_id id with
| None ->
(* Add placeholder to prevent computing [elts] in calls to [rec_app]
from [elt_thks] for recursive occurrences of [id]. *)
let elta = Array.create ~len:(IArray.length elt_thks) dummy in
let elts = IArray.of_array elta in
Hashtbl.set memo_id ~key:id ~data:elts ;
IArray.iteri elt_thks ~f:(fun i (lazy elt) -> elta.(i) <- elt) ;
RecN (op, elts) |> check invariant
| Some elts ->
(* Do not check invariant as invariant will be checked above after the
thunks are forced, before which invariant-checking may spuriously
fail. Note that it is important that the value constructed here
shares the array in the memo table, so that the update after
forcing the recursive thunks also updates this value. *)
RecN (op, elts)
(* dispatching for normalization and invariant checking *) (* dispatching for normalization and invariant checking *)
@ -1124,6 +1085,7 @@ let concat xs = normN Concat (IArray.of_array xs)
let record elts = normN Record elts let record elts = normN Record elts
let select ~rcd ~idx = norm1 (Select idx) rcd let select ~rcd ~idx = norm1 (Select idx) rcd
let update ~rcd ~idx ~elt = norm2 (Update idx) rcd elt let update ~rcd ~idx ~elt = norm2 (Update idx) rcd elt
let rec_record i = simp_rec_record i |> check invariant
let eq_concat (siz, arr) ms = let eq_concat (siz, arr) ms =
eq (memory ~siz ~arr) eq (memory ~siz ~arr)
@ -1168,12 +1130,9 @@ let map e ~f =
| Ap2 (op, x, y) -> map2 op ~f x y | Ap2 (op, x, y) -> map2 op ~f x y
| Ap3 (op, x, y, z) -> map3 op ~f x y z | Ap3 (op, x, y, z) -> map3 op ~f x y z
| ApN (op, xs) -> mapN op ~f xs | ApN (op, xs) -> mapN op ~f xs
| RecN (_, xs) -> | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _
assert ( |RecRecord _ ->
xs == IArray.map_endo ~f xs
|| fail "Term.map does not support updating subterms of RecN." () ) ;
e e
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> e
let fold_map e ~init ~f = let fold_map e ~init ~f =
let s = ref init in let s = ref init in
@ -1185,53 +1144,13 @@ let fold_map e ~init ~f =
let e' = map e ~f in let e' = map e ~f in
(!s, e') (!s, e')
let map_rec_pre e ~f = let rec map_rec_pre e ~f =
let rec map_rec_pre_f memo e = match f e with Some e' -> e' | None -> map ~f:(map_rec_pre ~f) e
match f e with
| Some e' -> e'
| None -> (
match e with
| RecN (op, xs) -> (
match List.Assoc.find ~equal:( == ) memo e with
| None ->
let xs' = IArray.to_array xs in
let e' = RecN (op, IArray.of_array xs') in
let memo = List.Assoc.add ~equal:( == ) memo e e' in
let changed = ref false in
Array.map_inplace xs' ~f:(fun x ->
let x' = map_rec_pre_f memo x in
if x' != x then changed := true ;
x' ) ;
if !changed then e' else e
| Some e' -> e' )
| _ -> map ~f:(map_rec_pre_f memo) e )
in
map_rec_pre_f [] e
let fold_map_rec_pre e ~init ~f = let rec fold_map_rec_pre e ~init:s ~f =
let rec fold_map_rec_pre_f memo s e =
match f s e with match f s e with
| Some (s, e') -> (s, e') | Some (s, e') -> (s, e')
| None -> ( | None -> fold_map ~f:(fun s e -> fold_map_rec_pre ~f ~init:s e) ~init:s e
match e with
| RecN (op, xs) -> (
match List.Assoc.find ~equal:( == ) memo e with
| None ->
let xs' = IArray.to_array xs in
let e' = RecN (op, IArray.of_array xs') in
let memo = List.Assoc.add ~equal:( == ) memo e e' in
let changed = ref false in
let s =
Array.fold_map_inplace ~init:s xs' ~f:(fun s x ->
let s, x' = fold_map_rec_pre_f memo s x in
if x' != x then changed := true ;
(s, x') )
in
if !changed then (s, e') else (s, e)
| Some e' -> (s, e') )
| _ -> fold_map ~f:(fold_map_rec_pre_f memo) ~init:s e )
in
fold_map_rec_pre_f [] init e
let rename sub e = let rename sub e =
map_rec_pre e ~f:(function map_rec_pre e ~f:(function
@ -1245,75 +1164,80 @@ let iter e ~f =
| Ap1 (_, x) -> f x | Ap1 (_, x) -> f x
| Ap2 (_, x, y) -> f x ; f y | Ap2 (_, x, y) -> f x ; f y
| Ap3 (_, x, y, z) -> f x ; f y ; f z | Ap3 (_, x, y, z) -> f x ; f y ; f z
| ApN (_, xs) | RecN (_, xs) -> IArray.iter ~f xs | ApN (_, xs) -> IArray.iter ~f xs
| And args | Or args -> Set.iter ~f args | And args | Or args -> Set.iter ~f args
| Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args | Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> () | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _
|RecRecord _ ->
()
let exists e ~f = let exists e ~f =
match e with match e with
| Ap1 (_, x) -> f x | Ap1 (_, x) -> f x
| Ap2 (_, x, y) -> f x || f y | Ap2 (_, x, y) -> f x || f y
| Ap3 (_, x, y, z) -> f x || f y || f z | Ap3 (_, x, y, z) -> f x || f y || f z
| ApN (_, xs) | RecN (_, xs) -> IArray.exists ~f xs | ApN (_, xs) -> IArray.exists ~f xs
| And args | Or args -> Set.exists ~f args | And args | Or args -> Set.exists ~f args
| Add args | Mul args -> Qset.exists ~f:(fun arg _ -> f arg) args | Add args | Mul args -> Qset.exists ~f:(fun arg _ -> f arg) args
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> false | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _
|RecRecord _ ->
false
let for_all e ~f = let for_all e ~f =
match e with match e with
| Ap1 (_, x) -> f x | Ap1 (_, x) -> f x
| Ap2 (_, x, y) -> f x && f y | Ap2 (_, x, y) -> f x && f y
| Ap3 (_, x, y, z) -> f x && f y && f z | Ap3 (_, x, y, z) -> f x && f y && f z
| ApN (_, xs) | RecN (_, xs) -> IArray.for_all ~f xs | ApN (_, xs) -> IArray.for_all ~f xs
| And args | Or args -> Set.for_all ~f args | And args | Or args -> Set.for_all ~f args
| Add args | Mul args -> Qset.for_all ~f:(fun arg _ -> f arg) args | Add args | Mul args -> Qset.for_all ~f:(fun arg _ -> f arg) args
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> true | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _
|RecRecord _ ->
true
let fold e ~init:s ~f = let fold e ~init:s ~f =
match e with match e with
| Ap1 (_, x) -> f x s | Ap1 (_, x) -> f x s
| Ap2 (_, x, y) -> f y (f x s) | Ap2 (_, x, y) -> f y (f x s)
| Ap3 (_, x, y, z) -> f z (f y (f x s)) | Ap3 (_, x, y, z) -> f z (f y (f x s))
| ApN (_, xs) | RecN (_, xs) -> | ApN (_, xs) -> IArray.fold ~f:(fun s x -> f x s) xs ~init:s
IArray.fold ~f:(fun s x -> f x s) xs ~init:s
| And args | Or args -> Set.fold ~f:(fun s e -> f e s) args ~init:s | And args | Or args -> Set.fold ~f:(fun s e -> f e s) args ~init:s
| Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s | Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> s | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _
|RecRecord _ ->
s
let iter_terms e ~f = let rec iter_terms e ~f =
let iter_terms_ iter_terms_ e =
( match e with ( match e with
| Ap1 (_, x) -> iter_terms_ x | Ap1 (_, x) -> iter_terms ~f x
| Ap2 (_, x, y) -> iter_terms_ x ; iter_terms_ y | Ap2 (_, x, y) -> iter_terms ~f x ; iter_terms ~f y
| Ap3 (_, x, y, z) -> iter_terms_ x ; iter_terms_ y ; iter_terms_ z | Ap3 (_, x, y, z) -> iter_terms ~f x ; iter_terms ~f y ; iter_terms ~f z
| ApN (_, xs) | RecN (_, xs) -> IArray.iter ~f:iter_terms_ xs | ApN (_, xs) -> IArray.iter ~f:(iter_terms ~f) xs
| And args | Or args -> Set.iter args ~f:iter_terms_ | And args | Or args -> Set.iter args ~f:(iter_terms ~f)
| Add args | Mul args -> | Add args | Mul args ->
Qset.iter args ~f:(fun arg _ -> iter_terms_ arg) Qset.iter args ~f:(fun arg _ -> iter_terms ~f arg)
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> () ) ; | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _
|RecRecord _ ->
() ) ;
f e f e
in
fix iter_terms_ (fun _ -> ()) e
let fold_terms e ~init ~f = let rec fold_terms e ~init:s ~f =
let fold_terms_ fold_terms_ e s = let fold_terms f e s = fold_terms e ~init:s ~f in
let s = let s =
match e with match e with
| Ap1 (_, x) -> fold_terms_ x s | Ap1 (_, x) -> fold_terms f x s
| Ap2 (_, x, y) -> fold_terms_ y (fold_terms_ x s) | Ap2 (_, x, y) -> fold_terms f y (fold_terms f x s)
| Ap3 (_, x, y, z) -> fold_terms_ z (fold_terms_ y (fold_terms_ x s)) | Ap3 (_, x, y, z) -> fold_terms f z (fold_terms f y (fold_terms f x s))
| ApN (_, xs) | RecN (_, xs) -> | ApN (_, xs) -> IArray.fold ~f:(fun s x -> fold_terms f x s) xs ~init:s
IArray.fold ~f:(fun s x -> fold_terms_ x s) xs ~init:s
| And args | Or args -> | And args | Or args ->
Set.fold args ~init:s ~f:(fun s x -> fold_terms_ x s) Set.fold args ~init:s ~f:(fun s x -> fold_terms f x s)
| Add args | Mul args -> | Add args | Mul args ->
Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms_ arg s) Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms f arg s)
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> s | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _
|RecRecord _ ->
s
in in
f s e f s e
in
fix fold_terms_ (fun _ s -> s) e init
let iter_vars e ~f = let iter_vars e ~f =
iter_terms e ~f:(function Var _ as v -> f (v :> Var.t) | _ -> ()) iter_terms e ~f:(function Var _ as v -> f (v :> Var.t) | _ -> ())
@ -1338,21 +1262,17 @@ let rec is_constant = function
| Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> true | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> true
| a -> for_all ~f:is_constant a | a -> for_all ~f:is_constant a
let height e = let rec height = function
let height_ height_ = function
| Var _ -> 0 | Var _ -> 0
| Ap1 (_, a) -> 1 + height_ a | Ap1 (_, a) -> 1 + height a
| Ap2 (_, a, b) -> 1 + max (height_ a) (height_ b) | Ap2 (_, a, b) -> 1 + max (height a) (height b)
| Ap3 (_, a, b, c) -> 1 + max (height_ a) (max (height_ b) (height_ c)) | Ap3 (_, a, b, c) -> 1 + max (height a) (max (height b) (height c))
| ApN (_, v) | RecN (_, v) -> | ApN (_, v) -> 1 + IArray.fold v ~init:0 ~f:(fun m a -> max m (height a))
1 + IArray.fold v ~init:0 ~f:(fun m a -> max m (height_ a))
| And bs | Or bs -> | And bs | Or bs ->
1 + Set.fold bs ~init:0 ~f:(fun m a -> max m (height_ a)) 1 + Set.fold bs ~init:0 ~f:(fun m a -> max m (height a))
| Add qs | Mul qs -> | Add qs | Mul qs ->
1 + Qset.fold qs ~init:0 ~f:(fun a _ m -> max m (height_ a)) 1 + Qset.fold qs ~init:0 ~f:(fun a _ m -> max m (height a))
| Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> 0 | Label _ | Nondet _ | Float _ | Integer _ | Rational _ | RecRecord _ -> 0
in
fix height_ (fun _ -> 0) e
(** Solve *) (** Solve *)

@ -57,9 +57,6 @@ type opN =
| Record (** Record (array / struct) constant *) | Record (** Record (array / struct) constant *)
[@@deriving compare, equal, hash, sexp] [@@deriving compare, equal, hash, sexp]
type recN = Record (** Recursive record (array / struct) constant *)
[@@deriving compare, equal, hash, sexp]
module rec Set : sig module rec Set : sig
include Import.Set.S with type elt := T.t include Import.Set.S with type elt := T.t
@ -86,10 +83,6 @@ and T : sig
| Ap2 of op2 * t * t (** Binary application *) | Ap2 of op2 * t * t (** Binary application *)
| Ap3 of op3 * t * t * t (** Ternary application *) | Ap3 of op3 * t * t * t (** Ternary application *)
| ApN of opN * t iarray (** N-ary application *) | ApN of opN * t iarray (** N-ary application *)
| RecN of recN * t iarray
(** Recursive n-ary application, may recursively refer to itself
(transitively) from its args. NOTE: represented by cyclic
values. *)
| And of set (** Conjunction, boolean or bitwise *) | And of set (** Conjunction, boolean or bitwise *)
| Or of set (** Disjunction, boolean or bitwise *) | Or of set (** Disjunction, boolean or bitwise *)
| Add of qset (** Sum of terms with rational coefficients *) | Add of qset (** Sum of terms with rational coefficients *)
@ -102,6 +95,7 @@ and T : sig
| Float of {data: string} (** Floating-point constant *) | Float of {data: string} (** Floating-point constant *)
| Integer of {data: Z.t} (** Integer constant *) | Integer of {data: Z.t} (** Integer constant *)
| Rational of {data: Q.t} (** Rational constant *) | Rational of {data: Q.t} (** Rational constant *)
| RecRecord of int (** Reference to ancestor recursive record *)
[@@deriving compare, equal, hash, sexp] [@@deriving compare, equal, hash, sexp]
end end
@ -242,11 +236,7 @@ val eq_concat : t * t -> (t * t) array -> t
val record : t iarray -> t val record : t iarray -> t
val select : rcd:t -> idx:int -> t val select : rcd:t -> idx:int -> t
val update : rcd:t -> idx:int -> elt:t -> t val update : rcd:t -> idx:int -> elt:t -> t
val rec_record : int -> t
(* recursive n-ary application *)
val rec_app :
(module Hashtbl.Key_plain with type t = 'id)
-> (id:'id -> recN -> t lazy_t iarray -> t) Staged.t
(** Transform *) (** Transform *)

Loading…
Cancel
Save