[sledge] Represent recursive records non-recursively

Summary:
In LLVM it is possible for struct constant values to be directly
recursive, with no pointer dereference to close the cycle. These
appear for example as the values of vtables from C++ code.

Currently such recursive records in the Exp and Term languages are
represented as genuinely cyclic values. Compared to a standard term
representation, the presence of cyclic values is a significant
complication everywhere. Since the backend solver does not do anything
such as induction over these, they have to be treated as essentially
atomic.

This patch changes the representation to a standard non-recursive tree
term structure. Instead of cyclic references, an explicit constructor
is added for the "non-tree edges", which simply indicates which
ancestor record value to which the recursive reference points.

There is a potential issue with this representation, since for
mutually recursive records, the representation is not canonical: it
chooses one of the records in the cycle to start from and expresses
the cycles relative to that. Currently the choice of representation is
dictated by the frontend. For the case of vtables, the frontend
translates globals in the same order they appear in the LLVM IR, so
the representation choice is fixed.

It may turn out that other potential uses require more reasoning
support in the backend solver, which would involve a theory of
equality of record values induced by equating the representations
resulting from different rotations of the cycle of records.

Reviewed By: jvillard

Differential Revision: D21441533

fbshipit-source-id: 0c5a11378
master
Josh Berdine 5 years ago committed by Facebook GitHub Bot
parent 849c61221d
commit 9d9060d213

