diff --git a/sledge/bin/frontend.ml b/sledge/bin/frontend.ml index f7ad93b23..c6f7cd471 100644 --- a/sledge/bin/frontend.ml +++ b/sledge/bin/frontend.ml @@ -343,16 +343,6 @@ let should_inline : Llvm.llvalue -> bool = | None -> true ) | None -> true -module Llvalue = struct - type t = Llvm.llvalue - - let hash = Hashtbl.hash - let compare = Poly.compare - let sexp_of_t llv = Sexp.Atom (Llvm.string_of_llvalue llv) -end - -let struct_rec = Staged.unstage (Exp.struct_rec (module Llvalue)) - let ptr_fld x ~ptr ~fld ~lltyp = let offset = Llvm_target.DataLayout.offset_of_element lltyp fld x.lldatalayout @@ -377,34 +367,34 @@ let convert_to_siz = let xlate_llvm_eh_typeid_for : x -> Typ.t -> Exp.t -> Exp.t = fun x typ arg -> Exp.convert typ ~to_:(i32 x) arg -let rec xlate_intrinsic_exp : string -> (x -> Llvm.llvalue -> Exp.t) option - = +let rec xlate_intrinsic_exp stk : + string -> (x -> Llvm.llvalue -> Exp.t) option = fun name -> match name with | "llvm.eh.typeid.for" -> Some (fun x llv -> let rand = Llvm.operand llv 0 in - let arg = xlate_value x rand in + let arg = xlate_value stk x rand in let src = xlate_type x (Llvm.type_of rand) in xlate_llvm_eh_typeid_for x src arg ) | _ -> None -and xlate_value ?(inline = false) : x -> Llvm.llvalue -> Exp.t = +and xlate_value ?(inline = false) stk : x -> Llvm.llvalue -> Exp.t = fun x llv -> let xlate_value_ llv = match Llvm.classify_value llv with | Instruction Call -> ( let func = Llvm.operand llv (Llvm.num_arg_operands llv) in let fname = Llvm.value_name func in - match xlate_intrinsic_exp fname with + match xlate_intrinsic_exp stk fname with | Some intrinsic when inline || should_inline llv -> intrinsic x llv | _ -> Exp.reg (xlate_name x llv) ) | Instruction (Invoke | Alloca | Load | PHI | LandingPad | VAArg) |Argument -> Exp.reg (xlate_name x llv) - | Function | GlobalVariable -> Exp.reg (xlate_global x llv).reg - | GlobalAlias -> xlate_value x (Llvm.operand llv 0) + | Function | GlobalVariable -> Exp.reg (xlate_global stk x llv).reg + | GlobalAlias -> xlate_value stk x (Llvm.operand llv 0) | ConstantInt -> xlate_int x llv | ConstantFP -> xlate_float x llv | ConstantPointerNull -> Exp.null @@ -419,35 +409,27 @@ and xlate_value ?(inline = false) : x -> Llvm.llvalue -> Exp.t = | ConstantVector | ConstantArray -> let typ = xlate_type x (Llvm.type_of llv) in let len = Llvm.num_operands llv in - let f i = xlate_value x (Llvm.operand llv i) in + let f i = xlate_value stk x (Llvm.operand llv i) in Exp.record typ (IArray.init len ~f) | ConstantDataVector -> let typ = xlate_type x (Llvm.type_of llv) in let len = Llvm.vector_size (Llvm.type_of llv) in - let f i = xlate_value x (Llvm.const_element llv i) in + let f i = xlate_value stk x (Llvm.const_element llv i) in Exp.record typ (IArray.init len ~f) | ConstantDataArray -> let typ = xlate_type x (Llvm.type_of llv) in let len = Llvm.array_length (Llvm.type_of llv) in - let f i = xlate_value x (Llvm.const_element llv i) in + let f i = xlate_value stk x (Llvm.const_element llv i) in Exp.record typ (IArray.init len ~f) - | ConstantStruct -> + | ConstantStruct -> ( let typ = xlate_type x (Llvm.type_of llv) in - let is_recursive = - Llvm.fold_left_uses - (fun b use -> b || llv == Llvm.used_value use) - false llv - in - if is_recursive then - let elt_thks = - IArray.init (Llvm.num_operands llv) ~f:(fun i -> - lazy (xlate_value x (Llvm.operand llv i)) ) - in - struct_rec ~id:llv typ elt_thks - else - Exp.record typ - (IArray.init (Llvm.num_operands llv) ~f:(fun i -> - xlate_value x (Llvm.operand llv i) )) + match List.findi llv stk with + | Some i -> Exp.rec_record i typ + | None -> + let stk = llv :: stk in + Exp.record typ + (IArray.init (Llvm.num_operands llv) ~f:(fun i -> + xlate_value stk x (Llvm.operand llv i) )) ) | BlockAddress -> let parent = find_name (Llvm.operand llv 0) in let name = find_name (Llvm.operand llv 1) in @@ -462,9 +444,9 @@ and xlate_value ?(inline = false) : x -> Llvm.llvalue -> Exp.t = | SRem | FRem | Shl | LShr | AShr | And | Or | Xor | ICmp | FCmp | Select | GetElementPtr | ExtractElement | InsertElement | ShuffleVector | ExtractValue | InsertValue ) as opcode ) -> - if inline || should_inline llv then xlate_opcode x llv opcode + if inline || should_inline llv then xlate_opcode stk x llv opcode else Exp.reg (xlate_name x llv) - | ConstantExpr -> xlate_opcode x llv (Llvm.constexpr_opcode llv) + | ConstantExpr -> xlate_opcode stk x llv (Llvm.constexpr_opcode llv) | GlobalIFunc -> todo "ifuncs: %a" pp_llvalue llv () | Instruction (CatchPad | CleanupPad | CatchSwitch) -> todo "windows exception handling: %a" pp_llvalue llv () @@ -482,11 +464,11 @@ and xlate_value ?(inline = false) : x -> Llvm.llvalue -> Exp.t = |> [%Trace.retn fun {pf} exp -> pf "%a" Exp.pp exp] ) -and xlate_opcode : x -> Llvm.llvalue -> Llvm.Opcode.t -> Exp.t = +and xlate_opcode stk : x -> Llvm.llvalue -> Llvm.Opcode.t -> Exp.t = fun x llv opcode -> [%Trace.call fun {pf} -> pf "%a" pp_llvalue llv] ; - let xlate_rand i = xlate_value x (Llvm.operand llv i) in + let xlate_rand i = xlate_value stk x (Llvm.operand llv i) in let typ = lazy (xlate_type x (Llvm.type_of llv)) in let check_vector = lazy @@ -497,7 +479,7 @@ and xlate_opcode : x -> Llvm.llvalue -> Llvm.Opcode.t -> Exp.t = let dst = Lazy.force typ in let rand = Llvm.operand llv 0 in let src = xlate_type x (Llvm.type_of rand) in - let arg = xlate_value x rand in + let arg = xlate_value stk x rand in match (opcode : Llvm.Opcode.t) with | Trunc -> Exp.signed (Typ.bit_size_of dst) arg ~to_:dst | SExt -> Exp.signed (Typ.bit_size_of src) arg ~to_:dst @@ -671,7 +653,7 @@ and xlate_opcode : x -> Llvm.llvalue -> Llvm.Opcode.t -> Exp.t = | ShuffleVector -> ( (* translate shufflevector %x, _, zeroinitializer to %x *) - let exp = xlate_value x (Llvm.operand llv 0) in + let exp = xlate_value stk x (Llvm.operand llv 0) in let exp_typ = xlate_type x (Llvm.type_of (Llvm.operand llv 0)) in let llmask = Llvm.operand llv 2 in let mask_typ = xlate_type x (Llvm.type_of llmask) in @@ -687,7 +669,7 @@ and xlate_opcode : x -> Llvm.llvalue -> Llvm.Opcode.t -> Exp.t = |> [%Trace.retn fun {pf} exp -> pf "%a" Exp.pp exp] -and xlate_global : x -> Llvm.llvalue -> Global.t = +and xlate_global stk : x -> Llvm.llvalue -> Global.t = fun x llg -> Hashtbl.find_or_add memo_global llg ~default:(fun () -> [%Trace.call fun {pf} -> pf "%a" pp_llvalue llg] @@ -702,13 +684,18 @@ and xlate_global : x -> Llvm.llvalue -> Global.t = let init = match Llvm.classify_value llg with | GlobalVariable -> - Option.map ~f:(xlate_value x) (Llvm.global_initializer llg) + Option.map ~f:(xlate_value stk x) (Llvm.global_initializer llg) | _ -> None in Global.mk ?init g typ loc |> [%Trace.retn fun {pf} -> pf "%a" Global.pp_defn] ) +let xlate_intrinsic_exp = xlate_intrinsic_exp [] +let xlate_value ?inline = xlate_value ?inline [] +let xlate_opcode = xlate_opcode [] +let xlate_global = xlate_global [] + type pop_thunk = Loc.t -> Llair.inst list let pop_stack_frame_of_function : diff --git a/sledge/src/equality.ml b/sledge/src/equality.ml index bac036fe8..0bc39562c 100644 --- a/sledge/src/equality.ml +++ b/sledge/src/equality.ml @@ -17,8 +17,8 @@ let classify e = | Add _ | Ap2 (Memory, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _) -> Interpreted | Mul _ | Ap1 _ | Ap2 _ | Ap3 _ | ApN _ | And _ | Or _ -> Uninterpreted - | RecN _ | Var _ | Integer _ | Rational _ | Float _ | Nondet _ | Label _ - -> + | Var _ | Integer _ | Rational _ | Float _ | Nondet _ | Label _ + |RecRecord _ -> Atomic let interpreted e = equal_kind (classify e) Interpreted diff --git a/sledge/src/exp.ml b/sledge/src/exp.ml index 3f8ba7d49..6a22228a9 100644 --- a/sledge/src/exp.ml +++ b/sledge/src/exp.ml @@ -57,10 +57,8 @@ module T = struct | Conditional [@@deriving compare, equal, hash, sexp] - type opN = - (* array/struct constants *) + type opN = (* array/struct constants *) | Record - | Struct_rec (** NOTE: may be cyclic *) [@@deriving compare, equal, hash, sexp] type t = {desc: desc; term: Term.t} @@ -75,6 +73,7 @@ module T = struct | Ap2 of op2 * Typ.t * t * t | Ap3 of op3 * Typ.t * t * t * t | ApN of opN * Typ.t * t iarray + | RecRecord of int * Typ.t [@@deriving compare, equal, hash, sexp] end @@ -84,24 +83,6 @@ module Map = Map.Make (T) let term e = e.term -let fix (f : (t -> 'a as 'f) -> 'f) (bot : 'f) (e : t) : 'a = - let rec fix_f seen e = - match e.desc with - | ApN (Struct_rec, _, _) -> - if List.mem ~equal:( == ) seen e then f bot e - else f (fix_f (e :: seen)) e - | _ -> f (fix_f seen) e - in - let rec fix_f_seen_nil e = - match e.desc with - | ApN (Struct_rec, _, _) -> f (fix_f [e]) e - | _ -> f fix_f_seen_nil e - in - fix_f_seen_nil e - -let fix_flip (f : ('z -> t -> 'a as 'f) -> 'f) (bot : 'f) (z : 'z) (e : t) = - fix (fun f' e z -> f (fun z e -> f' e z) z e) (fun e z -> bot z e) e z - let pp_op2 fs op = let pf fmt = Format.fprintf fs fmt in match op with @@ -133,44 +114,41 @@ let pp_op2 fs op = | Update idx -> pf "[_|%i→_]" idx let rec pp fs exp = - let pp_ pp fs exp = - let pf fmt = - Format.pp_open_box fs 2 ; - Format.kfprintf (fun fs -> Format.pp_close_box fs ()) fs fmt - in - match exp.desc with - | Reg {name} -> ( - match Var.of_term exp.term with - | Some v when Var.is_global v -> pf "%@%s" name - | _ -> pf "%%%s" name ) - | Nondet {msg} -> pf "nondet \"%s\"" msg - | Label {name} -> pf "%s" name - | Integer {data; typ= Pointer _} when Z.equal Z.zero data -> pf "null" - | Integer {data} -> Trace.pp_styled `Magenta "%a" fs Z.pp data - | Float {data} -> pf "%s" data - | Ap1 (Signed {bits}, dst, arg) -> - pf "((%a)(s%i)@ %a)" Typ.pp dst bits pp arg - | Ap1 (Unsigned {bits}, dst, arg) -> - pf "((%a)(u%i)@ %a)" Typ.pp dst bits pp arg - | Ap1 (Convert {src}, dst, arg) -> - pf "((%a)(%a)@ %a)" Typ.pp dst Typ.pp src pp arg - | Ap1 (Splat, _, byt) -> pf "%a^" pp byt - | Ap1 (Select idx, _, rcd) -> pf "%a[%i]" pp rcd idx - | Ap2 (Update idx, _, rcd, elt) -> - pf "[%a@ @[| %i → %a@]]" pp rcd idx pp elt - | Ap2 (Xor, Integer {bits= 1}, {desc= Integer {data}}, x) - when Z.is_true data -> - pf "¬%a" pp x - | Ap2 (Xor, Integer {bits= 1}, x, {desc= Integer {data}}) - when Z.is_true data -> - pf "¬%a" pp x - | Ap2 (op, _, x, y) -> pf "(%a@ %a %a)" pp x pp_op2 op pp y - | Ap3 (Conditional, _, cnd, thn, els) -> - pf "(%a@ ? %a@ : %a)" pp cnd pp thn pp els - | ApN (Record, _, elts) -> pf "{%a}" pp_record elts - | ApN (Struct_rec, _, elts) -> pf "{|%a|}" (IArray.pp ",@ " pp) elts + let pf fmt = + Format.pp_open_box fs 2 ; + Format.kfprintf (fun fs -> Format.pp_close_box fs ()) fs fmt in - fix_flip pp_ (fun _ _ -> ()) fs exp + match exp.desc with + | Reg {name} -> ( + match Var.of_term exp.term with + | Some v when Var.is_global v -> pf "%@%s" name + | _ -> pf "%%%s" name ) + | Nondet {msg} -> pf "nondet \"%s\"" msg + | Label {name} -> pf "%s" name + | Integer {data; typ= Pointer _} when Z.equal Z.zero data -> pf "null" + | Integer {data} -> Trace.pp_styled `Magenta "%a" fs Z.pp data + | Float {data} -> pf "%s" data + | Ap1 (Signed {bits}, dst, arg) -> + pf "((%a)(s%i)@ %a)" Typ.pp dst bits pp arg + | Ap1 (Unsigned {bits}, dst, arg) -> + pf "((%a)(u%i)@ %a)" Typ.pp dst bits pp arg + | Ap1 (Convert {src}, dst, arg) -> + pf "((%a)(%a)@ %a)" Typ.pp dst Typ.pp src pp arg + | Ap1 (Splat, _, byt) -> pf "%a^" pp byt + | Ap1 (Select idx, _, rcd) -> pf "%a[%i]" pp rcd idx + | Ap2 (Update idx, _, rcd, elt) -> + pf "[%a@ @[| %i → %a@]]" pp rcd idx pp elt + | Ap2 (Xor, Integer {bits= 1}, {desc= Integer {data}}, x) + when Z.is_true data -> + pf "¬%a" pp x + | Ap2 (Xor, Integer {bits= 1}, x, {desc= Integer {data}}) + when Z.is_true data -> + pf "¬%a" pp x + | Ap2 (op, _, x, y) -> pf "(%a@ %a %a)" pp x pp_op2 op pp y + | Ap3 (Conditional, _, cnd, thn, els) -> + pf "(%a@ ? %a@ : %a)" pp cnd pp thn pp els + | ApN (Record, _, elts) -> pf "{%a}" pp_record elts + | RecRecord (i, _) -> pf "rec_record %i" i [@@warning "-9"] and pp_record fs elts = @@ -256,7 +234,7 @@ let rec invariant exp = assert (Typ.castable Typ.bool (typ_of cnd)) ; assert (Typ.castable typ (typ_of thn)) ; assert (Typ.castable typ (typ_of els)) - | ApN ((Record | Struct_rec), typ, args) -> ( + | ApN (Record, typ, args) -> ( match typ with | Array {elt} -> assert ( @@ -268,6 +246,7 @@ let rec invariant exp = IArray.for_all2_exn elts args ~f:(fun typ arg -> Typ.castable typ (typ_of arg) ) ) | _ -> assert false ) + | RecRecord _ -> () [@@warning "-9"] (** Type query *) @@ -295,7 +274,8 @@ and typ_of exp = , _ , _ ) |Ap3 (Conditional, typ, _, _, _) - |ApN ((Record | Struct_rec), typ, _) -> + |ApN (Record, typ, _) + |RecRecord (_, typ) -> typ [@@warning "-9"] @@ -472,35 +452,12 @@ let update typ ~rcd idx ~elt = ; term= Term.update ~rcd:rcd.term ~idx ~elt:elt.term } |> check invariant -let struct_rec key = - let memo_id = Hashtbl.create key in - let rec_app = (Staged.unstage (Term.rec_app key)) Term.Record in - Staged.stage - @@ fun ~id typ elt_thks -> - match Hashtbl.find memo_id id with - | None -> - (* Add placeholder to prevent computing [elts] in calls to - [struct_rec] from [elt_thks] for recursive occurrences of [id]. *) - let elta = Array.create ~len:(IArray.length elt_thks) null in - let elts = IArray.of_array elta in - Hashtbl.set memo_id ~key:id ~data:elts ; - let term = - rec_app ~id (IArray.map ~f:(fun elt -> lazy elt.term) elts) - in - IArray.iteri elt_thks ~f:(fun i (lazy elt) -> elta.(i) <- elt) ; - {desc= ApN (Struct_rec, typ, elts); term} |> check invariant - | Some elts -> - (* Do not check invariant as invariant will be checked above after the - thunks are forced, before which invariant-checking may spuriously - fail. Note that it is important that the value constructed here - shares the array in the memo table, so that the update after - forcing the recursive thunks also updates this value. *) - {desc= ApN (Struct_rec, typ, elts); term= rec_app ~id IArray.empty} +let rec_record i typ = {desc= RecRecord (i, typ); term= Term.rec_record i} (** Traverse *) let fold_exps e ~init ~f = - let fold_exps_ fold_exps_ e z = + let rec fold_exps_ e z = let z = match e.desc with | Ap1 (_, _, x) -> fold_exps_ x z @@ -512,7 +469,7 @@ let fold_exps e ~init ~f = in f z e in - fix fold_exps_ (fun _ z -> z) e init + fold_exps_ e init let fold_regs e ~init ~f = fold_exps e ~init ~f:(fun z x -> diff --git a/sledge/src/exp.mli b/sledge/src/exp.mli index 34eba21e3..a27411904 100644 --- a/sledge/src/exp.mli +++ b/sledge/src/exp.mli @@ -68,11 +68,7 @@ type op2 = type op3 = Conditional (** If-then-else *) [@@deriving compare, equal, hash, sexp] -type opN = - | Record (** Record (array / struct) constant *) - | Struct_rec - (** Struct constant that may recursively refer to itself - (transitively) from [elts]. NOTE: represented by cyclic values. *) +type opN = Record (** Record (array / struct) constant *) [@@deriving compare, equal, hash, sexp] type t = private {desc: desc; term: Term.t} @@ -90,6 +86,7 @@ and desc = private | Ap2 of op2 * Typ.t * t * t | Ap3 of op3 * Typ.t * t * t * t | ApN of opN * Typ.t * t iarray + | RecRecord of int * Typ.t (** Reference to ancestor recursive record *) [@@deriving compare, equal, hash, sexp] val pp : t pp @@ -188,17 +185,7 @@ val splat : Typ.t -> t -> t val record : Typ.t -> t iarray -> t val select : Typ.t -> t -> int -> t val update : Typ.t -> rcd:t -> int -> elt:t -> t - -val struct_rec : - (module Hashtbl.Key_plain with type t = 'id) - -> (id:'id -> Typ.t -> t lazy_t iarray -> t) Staged.t -(** [struct_rec Id id element_thunks] constructs a possibly-cyclic [Struct] - value. Cycles are detected using [Id]. The caller of [struct_rec Id] - must ensure that a single unstaging of [struct_rec Id] is used for each - complete cyclic value. Also, the caller must ensure that recursive calls - to [struct_rec Id] provide [id] values that uniquely identify at least - one point on each cycle. Failure to obey these requirements will lead to - stack overflow. *) +val rec_record : int -> Typ.t -> t (** Traverse *) diff --git a/sledge/src/import/list.ml b/sledge/src/import/list.ml index 57f8962e5..ccf95310d 100644 --- a/sledge/src/import/list.ml +++ b/sledge/src/import/list.ml @@ -18,6 +18,15 @@ let rec pp ?pre ?suf sep pp_elt fs = function | xs -> Format.fprintf fs "%( %)%a" sep (pp sep pp_elt) xs ) ; Option.iter suf ~f:(Format.fprintf fs) +let findi x xs = + let rec findi_ i xs = + match xs with + | [] -> None + | x' :: _ when x == x' -> Some i + | _ :: xs -> findi_ (i + 1) xs + in + findi_ 0 xs + let pop_exn = function | x :: xs -> (x, xs) | [] -> raise (Not_found_s (Atom __LOC__)) diff --git a/sledge/src/import/list.mli b/sledge/src/import/list.mli index 040834fd7..a068ceb02 100644 --- a/sledge/src/import/list.mli +++ b/sledge/src/import/list.mli @@ -22,6 +22,9 @@ val pp_diff : -> 'a pp -> ('a list * 'a list) pp +val findi : 'a -> 'a t -> int option +(** [findi x xs] is [Some i] when [nth xs i == x], otherwise [None]. *) + val pop_exn : 'a list -> 'a * 'a list val find_map_remove : diff --git a/sledge/src/term.ml b/sledge/src/term.ml index 98788d883..319005577 100644 --- a/sledge/src/term.ml +++ b/sledge/src/term.ml @@ -36,7 +36,6 @@ type op2 = type op3 = Conditional | Extract [@@deriving compare, equal, hash, sexp] type opN = Concat | Record [@@deriving compare, equal, hash, sexp] -type recN = Record [@@deriving compare, equal, hash, sexp] module rec Set : sig include Import.Set.S with type elt := T.t @@ -77,7 +76,6 @@ and T : sig | Ap2 of op2 * t * t | Ap3 of op3 * t * t * t | ApN of opN * t iarray - | RecN of recN * t iarray (** NOTE: cyclic *) | And of set | Or of set | Add of qset @@ -87,6 +85,7 @@ and T : sig | Float of {data: string} | Integer of {data: Z.t} | Rational of {data: Q.t} + | RecRecord of int [@@deriving compare, equal, hash, sexp] end = struct type set = Set.t [@@deriving compare, equal, hash, sexp] @@ -98,7 +97,6 @@ end = struct | Ap2 of op2 * t * t | Ap3 of op3 * t * t * t | ApN of opN * t iarray - | RecN of recN * t iarray (** NOTE: cyclic *) | And of set | Or of set | Add of qset @@ -108,6 +106,7 @@ end = struct | Float of {data: string} | Integer of {data: Z.t} | Rational of {data: Q.t} + | RecRecord of int [@@deriving compare, equal, hash, sexp] (* Note: solve (and invariant) requires Qset.min_elt to return a @@ -133,24 +132,8 @@ end include T module Map = struct include Map.Make (T) include Provide_of_sexp (T) end -let fix (f : (t -> 'a as 'f) -> 'f) (bot : 'f) (e : t) : 'a = - let rec fix_f seen e = - match e with - | RecN _ -> - if List.mem ~equal:( == ) seen e then f bot e - else f (fix_f (e :: seen)) e - | _ -> f (fix_f seen) e - in - let rec fix_f_seen_nil e = - match e with RecN _ -> f (fix_f [e]) e | _ -> f fix_f_seen_nil e - in - fix_f_seen_nil e - -let fix_flip (f : ('z -> t -> 'a as 'f) -> 'f) (bot : 'f) (z : 'z) (e : t) = - fix (fun f' e z -> f (fun z e -> f' e z) z e) (fun e z -> bot z e) e z - let rec ppx strength fs term = - let pp_ pp fs term = + let rec pp fs term = let pf fmt = Format.pp_open_box fs 2 ; Format.kfprintf (fun fs -> Format.pp_close_box fs ()) fs fmt @@ -212,12 +195,12 @@ let rec ppx strength fs term = | ApN (Concat, args) when IArray.is_empty args -> pf "@<2>⟨⟩" | ApN (Concat, args) -> pf "(%a)" (IArray.pp "@,^" pp) args | ApN (Record, elts) -> pf "{%a}" (pp_record strength) elts - | RecN (Record, elts) -> pf "{|%a|}" (IArray.pp ",@ " pp) elts | Ap1 (Select idx, rcd) -> pf "%a[%i]" pp rcd idx | Ap2 (Update idx, rcd, elt) -> pf "[%a@ @[| %i → %a@]]" pp rcd idx pp elt + | RecRecord i -> pf "(rec_record %i)" i in - fix_flip pp_ (fun _ _ -> ()) fs term + pp fs term [@@warning "-9"] and pp_record strength fs elts = @@ -325,8 +308,7 @@ let invariant e = | Mul _ -> assert_monomial e |> Fn.id | Ap2 (Memory, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _) -> assert_aggregate e - | ApN (Record, elts) | RecN (Record, elts) -> - assert (not (IArray.is_empty elts)) + | ApN (Record, elts) -> assert (not (IArray.is_empty elts)) | Ap1 (Convert {src= Integer _; dst= Integer _}, _) -> assert false | Ap1 (Convert {src; dst}, _) -> assert (Typ.convertible src dst) ; @@ -1023,28 +1005,7 @@ let simp_ashr x y = let simp_record elts = ApN (Record, elts) let simp_select idx rcd = Ap1 (Select idx, rcd) let simp_update idx rcd elt = Ap2 (Update idx, rcd, elt) - -let rec_app key = - let memo_id = Hashtbl.create key in - let dummy = null in - Staged.stage - @@ fun ~id op elt_thks -> - match Hashtbl.find memo_id id with - | None -> - (* Add placeholder to prevent computing [elts] in calls to [rec_app] - from [elt_thks] for recursive occurrences of [id]. *) - let elta = Array.create ~len:(IArray.length elt_thks) dummy in - let elts = IArray.of_array elta in - Hashtbl.set memo_id ~key:id ~data:elts ; - IArray.iteri elt_thks ~f:(fun i (lazy elt) -> elta.(i) <- elt) ; - RecN (op, elts) |> check invariant - | Some elts -> - (* Do not check invariant as invariant will be checked above after the - thunks are forced, before which invariant-checking may spuriously - fail. Note that it is important that the value constructed here - shares the array in the memo table, so that the update after - forcing the recursive thunks also updates this value. *) - RecN (op, elts) +let simp_rec_record i = RecRecord i (* dispatching for normalization and invariant checking *) @@ -1124,6 +1085,7 @@ let concat xs = normN Concat (IArray.of_array xs) let record elts = normN Record elts let select ~rcd ~idx = norm1 (Select idx) rcd let update ~rcd ~idx ~elt = norm2 (Update idx) rcd elt +let rec_record i = simp_rec_record i |> check invariant let eq_concat (siz, arr) ms = eq (memory ~siz ~arr) @@ -1168,12 +1130,9 @@ let map e ~f = | Ap2 (op, x, y) -> map2 op ~f x y | Ap3 (op, x, y, z) -> map3 op ~f x y z | ApN (op, xs) -> mapN op ~f xs - | RecN (_, xs) -> - assert ( - xs == IArray.map_endo ~f xs - || fail "Term.map does not support updating subterms of RecN." () ) ; + | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ + |RecRecord _ -> e - | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> e let fold_map e ~init ~f = let s = ref init in @@ -1185,53 +1144,13 @@ let fold_map e ~init ~f = let e' = map e ~f in (!s, e') -let map_rec_pre e ~f = - let rec map_rec_pre_f memo e = - match f e with - | Some e' -> e' - | None -> ( - match e with - | RecN (op, xs) -> ( - match List.Assoc.find ~equal:( == ) memo e with - | None -> - let xs' = IArray.to_array xs in - let e' = RecN (op, IArray.of_array xs') in - let memo = List.Assoc.add ~equal:( == ) memo e e' in - let changed = ref false in - Array.map_inplace xs' ~f:(fun x -> - let x' = map_rec_pre_f memo x in - if x' != x then changed := true ; - x' ) ; - if !changed then e' else e - | Some e' -> e' ) - | _ -> map ~f:(map_rec_pre_f memo) e ) - in - map_rec_pre_f [] e - -let fold_map_rec_pre e ~init ~f = - let rec fold_map_rec_pre_f memo s e = - match f s e with - | Some (s, e') -> (s, e') - | None -> ( - match e with - | RecN (op, xs) -> ( - match List.Assoc.find ~equal:( == ) memo e with - | None -> - let xs' = IArray.to_array xs in - let e' = RecN (op, IArray.of_array xs') in - let memo = List.Assoc.add ~equal:( == ) memo e e' in - let changed = ref false in - let s = - Array.fold_map_inplace ~init:s xs' ~f:(fun s x -> - let s, x' = fold_map_rec_pre_f memo s x in - if x' != x then changed := true ; - (s, x') ) - in - if !changed then (s, e') else (s, e) - | Some e' -> (s, e') ) - | _ -> fold_map ~f:(fold_map_rec_pre_f memo) ~init:s e ) - in - fold_map_rec_pre_f [] init e +let rec map_rec_pre e ~f = + match f e with Some e' -> e' | None -> map ~f:(map_rec_pre ~f) e + +let rec fold_map_rec_pre e ~init:s ~f = + match f s e with + | Some (s, e') -> (s, e') + | None -> fold_map ~f:(fun s e -> fold_map_rec_pre ~f ~init:s e) ~init:s e let rename sub e = map_rec_pre e ~f:(function @@ -1245,75 +1164,80 @@ let iter e ~f = | Ap1 (_, x) -> f x | Ap2 (_, x, y) -> f x ; f y | Ap3 (_, x, y, z) -> f x ; f y ; f z - | ApN (_, xs) | RecN (_, xs) -> IArray.iter ~f xs + | ApN (_, xs) -> IArray.iter ~f xs | And args | Or args -> Set.iter ~f args | Add args | Mul args -> Qset.iter ~f:(fun arg _ -> f arg) args - | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> () + | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ + |RecRecord _ -> + () let exists e ~f = match e with | Ap1 (_, x) -> f x | Ap2 (_, x, y) -> f x || f y | Ap3 (_, x, y, z) -> f x || f y || f z - | ApN (_, xs) | RecN (_, xs) -> IArray.exists ~f xs + | ApN (_, xs) -> IArray.exists ~f xs | And args | Or args -> Set.exists ~f args | Add args | Mul args -> Qset.exists ~f:(fun arg _ -> f arg) args - | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> false + | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ + |RecRecord _ -> + false let for_all e ~f = match e with | Ap1 (_, x) -> f x | Ap2 (_, x, y) -> f x && f y | Ap3 (_, x, y, z) -> f x && f y && f z - | ApN (_, xs) | RecN (_, xs) -> IArray.for_all ~f xs + | ApN (_, xs) -> IArray.for_all ~f xs | And args | Or args -> Set.for_all ~f args | Add args | Mul args -> Qset.for_all ~f:(fun arg _ -> f arg) args - | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> true + | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ + |RecRecord _ -> + true let fold e ~init:s ~f = match e with | Ap1 (_, x) -> f x s | Ap2 (_, x, y) -> f y (f x s) | Ap3 (_, x, y, z) -> f z (f y (f x s)) - | ApN (_, xs) | RecN (_, xs) -> - IArray.fold ~f:(fun s x -> f x s) xs ~init:s + | ApN (_, xs) -> IArray.fold ~f:(fun s x -> f x s) xs ~init:s | And args | Or args -> Set.fold ~f:(fun s e -> f e s) args ~init:s | Add args | Mul args -> Qset.fold ~f:(fun e _ s -> f e s) args ~init:s - | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> s - -let iter_terms e ~f = - let iter_terms_ iter_terms_ e = - ( match e with - | Ap1 (_, x) -> iter_terms_ x - | Ap2 (_, x, y) -> iter_terms_ x ; iter_terms_ y - | Ap3 (_, x, y, z) -> iter_terms_ x ; iter_terms_ y ; iter_terms_ z - | ApN (_, xs) | RecN (_, xs) -> IArray.iter ~f:iter_terms_ xs - | And args | Or args -> Set.iter args ~f:iter_terms_ + | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ + |RecRecord _ -> + s + +let rec iter_terms e ~f = + ( match e with + | Ap1 (_, x) -> iter_terms ~f x + | Ap2 (_, x, y) -> iter_terms ~f x ; iter_terms ~f y + | Ap3 (_, x, y, z) -> iter_terms ~f x ; iter_terms ~f y ; iter_terms ~f z + | ApN (_, xs) -> IArray.iter ~f:(iter_terms ~f) xs + | And args | Or args -> Set.iter args ~f:(iter_terms ~f) + | Add args | Mul args -> + Qset.iter args ~f:(fun arg _ -> iter_terms ~f arg) + | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ + |RecRecord _ -> + () ) ; + f e + +let rec fold_terms e ~init:s ~f = + let fold_terms f e s = fold_terms e ~init:s ~f in + let s = + match e with + | Ap1 (_, x) -> fold_terms f x s + | Ap2 (_, x, y) -> fold_terms f y (fold_terms f x s) + | Ap3 (_, x, y, z) -> fold_terms f z (fold_terms f y (fold_terms f x s)) + | ApN (_, xs) -> IArray.fold ~f:(fun s x -> fold_terms f x s) xs ~init:s + | And args | Or args -> + Set.fold args ~init:s ~f:(fun s x -> fold_terms f x s) | Add args | Mul args -> - Qset.iter args ~f:(fun arg _ -> iter_terms_ arg) - | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> () ) ; - f e - in - fix iter_terms_ (fun _ -> ()) e - -let fold_terms e ~init ~f = - let fold_terms_ fold_terms_ e s = - let s = - match e with - | Ap1 (_, x) -> fold_terms_ x s - | Ap2 (_, x, y) -> fold_terms_ y (fold_terms_ x s) - | Ap3 (_, x, y, z) -> fold_terms_ z (fold_terms_ y (fold_terms_ x s)) - | ApN (_, xs) | RecN (_, xs) -> - IArray.fold ~f:(fun s x -> fold_terms_ x s) xs ~init:s - | And args | Or args -> - Set.fold args ~init:s ~f:(fun s x -> fold_terms_ x s) - | Add args | Mul args -> - Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms_ arg s) - | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> s - in - f s e + Qset.fold args ~init:s ~f:(fun arg _ s -> fold_terms f arg s) + | Var _ | Label _ | Nondet _ | Float _ | Integer _ | Rational _ + |RecRecord _ -> + s in - fix fold_terms_ (fun _ s -> s) e init + f s e let iter_vars e ~f = iter_terms e ~f:(function Var _ as v -> f (v :> Var.t) | _ -> ()) @@ -1338,21 +1262,17 @@ let rec is_constant = function | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> true | a -> for_all ~f:is_constant a -let height e = - let height_ height_ = function - | Var _ -> 0 - | Ap1 (_, a) -> 1 + height_ a - | Ap2 (_, a, b) -> 1 + max (height_ a) (height_ b) - | Ap3 (_, a, b, c) -> 1 + max (height_ a) (max (height_ b) (height_ c)) - | ApN (_, v) | RecN (_, v) -> - 1 + IArray.fold v ~init:0 ~f:(fun m a -> max m (height_ a)) - | And bs | Or bs -> - 1 + Set.fold bs ~init:0 ~f:(fun m a -> max m (height_ a)) - | Add qs | Mul qs -> - 1 + Qset.fold qs ~init:0 ~f:(fun a _ m -> max m (height_ a)) - | Label _ | Nondet _ | Float _ | Integer _ | Rational _ -> 0 - in - fix height_ (fun _ -> 0) e +let rec height = function + | Var _ -> 0 + | Ap1 (_, a) -> 1 + height a + | Ap2 (_, a, b) -> 1 + max (height a) (height b) + | Ap3 (_, a, b, c) -> 1 + max (height a) (max (height b) (height c)) + | ApN (_, v) -> 1 + IArray.fold v ~init:0 ~f:(fun m a -> max m (height a)) + | And bs | Or bs -> + 1 + Set.fold bs ~init:0 ~f:(fun m a -> max m (height a)) + | Add qs | Mul qs -> + 1 + Qset.fold qs ~init:0 ~f:(fun a _ m -> max m (height a)) + | Label _ | Nondet _ | Float _ | Integer _ | Rational _ | RecRecord _ -> 0 (** Solve *) diff --git a/sledge/src/term.mli b/sledge/src/term.mli index 60f14e8cb..9d029f1cb 100644 --- a/sledge/src/term.mli +++ b/sledge/src/term.mli @@ -57,9 +57,6 @@ type opN = | Record (** Record (array / struct) constant *) [@@deriving compare, equal, hash, sexp] -type recN = Record (** Recursive record (array / struct) constant *) -[@@deriving compare, equal, hash, sexp] - module rec Set : sig include Import.Set.S with type elt := T.t @@ -86,10 +83,6 @@ and T : sig | Ap2 of op2 * t * t (** Binary application *) | Ap3 of op3 * t * t * t (** Ternary application *) | ApN of opN * t iarray (** N-ary application *) - | RecN of recN * t iarray - (** Recursive n-ary application, may recursively refer to itself - (transitively) from its args. NOTE: represented by cyclic - values. *) | And of set (** Conjunction, boolean or bitwise *) | Or of set (** Disjunction, boolean or bitwise *) | Add of qset (** Sum of terms with rational coefficients *) @@ -102,6 +95,7 @@ and T : sig | Float of {data: string} (** Floating-point constant *) | Integer of {data: Z.t} (** Integer constant *) | Rational of {data: Q.t} (** Rational constant *) + | RecRecord of int (** Reference to ancestor recursive record *) [@@deriving compare, equal, hash, sexp] end @@ -242,11 +236,7 @@ val eq_concat : t * t -> (t * t) array -> t val record : t iarray -> t val select : rcd:t -> idx:int -> t val update : rcd:t -> idx:int -> elt:t -> t - -(* recursive n-ary application *) -val rec_app : - (module Hashtbl.Key_plain with type t = 'id) - -> (id:'id -> recN -> t lazy_t iarray -> t) Staged.t +val rec_record : int -> t (** Transform *)