diff --git a/sledge/src/exec.ml b/sledge/src/exec.ml index f0ebec935..f34cce01c 100644 --- a/sledge/src/exec.ml +++ b/sledge/src/exec.ml @@ -44,9 +44,8 @@ let eq_concat (siz, seq) ms = fresh. *) let assign ~ws ~rs ~us = let ovs = Var.Set.inter ws rs in - let sub = Var.Subst.freshen ovs ~wrt:us in - let us = Var.Set.union us (Var.Subst.range sub) in - let ms = Var.Set.diff ws (Var.Subst.domain sub) in + let {Var.Subst.sub; dom; rng= _}, us = Var.Subst.freshen ovs ~wrt:us in + let ms = Var.Set.diff ws dom in (sub, ms, us) (* diff --git a/sledge/src/sh.ml b/sledge/src/sh.ml index 3192cc857..34530aac9 100644 --- a/sledge/src/sh.ml +++ b/sledge/src/sh.ml @@ -332,15 +332,16 @@ let rec apply_subst sub q = |> check (fun q' -> assert (Var.Set.disjoint (fv q') (Var.Subst.domain sub)) ) -and rename sub q = - [%Trace.call fun {pf} -> pf "@[%a@]@ %a" Var.Subst.pp sub pp q] +and rename_ Var.Subst.{sub; dom; rng} q = + [%Trace.call fun {pf} -> + pf "@[%a@]@ %a" Var.Subst.pp sub pp q ; + assert (Var.Set.is_subset dom ~of_:q.us)] ; - let sub = Var.Subst.restrict sub q.us in ( if Var.Subst.is_empty sub then q else - let us = Var.Subst.apply_set sub q.us in + let us = Var.Set.union (Var.Set.diff q.us dom) rng in assert (not (Var.Set.equal us q.us)) ; - let q' = apply_subst sub (freshen_xs q ~wrt:(Var.Set.union q.us us)) in + let q' = apply_subst sub (freshen_xs q ~wrt:(Var.Set.union dom us)) in {q' with us} ) |> [%Trace.retn fun {pf} q' -> @@ -348,16 +349,26 @@ and rename sub q = invariant q' ; assert (Var.Set.disjoint q'.us (Var.Subst.domain sub))] +and rename sub q = + [%Trace.call fun {pf} -> pf "@[%a@]@ %a" Var.Subst.pp sub pp q] + ; + rename_ (Var.Subst.restrict sub q.us) q + |> + [%Trace.retn fun {pf} q' -> + pf "%a" pp q' ; + invariant q' ; + assert (Var.Set.disjoint q'.us (Var.Subst.domain sub))] + (** freshen existentials, preserving vocabulary *) and freshen_xs q ~wrt = [%Trace.call fun {pf} -> pf "{@[%a@]}@ %a" Var.Set.pp wrt pp q ; assert (Var.Set.is_subset q.us ~of_:wrt)] ; - let sub = Var.Subst.freshen q.xs ~wrt in + let Var.Subst.{sub; dom; rng}, _ = Var.Subst.freshen q.xs ~wrt in ( if Var.Subst.is_empty sub then q else - let xs = Var.Subst.apply_set sub q.xs in + let xs = Var.Set.union (Var.Set.diff q.xs dom) rng in let q' = apply_subst sub q in if xs == q.xs && q' == q then q else {q' with xs} ) |> @@ -374,9 +385,9 @@ let extend_us us q = (if us == q.us && q' == q then q else {q' with us}) |> check invariant let freshen ~wrt q = - let sub = Var.Subst.freshen q.us ~wrt:(Var.Set.union wrt q.xs) in - let q' = extend_us wrt (rename sub q) in - (if q' == q then (q, sub) else (q', sub)) + let xsub, _ = Var.Subst.freshen q.us ~wrt:(Var.Set.union wrt q.xs) in + let q' = extend_us wrt (rename_ xsub q) in + (if q' == q then (q, xsub.sub) else (q', xsub.sub)) |> check (fun (q', _) -> invariant q' ; assert (Var.Set.is_subset wrt ~of_:q'.us) ; diff --git a/sledge/src/term.ml b/sledge/src/term.ml index d9e39b5e5..fb02b7408 100644 --- a/sledge/src/term.ml +++ b/sledge/src/term.ml @@ -1072,6 +1072,7 @@ module Var = struct (** Variable renaming substitutions *) module Subst = struct type t = T.t Map.t [@@deriving compare, equal, sexp_of] + type x = {sub: t; dom: Set.t; rng: Set.t} let t_of_sexp = Map.t_of_sexp T.t_of_sexp @@ -1080,6 +1081,7 @@ module Var = struct let domain, range = Map.fold s ~init:(Set.empty, Set.empty) ~f:(fun ~key ~data (domain, range) -> + (* substs are injective *) assert (not (Set.mem range data)) ; (Set.add domain key, Set.add range data) ) in @@ -1090,28 +1092,25 @@ module Var = struct let is_empty = Map.is_empty let freshen vs ~wrt = - let xs = Set.inter wrt vs in - ( if Set.is_empty xs then empty + let dom = Set.inter wrt vs in + ( if Set.is_empty dom then + ({sub= empty; dom= Set.empty; rng= Set.empty}, wrt) else let wrt = Set.union wrt vs in - Set.fold xs ~init:(empty, wrt) ~f:(fun (sub, wrt) x -> - let x', wrt = fresh (name x) ~wrt in - let sub = Map.add_exn sub ~key:x ~data:x' in - (sub, wrt) ) - |> fst ) - |> check invariant + let sub, rng, wrt = + Set.fold dom ~init:(empty, Set.empty, wrt) + ~f:(fun (sub, rng, wrt) x -> + let x', wrt = fresh (name x) ~wrt in + let sub = Map.add_exn sub ~key:x ~data:x' in + let rng = Set.add rng x' in + (sub, rng, wrt) ) + in + ({sub; dom; rng}, wrt) ) + |> check (fun ({sub; _}, _) -> invariant sub) let fold sub ~init ~f = Map.fold sub ~init ~f:(fun ~key ~data s -> f key data s) - let invert sub = - Map.fold sub ~init:empty ~f:(fun ~key ~data sub' -> - Map.add_exn sub' ~key:data ~data:key ) - |> check invariant - - let restrict sub vs = - Map.filter_keys ~f:(Set.mem vs) sub |> check invariant - let domain sub = Map.fold sub ~init:Set.empty ~f:(fun ~key ~data:_ domain -> Set.add domain key ) @@ -1120,18 +1119,28 @@ module Var = struct Map.fold sub ~init:Set.empty ~f:(fun ~key:_ ~data range -> Set.add range data ) - let apply sub v = Map.find sub v |> Option.value ~default:v + let invert sub = + Map.fold sub ~init:empty ~f:(fun ~key ~data sub' -> + Map.add_exn sub' ~key:data ~data:key ) + |> check invariant - let apply_set sub vs = - Map.fold sub ~init:vs ~f:(fun ~key ~data vs -> - let vs' = Set.remove vs key in - if vs' == vs then vs + let restrict sub vs = + Map.fold sub ~init:{sub; dom= Set.empty; rng= Set.empty} + ~f:(fun ~key ~data z -> + if Set.mem vs key then + {z with dom= Set.add z.dom key; rng= Set.add z.rng data} else ( - assert (not (Set.equal vs' vs)) ; - Set.add vs' data ) ) - |> check (fun vs' -> - assert (Set.disjoint (domain sub) vs') ; - assert (Set.is_subset (range sub) ~of_:vs') ) + assert ( + (* all substs are injective, so the current mapping is the + only one that can cause [data] to be in [rng] *) + (not (Set.mem (range (Map.remove sub key)) data)) + || violates invariant sub ) ; + {z with sub= Map.remove z.sub key} ) ) + |> check (fun {sub; dom; rng} -> + assert (Set.equal dom (domain sub)) ; + assert (Set.equal rng (range sub)) ) + + let apply sub v = Map.find sub v |> Option.value ~default:v end end diff --git a/sledge/src/term.mli b/sledge/src/term.mli index cefe113ae..017e85605 100644 --- a/sledge/src/term.mli +++ b/sledge/src/term.mli @@ -139,16 +139,17 @@ module Var : sig module Subst : sig type var := t type t [@@deriving compare, equal, sexp] + type x = {sub: t; dom: Set.t; rng: Set.t} val pp : t pp val empty : t - val freshen : Set.t -> wrt:Set.t -> t + val freshen : Set.t -> wrt:Set.t -> x * Set.t val invert : t -> t - val restrict : t -> Set.t -> t + val restrict : t -> Set.t -> x val is_empty : t -> bool val domain : t -> Set.t val range : t -> Set.t - val apply_set : t -> Set.t -> Set.t + val apply : t -> var -> var val fold : t -> init:'a -> f:(var -> var -> 'a -> 'a) -> 'a end end