@ -343,16 +343,6 @@ let should_inline : Llvm.llvalue -> bool =
| None -> true )
| None -> true
module Llvalue = struct
type t = Llvm.llvalue
let hash = Hashtbl.hash
let compare = Poly.compare
let sexp_of_t llv = Sexp.Atom (Llvm.string_of_llvalue llv)
end
let struct_rec = Staged.unstage (Exp.struct_rec (module Llvalue))
let ptr_fld x ~ptr ~fld ~lltyp =
let offset =
Llvm_target.DataLayout.offset_of_element lltyp fld x.lldatalayout
@ -377,34 +367,34 @@ let convert_to_siz =
let xlate_llvm_eh_typeid_for : x -> Typ.t -> Exp.t -> Exp.t =
fun x typ arg -> Exp.convert typ ~to_:(i32 x) arg
let rec xlate_intrinsic_exp : string -> (x -> Llvm.llvalue -> Exp.t) option
=
let rec xlate_intrinsic_exp stk :
string -> (x -> Llvm.llvalue -> Exp.t) option =
fun name ->
match name with
| "llvm.eh.typeid.for" ->
Some
(fun x llv ->
let rand = Llvm.operand llv 0 in
let arg = xlate_value x rand in
let arg = xlate_value stk x rand in
let src = xlate_type x (Llvm.type_of rand) in
xlate_llvm_eh_typeid_for x src arg )
| _ -> None
and xlate_value ?(inline = false) : x -> Llvm.llvalue -> Exp.t =
and xlate_value ?(inline = false) stk : x -> Llvm.llvalue -> Exp.t =
fun x llv ->
let xlate_value_ llv =
match Llvm.classify_value llv with
| Instruction Call -> (
let func = Llvm.operand llv (Llvm.num_arg_operands llv) in
let fname = Llvm.value_name func in
match xlate_intrinsic_exp fname with
match xlate_intrinsic_exp stk fname with
| Some intrinsic when inline || should_inline llv -> intrinsic x llv
| _ -> Exp.reg (xlate_name x llv) )
| Instruction (Invoke | Alloca | Load | PHI | LandingPad | VAArg)
|Argument ->
Exp.reg (xlate_name x llv)
| Function | GlobalVariable -> Exp.reg (xlate_global x llv).reg
| GlobalAlias -> xlate_value x (Llvm.operand llv 0)
| Function | GlobalVariable -> Exp.reg (xlate_global stk x llv).reg
| GlobalAlias -> xlate_value stk x (Llvm.operand llv 0)
| ConstantInt -> xlate_int x llv
| ConstantFP -> xlate_float x llv
| ConstantPointerNull -> Exp.null
@ -419,35 +409,27 @@ and xlate_value ?(inline = false) : x -> Llvm.llvalue -> Exp.t =
| ConstantVector | ConstantArray ->
let typ = xlate_type x (Llvm.type_of llv) in
let len = Llvm.num_operands llv in
let f i = xlate_value x (Llvm.operand llv i) in
let f i = xlate_value stk x (Llvm.operand llv i) in
Exp.record typ (IArray.init len ~f)
| ConstantDataVector ->
let typ = xlate_type x (Llvm.type_of llv) in
let len = Llvm.vector_size (Llvm.type_of llv) in
let f i = xlate_value x (Llvm.const_element llv i) in
let f i = xlate_value stk x (Llvm.const_element llv i) in
Exp.record typ (IArray.init len ~f)
| ConstantDataArray ->
let typ = xlate_type x (Llvm.type_of llv) in
let len = Llvm.array_length (Llvm.type_of llv) in
let f i = xlate_value x (Llvm.const_element llv i) in
let f i = xlate_value stk x (Llvm.const_element llv i) in
Exp.record typ (IArray.init len ~f)
| ConstantStruct ->
| ConstantStruct -> (
let typ = xlate_type x (Llvm.type_of llv) in
let is_recursive =
Llvm.fold_left_uses
(fun b use -> b || llv == Llvm.used_value use)
false llv
in
if is_recursive then
let elt_thks =
IArray.init (Llvm.num_operands llv) ~f:(fun i ->
lazy (xlate_value x (Llvm.operand llv i)) )
in
struct_rec ~id:llv typ elt_thks
else
match List.findi llv stk with
| Some i -> Exp.rec_record i typ
| None ->
let stk = llv :: stk in
Exp.record typ
(IArray.init (Llvm.num_operands llv) ~f:(fun i ->
xlate_value x (Llvm.operand llv i) ))
xlate_value stk x (Llvm.operand llv i) )) )
| BlockAddress ->
let parent = find_name (Llvm.operand llv 0) in
let name = find_name (Llvm.operand llv 1) in
@ -462,9 +444,9 @@ and xlate_value ?(inline = false) : x -> Llvm.llvalue -> Exp.t =
| SRem | FRem | Shl | LShr | AShr | And | Or | Xor | ICmp | FCmp
| Select | GetElementPtr | ExtractElement | InsertElement
| ShuffleVector | ExtractValue | InsertValue ) as opcode ) ->
if inline || should_inline llv then xlate_opcode x llv opcode
if inline || should_inline llv then xlate_opcode stk x llv opcode
else Exp.reg (xlate_name x llv)
| ConstantExpr -> xlate_opcode x llv (Llvm.constexpr_opcode llv)
| ConstantExpr -> xlate_opcode stk x llv (Llvm.constexpr_opcode llv)
| GlobalIFunc -> todo "ifuncs: %a" pp_llvalue llv ()
| Instruction (CatchPad | CleanupPad | CatchSwitch) ->
todo "windows exception handling: %a" pp_llvalue llv ()
@ -482,11 +464,11 @@ and xlate_value ?(inline = false) : x -> Llvm.llvalue -> Exp.t =
|>
[%Trace.retn fun {pf} exp -> pf "%a" Exp.pp exp] )
and xlate_opcode : x -> Llvm.llvalue -> Llvm.Opcode.t -> Exp.t =
and xlate_opcode stk : x -> Llvm.llvalue -> Llvm.Opcode.t -> Exp.t =
fun x llv opcode ->
[%Trace.call fun {pf} -> pf "%a" pp_llvalue llv]
;
let xlate_rand i = xlate_value x (Llvm.operand llv i) in
let xlate_rand i = xlate_value stk x (Llvm.operand llv i) in
let typ = lazy (xlate_type x (Llvm.type_of llv)) in
let check_vector =
lazy
@ -497,7 +479,7 @@ and xlate_opcode : x -> Llvm.llvalue -> Llvm.Opcode.t -> Exp.t =
let dst = Lazy.force typ in
let rand = Llvm.operand llv 0 in
let src = xlate_type x (Llvm.type_of rand) in
let arg = xlate_value x rand in
let arg = xlate_value stk x rand in
match (opcode : Llvm.Opcode.t) with
| Trunc -> Exp.signed (Typ.bit_size_of dst) arg ~to_:dst
| SExt -> Exp.signed (Typ.bit_size_of src) arg ~to_:dst
@ -671,7 +653,7 @@ and xlate_opcode : x -> Llvm.llvalue -> Llvm.Opcode.t -> Exp.t =
| ShuffleVector -> (
(* translate shufflevector <N x t> %x, _, <N x i32> zeroinitializer to
%x *)
let exp = xlate_value x (Llvm.operand llv 0) in
let exp = xlate_value stk x (Llvm.operand llv 0) in
let exp_typ = xlate_type x (Llvm.type_of (Llvm.operand llv 0)) in
let llmask = Llvm.operand llv 2 in
let mask_typ = xlate_type x (Llvm.type_of llmask) in
@ -687,7 +669,7 @@ and xlate_opcode : x -> Llvm.llvalue -> Llvm.Opcode.t -> Exp.t =
|>
[%Trace.retn fun {pf} exp -> pf "%a" Exp.pp exp]
and xlate_global : x -> Llvm.llvalue -> Global.t =
and xlate_global stk : x -> Llvm.llvalue -> Global.t =
fun x llg ->
Hashtbl.find_or_add memo_global llg ~default:(fun () ->
[%Trace.call fun {pf} -> pf "%a" pp_llvalue llg]
@ -702,13 +684,18 @@ and xlate_global : x -> Llvm.llvalue -> Global.t =
let init =
match Llvm.classify_value llg with
| GlobalVariable ->
Option.map ~f:(xlate_value x) (Llvm.global_initializer llg)
Option.map ~f:(xlate_value stk x) (Llvm.global_initializer llg)
| _ -> None
in
Global.mk ?init g typ loc
|>
[%Trace.retn fun {pf} -> pf "%a" Global.pp_defn] )
let xlate_intrinsic_exp = xlate_intrinsic_exp []
let xlate_value ?inline = xlate_value ?inline []
let xlate_opcode = xlate_opcode []
let xlate_global = xlate_global []
type pop_thunk = Loc.t -> Llair.inst list
let pop_stack_frame_of_function :

@ -17,8 +17,8 @@ let classify e =
| Add _ | Ap2 (Memory, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _) ->
Interpreted
| Mul _ | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ | And _ | Or _ -> Uninterpreted
| RecN _ | Var _ | Integer _ | Rational _ | Float _ | Nondet _ | Label _
->
| Var _ | Integer _ | Rational _ | Float _ | Nondet _ | Label _
|RecRecord _ ->
Atomic
let interpreted e = equal_kind (classify e) Interpreted

@ -57,10 +57,8 @@ module T = struct
| Conditional
[@@deriving compare, equal, hash, sexp]
type opN =
(* array/struct constants *)
type opN = (* array/struct constants *)
| Record
| Struct_rec (** NOTE: may be cyclic *)
[@@deriving compare, equal, hash, sexp]
type t = {desc: desc; term: Term.t}
@ -75,6 +73,7 @@ module T = struct
| Ap2 of op2 * Typ.t * t * t
| Ap3 of op3 * Typ.t * t * t * t
| ApN of opN * Typ.t * t iarray
| RecRecord of int * Typ.t
[@@deriving compare, equal, hash, sexp]
end
@ -84,24 +83,6 @@ module Map = Map.Make (T)
let term e = e.term
let fix (f : (t -> 'a as 'f) -> 'f) (bot : 'f) (e : t) : 'a =
let rec fix_f seen e =
match e.desc with
| ApN (Struct_rec, _, _) ->
if List.mem ~equal:( == ) seen e then f bot e
else f (fix_f (e :: seen)) e
| _ -> f (fix_f seen) e
in
let rec fix_f_seen_nil e =
match e.desc with
| ApN (Struct_rec, _, _) -> f (fix_f [e]) e
| _ -> f fix_f_seen_nil e
in
fix_f_seen_nil e
let fix_flip (f : ('z -> t -> 'a as 'f) -> 'f) (bot : 'f) (z : 'z) (e : t) =
fix (fun f' e z -> f (fun z e -> f' e z) z e) (fun e z -> bot z e) e z
let pp_op2 fs op =
let pf fmt = Format.fprintf fs fmt in
match op with
@ -133,7 +114,6 @@ let pp_op2 fs op =
| Update idx -> pf "[_|%i→_]" idx
let rec pp fs exp =
let pp_ pp fs exp =
let pf fmt =
Format.pp_open_box fs 2 ;
Format.kfprintf (fun fs -> Format.pp_close_box fs ()) fs fmt
@ -168,9 +148,7 @@ let rec pp fs exp =
| Ap3 (Conditional, _, cnd, thn, els) ->
pf "(%a@ ? %a@ : %a)" pp cnd pp thn pp els
| ApN (Record, _, elts) -> pf "{%a}" pp_record elts
| ApN (Struct_rec, _, elts) -> pf "{|%a|}" (IArray.pp ",@ " pp) elts
in
fix_flip pp_ (fun _ _ -> ()) fs exp
| RecRecord (i, _) -> pf "rec_record %i" i
[@@warning "-9"]
and pp_record fs elts =
@ -256,7 +234,7 @@ let rec invariant exp =
assert (Typ.castable Typ.bool (typ_of cnd)) ;
assert (Typ.castable typ (typ_of thn)) ;
assert (Typ.castable typ (typ_of els))
| ApN ((Record | Struct_rec), typ, args) -> (
| ApN (Record, typ, args) -> (
match typ with
| Array {elt} ->
assert (
@ -268,6 +246,7 @@ let rec invariant exp =
IArray.for_all2_exn elts args ~f:(fun typ arg ->
Typ.castable typ (typ_of arg) ) )
| _ -> assert false )
| RecRecord _ -> ()
[@@warning "-9"]
(** Type query *)
@ -295,7 +274,8 @@ and typ_of exp =
, _
, _ )
|Ap3 (Conditional, typ, _, _, _)
|ApN ((Record | Struct_rec), typ, _) ->
|ApN (Record, typ, _)
|RecRecord (_, typ) ->
typ
[@@warning "-9"]
@ -472,35 +452,12 @@ let update typ ~rcd idx ~elt =
; term= Term.update ~rcd:rcd.term ~idx ~elt:elt.term }
|> check invariant
let struct_rec key =
let memo_id = Hashtbl.create key in
let rec_app = (Staged.unstage (Term.rec_app key)) Term.Record in
Staged.stage
@@ fun ~id typ elt_thks ->
match Hashtbl.find memo_id id with
| None ->
(* Add placeholder to prevent computing [elts] in calls to
[struct_rec] from [elt_thks] for recursive occurrences of [id]. *)
let elta = Array.create ~len:(IArray.length elt_thks) null in
let elts = IArray.of_array elta in
Hashtbl.set memo_id ~key:id ~data:elts ;
let term =
rec_app ~id (IArray.map ~f:(fun elt -> lazy elt.term) elts)
in
IArray.iteri elt_thks ~f:(fun i (lazy elt) -> elta.(i) <- elt) ;
{desc= ApN (Struct_rec, typ, elts); term} |> check invariant
| Some elts ->
(* Do not check invariant as invariant will be checked above after the
thunks are forced, before which invariant-checking may spuriously
fail. Note that it is important that the value constructed here
shares the array in the memo table, so that the update after
forcing the recursive thunks also updates this value. *)
{desc= ApN (Struct_rec, typ, elts); term= rec_app ~id IArray.empty}
let rec_record i typ = {desc= RecRecord (i, typ); term= Term.rec_record i}
(** Traverse *)
let fold_exps e ~init ~f =
let fold_exps_ fold_exps_ e z =
let rec fold_exps_ e z =
let z =
match e.desc with
| Ap1 (_, _, x) -> fold_exps_ x z
@ -512,7 +469,7 @@ let fold_exps e ~init ~f =
in
f z e
in
fix fold_exps_ (fun _ z -> z) e init
fold_exps_ e init
let fold_regs e ~init ~f =
fold_exps e ~init ~f:(fun z x ->

@ -68,11 +68,7 @@ type op2 =
type op3 = Conditional (** If-then-else *)
[@@deriving compare, equal, hash, sexp]
type opN =
| Record (** Record (array / struct) constant *)
| Struct_rec
(** Struct constant that may recursively refer to itself
(transitively) from [elts]. NOTE: represented by cyclic values. *)
type opN = Record (** Record (array / struct) constant *)
[@@deriving compare, equal, hash, sexp]
type t = private {desc: desc; term: Term.t}
@ -90,6 +86,7 @@ and desc = private
| Ap2 of op2 * Typ.t * t * t
| Ap3 of op3 * Typ.t * t * t * t
| ApN of opN * Typ.t * t iarray
| RecRecord of int * Typ.t (** Reference to ancestor recursive record *)
[@@deriving compare, equal, hash, sexp]
val pp : t pp
@ -188,17 +185,7 @@ val splat : Typ.t -> t -> t
val record : Typ.t -> t iarray -> t
val select : Typ.t -> t -> int -> t
val update : Typ.t -> rcd:t -> int -> elt:t -> t
val struct_rec :
(module Hashtbl.Key_plain with type t = 'id)
-> (id:'id -> Typ.t -> t lazy_t iarray -> t) Staged.t
(** [struct_rec Id id element_thunks] constructs a possibly-cyclic [Struct]
value. Cycles are detected using [Id]. The caller of [struct_rec Id]
must ensure that a single unstaging of [struct_rec Id] is used for each
complete cyclic value. Also, the caller must ensure that recursive calls
to [struct_rec Id] provide [id] values that uniquely identify at least
one point on each cycle. Failure to obey these requirements will lead to
stack overflow. *)
val rec_record : int -> Typ.t -> t
(** Traverse *)

@ -18,6 +18,15 @@ let rec pp ?pre ?suf sep pp_elt fs = function
| xs -> Format.fprintf fs "%( %)%a" sep (pp sep pp_elt) xs ) ;
Option.iter suf ~f:(Format.fprintf fs)
let findi x xs =
let rec findi_ i xs =
match xs with
| [] -> None
| x' :: _ when x == x' -> Some i
| _ :: xs -> findi_ (i + 1) xs
in
findi_ 0 xs
let pop_exn = function
| x :: xs -> (x, xs)
| [] -> raise (Not_found_s (Atom __LOC__))

@ -22,6 +22,9 @@ val pp_diff :
-> 'a pp
-> ('a list * 'a list) pp
val findi : 'a -> 'a t -> int option
(** [findi x xs] is [Some i] when [nth xs i == x], otherwise [None]. *)
val pop_exn : 'a list -> 'a * 'a list
val find_map_remove :

@ -36,7 +36,6 @@ type op2 =
type op3 = Conditional | Extract [@@deriving compare, equal, hash, sexp]
type opN = Concat | Record [@@deriving compare, equal, hash, sexp]
type recN = Record [@@deriving compare, equal, hash, sexp]
module rec Set : sig
include Import.Set.S with type elt := T.t
@ -77,7 +76,6 @@ and T : sig
| Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t
| ApN of opN * t iarray
| RecN of recN * t iarray (** NOTE: cyclic *)
| And of set
| Or of set
| Add of qset
@ -87,6 +85,7 @@ and T : sig
| Float of {data: string}
| Integer of {data: Z.t}
| Rational of {data: Q.t}
| RecRecord of int
[@@deriving compare, equal, hash, sexp]
end = struct
type set = Set.t [@@deriving compare, equal, hash, sexp]
@ -98,7 +97,6 @@ end = struct
| Ap2 of op2 * t * t
| Ap3 of op3 * t * t * t
| ApN of opN * t iarray
| RecN of recN * t iarray (** NOTE: cyclic *)
| And of set
| Or of set
| Add of qset
@ -108,6 +106,7 @@ end = struct
| Float of {data: string}
| Integer of {data: Z.t}
| Rational of {data: Q.t}
| RecRecord of int
[@@deriving compare, equal, hash, sexp]
(* Note: solve (and invariant) requires Qset.min_elt to return a
@ -133,24 +132,8 @@ end
include T
module Map = struct include Map.Make (T) include Provide_of_sexp (T) end
let fix (f : (t -> 'a as 'f) -> 'f) (bot : 'f) (e : t) : 'a =
let rec fix_f seen e =
match e with
| RecN _ ->
if List.mem ~equal:( == ) seen e then f bot e
else f (fix_f (e :: seen)) e
| _ -> f (fix_f seen) e
in
let rec fix_f_seen_nil e =
match e with RecN _ -> f (fix_f [e]) e | _ -> f fix_f_seen_nil e
in
fix_f_seen_nil e
let fix_flip (f : ('z -> t -> 'a as 'f) -> 'f) (bot : 'f) (z : 'z) (e : t) =
fix (fun f' e z -> f (fun z e -> f' e z) z e) (fun e z -> bot z e) e z
let rec ppx strength fs term =
let pp_ pp fs term =
let rec pp fs term =
let pf fmt =
Format.pp_open_box fs 2 ;
Format.kfprintf (fun fs -> Format.pp_close_box fs ()) fs fmt
@ -212,12 +195,12 @@ let rec ppx strength fs term =
| ApN (Concat, args) when IArray.is_empty args -> pf "@<2>⟨⟩"
| ApN (Concat, args) -> pf "(%a)" (IArray.pp "@,^" pp) args
| ApN (Record, elts) -> pf "{%a}" (pp_record strength) elts
| RecN (Record, elts) -> pf "{|%a|}" (IArray.pp ",@ " pp) elts
| Ap1 (Select idx, rcd) -> pf "%a[%i]" pp rcd idx
| Ap2 (Update idx, rcd, elt) ->
pf "[%a@ @[| %i → %a@]]" pp rcd idx pp elt
| RecRecord i -> pf "(rec_record %i)" i
in
fix_flip pp_ (fun _ _ -> ()) fs term
pp fs term
[@@warning "-9"]
and pp_record strength fs elts =
@ -325,8 +308,7 @@ let invariant e =
| Mul _ -> assert_monomial e |> Fn.id
| Ap2 (Memory, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _) ->
assert_aggregate e
| ApN (Record, elts) | RecN (Record, elts) ->
assert (not (IArray.is_empty elts))
| ApN (Record, elts) -> assert (not (IArray.is_empty elts))
| Ap1 (Convert {src= Integer _; dst= Integer _}, _) -> assert false
| Ap1 (Convert {src; dst}, _) ->
assert (Typ.convertible src dst) ;
@ -1023,28 +1005,7 @@ let simp_ashr x y =
let simp_record elts = ApN (Record, elts)
let simp_select idx rcd = Ap1 (Select idx, rcd)
let simp_update idx rcd elt = Ap2 (Update idx, rcd, elt)
let rec_app key =
let memo_id = Hashtbl.create key in
let dummy = null in
Staged.stage
@@ fun ~id op elt_thks ->
match Hashtbl.find memo_id id with
| None ->
(* Add placeholder to prevent computing [elts] in calls to [rec_app]
from [elt_thks] for recursive occurrences of [id]. *)
let elta = Array.create ~len:(IArray.length elt_thks) dummy in
let elts = IArray.of_array elta in
Hashtbl.set memo_id ~key:id ~data:elts ;
IArray.iteri elt_thks ~f:(fun i (lazy elt) -> elta.(i) <- elt) ;
RecN (op, elts) |> check invariant
| Some elts ->
(* Do not check invariant as invariant will be checked above after the
thunks are forced, before which invariant-checking may spuriously
fail. Note that it is important that the value constructed here
shares the array in the memo table, so that the update after
forcing the recursive thunks also updates this value. *)
RecN (op, elts)
let simp_rec_record i = RecRecord i
(* dispatching for normalization and invariant checking *)
@ -1124,6 +1085,7 @@ let concat xs = normN Concat (IArray.of_array xs)
let record elts = normN Record elts
let select ~rcd ~idx = norm1 (Select idx) rcd
let update ~rcd ~idx ~elt = norm2 (Update idx) rcd elt
let rec_record i = simp_rec_record i |> check invariant
let eq_concat (siz, arr) ms =
eq (memory ~siz ~arr)
@ -1168,12 +1130,9 @@ let map e ~f =
| Ap2 (op, x, y) -> map2 op ~f x y
| Ap3 (op, x, y, z) -> map3 op ~f x y z
| ApN (op, xs) -> mapN op ~f xs
| RecN (_, xs) ->
assert (
xs == IArray.map_endo ~f xs
|| fail "Term.map does not support updating subterms of RecN." () ) ;
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _
|RecRecord _ ->
e
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> e
let fold_map e ~init ~f =
let s = ref init in
@ -1185,53 +1144,13 @@ let fold_map e ~init ~f =
let e' = map e ~f in
(!s, e')
let map_rec_pre e ~f =
let rec map_rec_pre_f memo e =
match f e with
| Some e' -> e'
| None -> (
match e with
| RecN (op, xs) -> (
match List.Assoc.find ~equal:( == ) memo e with
| None ->
let xs' = IArray.to_array xs in
let e' = RecN (op, IArray.of_array xs') in
let memo = List.Assoc.add ~equal:( == ) memo e e' in
let changed = ref false in
Array.map_inplace xs' ~f:(fun x ->
let x' = map_rec_pre_f memo x in
if x' != x then changed := true ;
x' ) ;
if !changed then e' else e
| Some e' -> e' )
| _ -> map ~f:(map_rec_pre_f memo) e )
in
map_rec_pre_f [] e
let rec map_rec_pre e ~f =
match f e with Some e' -> e' | None -> map ~f:(map_rec_pre ~f) e
let fold_map_rec_pre e ~init ~f =
let rec fold_map_rec_pre_f memo s e =
let rec fold_map_rec_pre e ~init:s ~f =
match f s e with
| Some (s, e') -> (s, e')
| None -> (
match e with
| RecN (op, xs) -> (
match List.Assoc.find ~equal:( == ) memo e with
| None ->
let xs' = IArray.to_array xs in
let e' = RecN (op, IArray.of_array xs') in
let memo = List.Assoc.add ~equal:( == ) memo e e' in
let changed = ref false in
let s =
Array.fold_map_inplace ~init:s xs' ~f:(fun s x ->
let s, x' = fold_map_rec_pre_f memo s x in
if x' != x then changed := true ;
(s, x') )
in
if !changed then (s, e') else (s, e)
| Some e' -> (s, e') )
| _ -> fold_map ~f:(fold_map_rec_pre_f memo) ~init:s e )
in
fold_map_rec_pre_f [] init e
| None -> fold_map ~f:(fun s e -> fold_map_rec_pre ~f ~init:s e) ~init:s e
let rename sub e =
map_rec_pre e ~f:(function
@ -1245,75 +1164,80 @@ let iter e ~f =
| Ap1 (_, x) -> f x
| Ap2 (_, x, y) -> f x ; f y
| Ap3 (_, x, y, z) -> f x ; f y ; f z
| ApN (_, xs) | RecN (_, xs) -> IArray.iter ~f xs
| ApN (_, xs) -> IArray.iter ~f xs
| And args | Or args -> Set.iter ~f args
| Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> ()
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _
|RecRecord _ ->
()
let exists e ~f =
match e with
| Ap1 (_, x) -> f x
| Ap2 (_, x, y) -> f x || f y
| Ap3 (_, x, y, z) -> f x || f y || f z
| ApN (_, xs) | RecN (_, xs) -> IArray.exists ~f xs
| ApN (_, xs) -> IArray.exists ~f xs
| And args | Or args -> Set.exists ~f args
| Add args | Mul args -> Qset.exists ~f:(fun arg _ -> f arg) args
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> false
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _
|RecRecord _ ->
false
let for_all e ~f =
match e with
| Ap1 (_, x) -> f x
| Ap2 (_, x, y) -> f x && f y
| Ap3 (_, x, y, z) -> f x && f y && f z
| ApN (_, xs) | RecN (_, xs) -> IArray.for_all ~f xs
| ApN (_, xs) -> IArray.for_all ~f xs
| And args | Or args -> Set.for_all ~f args
| Add args | Mul args -> Qset.for_all ~f:(fun arg _ -> f arg) args
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> true
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _
|RecRecord _ ->
true
let fold e ~init:s ~f =
match e with
| Ap1 (_, x) -> f x s
| Ap2 (_, x, y) -> f y (f x s)
| Ap3 (_, x, y, z) -> f z (f y (f x s))
| ApN (_, xs) | RecN (_, xs) ->
IArray.fold ~f:(fun s x -> f x s) xs ~init:s
| ApN (_, xs) -> IArray.fold ~f:(fun s x -> f x s) xs ~init:s
| And args | Or args -> Set.fold ~f:(fun s e -> f e s) args ~init:s
| Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> s
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _
|RecRecord _ ->
s
let iter_terms e ~f =
let iter_terms_ iter_terms_ e =
let rec iter_terms e ~f =
( match e with
| Ap1 (_, x) -> iter_terms_ x
| Ap2 (_, x, y) -> iter_terms_ x ; iter_terms_ y
| Ap3 (_, x, y, z) -> iter_terms_ x ; iter_terms_ y ; iter_terms_ z
| ApN (_, xs) | RecN (_, xs) -> IArray.iter ~f:iter_terms_ xs
| And args | Or args -> Set.iter args ~f:iter_terms_
| Ap1 (_, x) -> iter_terms ~f x
| Ap2 (_, x, y) -> iter_terms ~f x ; iter_terms ~f y
| Ap3 (_, x, y, z) -> iter_terms ~f x ; iter_terms ~f y ; iter_terms ~f z
| ApN (_, xs) -> IArray.iter ~f:(iter_terms ~f) xs
| And args | Or args -> Set.iter args ~f:(iter_terms ~f)
| Add args | Mul args ->
Qset.iter args ~f:(fun arg _ -> iter_terms_ arg)
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> () ) ;
Qset.iter args ~f:(fun arg _ -> iter_terms ~f arg)
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _
|RecRecord _ ->
() ) ;
f e
in
fix iter_terms_ (fun _ -> ()) e
let fold_terms e ~init ~f =
let fold_terms_ fold_terms_ e s =
let rec fold_terms e ~init:s ~f =
let fold_terms f e s = fold_terms e ~init:s ~f in
let s =
match e with
| Ap1 (_, x) -> fold_terms_ x s
| Ap2 (_, x, y) -> fold_terms_ y (fold_terms_ x s)
| Ap3 (_, x, y, z) -> fold_terms_ z (fold_terms_ y (fold_terms_ x s))
| ApN (_, xs) | RecN (_, xs) ->
IArray.fold ~f:(fun s x -> fold_terms_ x s) xs ~init:s
| Ap1 (_, x) -> fold_terms f x s
| Ap2 (_, x, y) -> fold_terms f y (fold_terms f x s)
| Ap3 (_, x, y, z) -> fold_terms f z (fold_terms f y (fold_terms f x s))
| ApN (_, xs) -> IArray.fold ~f:(fun s x -> fold_terms f x s) xs ~init:s
| And args | Or args ->
Set.fold args ~init:s ~f:(fun s x -> fold_terms_ x s)
Set.fold args ~init:s ~f:(fun s x -> fold_terms f x s)
| Add args | Mul args ->
Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms_ arg s)
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> s
Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms f arg s)
| Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _
|RecRecord _ ->
s
in
f s e
in
fix fold_terms_ (fun _ s -> s) e init
let iter_vars e ~f =
iter_terms e ~f:(function Var _ as v -> f (v :> Var.t) | _ -> ())
@ -1338,21 +1262,17 @@ let rec is_constant = function
| Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> true
| a -> for_all ~f:is_constant a
let height e =
let height_ height_ = function
let rec height = function
| Var _ -> 0
| Ap1 (_, a) -> 1 + height_ a
| Ap2 (_, a, b) -> 1 + max (height_ a) (height_ b)
| Ap3 (_, a, b, c) -> 1 + max (height_ a) (max (height_ b) (height_ c))
| ApN (_, v) | RecN (_, v) ->
1 + IArray.fold v ~init:0 ~f:(fun m a -> max m (height_ a))
| Ap1 (_, a) -> 1 + height a
| Ap2 (_, a, b) -> 1 + max (height a) (height b)
| Ap3 (_, a, b, c) -> 1 + max (height a) (max (height b) (height c))
| ApN (_, v) -> 1 + IArray.fold v ~init:0 ~f:(fun m a -> max m (height a))
| And bs | Or bs ->
1 + Set.fold bs ~init:0 ~f:(fun m a -> max m (height_ a))
1 + Set.fold bs ~init:0 ~f:(fun m a -> max m (height a))
| Add qs | Mul qs ->
1 + Qset.fold qs ~init:0 ~f:(fun a _ m -> max m (height_ a))
| Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> 0
in
fix height_ (fun _ -> 0) e
1 + Qset.fold qs ~init:0 ~f:(fun a _ m -> max m (height a))
| Label _ | Nondet _ | Float _ | Integer _ | Rational _ | RecRecord _ -> 0
(** Solve *)

@ -57,9 +57,6 @@ type opN =
| Record (** Record (array / struct) constant *)
[@@deriving compare, equal, hash, sexp]
type recN = Record (** Recursive record (array / struct) constant *)
[@@deriving compare, equal, hash, sexp]
module rec Set : sig
include Import.Set.S with type elt := T.t
@ -86,10 +83,6 @@ and T : sig
| Ap2 of op2 * t * t (** Binary application *)
| Ap3 of op3 * t * t * t (** Ternary application *)
| ApN of opN * t iarray (** N-ary application *)
| RecN of recN * t iarray
(** Recursive n-ary application, may recursively refer to itself
(transitively) from its args. NOTE: represented by cyclic
values. *)
| And of set (** Conjunction, boolean or bitwise *)
| Or of set (** Disjunction, boolean or bitwise *)
| Add of qset (** Sum of terms with rational coefficients *)
@ -102,6 +95,7 @@ and T : sig
| Float of {data: string} (** Floating-point constant *)
| Integer of {data: Z.t} (** Integer constant *)
| Rational of {data: Q.t} (** Rational constant *)
| RecRecord of int (** Reference to ancestor recursive record *)
[@@deriving compare, equal, hash, sexp]
end
@ -242,11 +236,7 @@ val eq_concat : t * t -> (t * t) array -> t
val record : t iarray -> t
val select : rcd:t -> idx:int -> t
val update : rcd:t -> idx:int -> elt:t -> t
(* recursive n-ary application *)
val rec_app :
(module Hashtbl.Key_plain with type t = 'id)
-> (id:'id -> recN -> t lazy_t iarray -> t) Staged.t
val rec_record : int -> t
(** Transform *)

Loading…
Cancel
Save