diff --git a/sledge/src/fol/context.ml b/sledge/src/fol/context.ml index e08472c1a..195c20835 100644 --- a/sledge/src/fol/context.ml +++ b/sledge/src/fol/context.ml @@ -55,6 +55,7 @@ module Subst : sig val remove : Var.Set.t -> t -> t val map_entries : f:(Trm.t -> Trm.t) -> t -> t val to_iter : t -> (Trm.t * Trm.t) iter + val fv : t -> Var.Set.t val partition_valid : Var.Set.t -> t -> t * Var.Set.t * t end = struct type t = Trm.t Trm.Map.t [@@deriving compare, equal, sexp_of] @@ -76,6 +77,14 @@ end = struct let for_alli = Trm.Map.for_alli let to_iter = Trm.Map.to_iter + let vars s = + s + |> to_iter + |> Iter.flat_map ~f:(fun (k, v) -> + Iter.append (Trm.vars k) (Trm.vars v) ) + + let fv s = Var.Set.of_iter (vars s) + (** look up a term in a substitution *) let apply s a = Trm.Map.find a s |> Option.value ~default:a @@ -332,7 +341,10 @@ and solve_ ?f d e s = let solve ?f ~us ~xs d e = [%Trace.call fun {pf} -> pf "%a@ %a" Trm.pp d Trm.pp e] ; - (solve_ ?f d e (us, xs, Subst.empty) |>= fun (_, xs, s) -> (xs, s)) + ( solve_ ?f d e (us, xs, Subst.empty) + |>= fun (_, xs, s) -> + let xs = Var.Set.inter xs (Subst.fv s) in + (xs, s) ) |> [%Trace.retn fun {pf} -> function @@ -795,7 +807,7 @@ let solve_seq_eq us e' f' subst = | None -> (Trm.sized ~siz:n ~seq:a, n) in let+ _, xs, s = solve_concat ~f ms a n (us, Var.Set.empty, subst) in - assert (Var.Set.is_empty xs) ; + assert (Var.Set.disjoint xs (Subst.fv s)) ; s in ( match ((e' : Trm.t), (f' : Trm.t)) with