diff --git a/sledge/lib/equality.ml b/sledge/lib/equality.ml index fa3d58546..b4918b276 100644 --- a/sledge/lib/equality.ml +++ b/sledge/lib/equality.ml @@ -89,10 +89,20 @@ end = struct (** compose two substitutions *) let compose r s = - let r' = Term.Map.map ~f:(norm s) r in - Term.Map.merge_skewed r' s ~combine:(fun ~key v1 v2 -> - if Term.equal v1 v2 then v1 - else fail "domains intersect: %a" Term.pp key () ) + [%Trace.call fun {pf} -> pf "%a@ %a" pp r pp s] + ; + let r' = Term.Map.map_endo ~f:(norm s) r in + Term.Map.merge_endo r' s ~f:(fun ~key -> function + | `Both (data_r, data_s) -> + assert ( + Term.equal data_s data_r + || fail "domains intersect: %a" Term.pp key () ) ; + Some data_r + | `Left data | `Right data -> Some data ) + |> + [%Trace.retn fun {pf} r' -> + pf "%a" pp_diff (r, r') ; + assert (r' != r ==> not (equal r' r))] (** compose a substitution with a mapping *) let compose1 ~key ~data s = diff --git a/sledge/lib/import/map.ml b/sledge/lib/import/map.ml index 65117e08c..bead4fd89 100644 --- a/sledge/lib/import/map.ml +++ b/sledge/lib/import/map.ml @@ -30,6 +30,19 @@ end) : S with type key = Key.t = struct let map_endo t ~f = map_endo map t ~f + let merge_endo t u ~f = + let change = ref false in + let t' = + merge t u ~f:(fun ~key side -> + let f_side = f ~key side in + ( match (side, f_side) with + | (`Both (data, _) | `Left data), Some data' when data' == data -> + () + | _ -> change := true ) ; + f_side ) + in + if !change then t' else t + let fold_until m ~init ~f ~finish = let fold m ~init ~f = let f ~key ~data s = f s (key, data) in diff --git a/sledge/lib/import/map_intf.ml b/sledge/lib/import/map_intf.ml index e2166afe7..747de3d8d 100644 --- a/sledge/lib/import/map_intf.ml +++ b/sledge/lib/import/map_intf.ml @@ -24,6 +24,18 @@ module type S = sig (** Like map, but specialized to require [f] to be an endofunction, which enables preserving [==] if [f] preserves [==] of every element. *) + val merge_endo : + 'a t + -> 'b t + -> f: + ( key:key + -> [`Both of 'a * 'b | `Left of 'a | `Right of 'b] + -> 'a option) + -> 'a t + (** Like merge, but specialized to require [f] to preserve the type of the + left argument, which enables preserving [==] if [f] preserves [==] of + every element. *) + val merge_skewed : 'a t -> 'a t -> combine:(key:key -> 'a -> 'a -> 'a) -> 'a t