diff --git a/sledge/src/fol/context.ml b/sledge/src/fol/context.ml index 27a3d321d..44d99c6ec 100644 --- a/sledge/src/fol/context.ml +++ b/sledge/src/fol/context.ml @@ -32,7 +32,7 @@ let uninterpreted e = equal_kind (classify e) Uninterpreted let rec max_solvables e = if non_interpreted e then Iter.return e - else Iter.flat_map ~f:max_solvables (Trm.subtrms e) + else Iter.flat_map ~f:max_solvables (Trm.trms e) let fold_max_solvables e s ~f = Iter.fold ~f (max_solvables e) s @@ -168,8 +168,8 @@ end = struct in ( is_var_in xs e || is_var_in xs f - || (uninterpreted e && Iter.exists ~f:(is_var_in xs) (Trm.subtrms e)) - || (uninterpreted f && Iter.exists ~f:(is_var_in xs) (Trm.subtrms f)) ) + || (uninterpreted e && Iter.exists ~f:(is_var_in xs) (Trm.trms e)) + || (uninterpreted f && Iter.exists ~f:(is_var_in xs) (Trm.trms f)) ) $> fun b -> [%Trace.info "is_valid_eq %a%a=%a = %b" Var.Set.pp_xs xs Trm.pp e Trm.pp f b] @@ -457,7 +457,7 @@ let pre_invariant r = (* no interpreted terms in carrier *) assert (non_interpreted trm || fail "non-interp %a" Trm.pp trm ()) ; (* carrier is closed under subterms *) - Iter.iter (Trm.subtrms trm) ~f:(fun subtrm -> + Iter.iter (Trm.trms trm) ~f:(fun subtrm -> assert ( (not (non_interpreted subtrm)) || (match subtrm with Z _ | Q _ -> true | _ -> false) @@ -525,12 +525,12 @@ let rec extend_ a r = match (a : Trm.t) with | Z _ | Q _ -> r | _ -> ( - if interpreted a then Iter.fold ~f:extend_ (Trm.subtrms a) r + if interpreted a then Iter.fold ~f:extend_ (Trm.trms a) r else (* add uninterpreted terms *) match Subst.extend a r with (* and their subterms if newly added *) - | Some r -> Iter.fold ~f:extend_ (Trm.subtrms a) r + | Some r -> Iter.fold ~f:extend_ (Trm.trms a) r | None -> r ) (** add a term to the carrier *) @@ -640,10 +640,10 @@ let ppx_diff var_strength fs parent_ctx fml ctx = let fold_uses_of r t s ~f = let rec fold_ e s ~f = let s = - Iter.fold (Trm.subtrms e) s ~f:(fun sub s -> + Iter.fold (Trm.trms e) s ~f:(fun sub s -> if Trm.equal t sub then f s e else s ) in - if interpreted e then Iter.fold ~f:(fold_ ~f) (Trm.subtrms e) s else s + if interpreted e then Iter.fold ~f:(fold_ ~f) (Trm.trms e) s else s in Subst.fold r.rep s ~f:(fun ~key:trm ~data:rep s -> fold_ ~f trm (fold_ ~f rep s) ) diff --git a/sledge/src/fol/exp.ml b/sledge/src/fol/exp.ml index a2bc3dde5..81f857f11 100644 --- a/sledge/src/fol/exp.ml +++ b/sledge/src/fol/exp.ml @@ -259,20 +259,16 @@ module Term = struct (** Traverse *) - let rec iter_vars_c c ~f = - match c with - | `Ite (cnd, thn, els) -> - Iter.iter ~f (Fml.vars cnd) ; - iter_vars_c ~f thn ; - iter_vars_c ~f els - | `Trm t -> Iter.iter ~f (Trm.vars t) - - let iter_vars e ~f = - match e with - | `Fml p -> Iter.iter ~f (Fml.vars p) - | #cnd as c -> iter_vars_c ~f c + let iter e ~f:iter_t = + let iter_f f = Iter.flat_map ~f:iter_t (Fml.trms f) in + let rec iter_c = function + | `Ite (cnd, thn, els) -> + Iter.(append (iter_f cnd) (append (iter_c thn) (iter_c els))) + | `Trm e -> iter_t e + in + match e with `Fml f -> iter_f f | #cnd as c -> iter_c c - let vars e = Iter.from_labelled_iter (iter_vars e) + let vars = iter ~f:Trm.vars (** Transform *) diff --git a/sledge/src/fol/fml.mli b/sledge/src/fol/fml.mli index d5f85d345..a5c281f90 100644 --- a/sledge/src/fol/fml.mli +++ b/sledge/src/fol/fml.mli @@ -83,6 +83,7 @@ val fold_dnf : -> 'disjunction val vars : t -> Var.t iter +val trms : t -> Trm.t iter (** Query *) diff --git a/sledge/src/fol/trm.ml b/sledge/src/fol/trm.ml index f39d98eb0..29f17f921 100644 --- a/sledge/src/fol/trm.ml +++ b/sledge/src/fol/trm.ml @@ -428,23 +428,16 @@ let map e ~f = (** Traverse *) -let iter_subtrms e ~f = - match e with - | Var _ | Z _ | Q _ -> () - | Arith a -> Iter.iter ~f (Arith.trms a) - | Splat x -> f x - | Sized {seq= x; siz= y} -> - f x ; - f y +let trms = function + | Var _ | Z _ | Q _ -> Iter.empty + | Arith a -> Arith.trms a + | Splat x -> Iter.(cons x empty) + | Sized {seq= x; siz= y} -> Iter.(cons x (cons y empty)) | Extract {seq= x; off= y; len= z} -> - f x ; - f y ; - f z - | Concat xs | Apply (_, xs) -> Array.iter ~f xs - -let subtrms e = Iter.from_labelled_iter (iter_subtrms e) + Iter.(cons x (cons y (cons z empty))) + | Concat xs | Apply (_, xs) -> Iter.of_array xs (** Query *) let fv e = Var.Set.of_iter (vars e) -let rec height e = 1 + Iter.fold ~f:(fun x -> max (height x)) (subtrms e) 0 +let rec height e = 1 + Iter.fold ~f:(fun x -> max (height x)) (trms e) 0 diff --git a/sledge/src/fol/trm.mli b/sledge/src/fol/trm.mli index 1ed83df66..14a0dbe9b 100644 --- a/sledge/src/fol/trm.mli +++ b/sledge/src/fol/trm.mli @@ -94,5 +94,8 @@ val height : t -> int (** Traverse *) +val trms : t -> t iter +(** The immediate subterms of a term. *) + val vars : t -> Var.t iter -val subtrms : t -> t iter +(** The variables that occur in a term. *)