(* * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. *) (** Equality over uninterpreted functions and linear rational arithmetic *) open Exp (** Classification of Terms by Theory *) type kind = Interpreted | Atomic | Uninterpreted [@@deriving compare, equal] let classify e = match (e : Trm.t) with | Var _ | Z _ | Q _ | Concat [||] | Apply (_, [||]) -> Atomic | Arith a -> ( match Trm.Arith.classify a with | Trm _ | Const _ -> violates Trm.invariant e | Interpreted -> Interpreted | Uninterpreted -> Uninterpreted ) | Splat _ | Sized _ | Extract _ | Concat _ -> Interpreted | Apply _ -> Uninterpreted let is_interpreted e = equal_kind (classify e) Interpreted let is_uninterpreted e = equal_kind (classify e) Uninterpreted let rec max_solvables e = if not (is_interpreted e) then Iter.return e else Iter.flat_map ~f:max_solvables (Trm.trms e) let fold_max_solvables e s ~f = Iter.fold ~f (max_solvables e) s (** Solution Substitutions *) module Subst : sig type t [@@deriving compare, equal, sexp] val pp : t pp val pp_diff : (t * t) pp val empty : t val is_empty : t -> bool val length : t -> int val mem : Trm.t -> t -> bool val find : Trm.t -> t -> Trm.t option val fold : t -> 's -> f:(key:Trm.t -> data:Trm.t -> 's -> 's) -> 's val fold_eqs : t -> 's -> f:(Fml.t -> 's -> 's) -> 's val iteri : t -> f:(key:Trm.t -> data:Trm.t -> unit) -> unit val for_alli : t -> f:(key:Trm.t -> data:Trm.t -> bool) -> bool val apply : t -> Trm.t -> Trm.t val subst_ : t -> Trm.t -> Trm.t val subst : t -> Term.t -> Term.t val norm : t -> Trm.t -> Trm.t val compose : t -> t -> t val compose1 : key:Trm.t -> data:Trm.t -> t -> t val extend : Trm.t -> t -> t option val map_entries : f:(Trm.t -> Trm.t) -> t -> t val to_iter : t -> (Trm.t * Trm.t) iter val fv : t -> Var.Set.t val partition_valid : Var.Set.t -> t -> t * Var.Set.t * t (* direct representation manipulation *) val add : key:Trm.t -> data:Trm.t -> t -> t val remove : Trm.t -> t -> t end = struct type t = Trm.t Trm.Map.t [@@deriving compare, equal, sexp_of] let t_of_sexp = Trm.Map.t_of_sexp Trm.t_of_sexp let pp = Trm.Map.pp Trm.pp Trm.pp let pp_diff = Trm.Map.pp_diff ~eq:Trm.equal Trm.pp Trm.pp Trm.pp_diff let empty = Trm.Map.empty let is_empty = Trm.Map.is_empty let length = Trm.Map.length let mem = Trm.Map.mem let find = Trm.Map.find let fold = Trm.Map.fold let fold_eqs s z ~f = Trm.Map.fold ~f:(fun ~key ~data -> f (Fml.eq key data)) s z let iteri = Trm.Map.iteri let for_alli = Trm.Map.for_alli let to_iter = Trm.Map.to_iter let vars s = s |> to_iter |> Iter.flat_map ~f:(fun (k, v) -> Iter.append (Trm.vars k) (Trm.vars v) ) let fv s = Var.Set.of_iter (vars s) (** look up a term in a substitution *) let apply s a = Trm.Map.find a s |> Option.value ~default:a let rec subst_ s a = apply s (Trm.map ~f:(subst_ s) a) let subst s e = Term.map_trms ~f:(subst_ s) e (** apply a substitution to maximal non-interpreted subterms *) let rec norm s a = [%trace] ~call:(fun {pf} -> pf "@ %a" Trm.pp a) ~retn:(fun {pf} -> pf "%a" Trm.pp) @@ fun () -> if is_interpreted a then Trm.map ~f:(norm s) a else apply s a (** compose two substitutions *) let compose r s = [%Trace.call fun {pf} -> pf "@ %a@ %a" pp r pp s] ; ( if is_empty s then r else let r' = Trm.Map.map_endo ~f:(norm s) r in Trm.Map.union_absent r' s ) |> [%Trace.retn fun {pf} r' -> pf "%a" pp_diff (r, r') ; assert (r' == r || not (equal r' r))] (** compose a substitution with a mapping *) let compose1 ~key ~data r = match (key : Trm.t) with | Z _ | Q _ -> r | _ when Trm.equal key data -> r | _ -> assert ( (not (Trm.Map.mem key r)) || fail "domains intersect: %a" Trm.pp key () ) ; let s = Trm.Map.singleton key data in let r' = Trm.Map.map_endo ~f:(norm s) r in Trm.Map.add ~key ~data r' (** add an identity entry if the term is not already present *) let extend e s = let exception Found in match Trm.Map.update e s ~f:(function | Some _ -> raise_notrace Found | None -> e ) with | exception Found -> None | s -> Some s (** map over a subst, applying [f] to both domain and range, requires that [f] is injective and for any set of terms [E], [f\[E\]] is disjoint from [E] *) let map_entries ~f s = Trm.Map.fold s s ~f:(fun ~key ~data s -> let key' = f key in let data' = f data in if Trm.equal key' key then if Trm.equal data' data then s else Trm.Map.add ~key ~data:data' s else let s = Trm.Map.remove key s in match (key : Trm.t) with | Z _ | Q _ -> s | _ -> Trm.Map.add_exn ~key:key' ~data:data' s ) (** Holds only if [true ⊢ ∃xs. e=f]. Clients assume [not (is_valid_eq xs e f)] implies [not (is_valid_eq ys e f)] for [ys ⊆ xs]. *) let is_valid_eq xs e f = let is_var_in xs e = Option.exists ~f:(fun x -> Var.Set.mem x xs) (Var.of_trm e) in ( is_var_in xs e || is_var_in xs f || (is_uninterpreted e && Iter.exists ~f:(is_var_in xs) (Trm.trms e)) || (is_uninterpreted f && Iter.exists ~f:(is_var_in xs) (Trm.trms f)) ) $> fun b -> [%Trace.info "is_valid_eq %a%a=%a = %b" Var.Set.pp_xs xs Trm.pp e Trm.pp f b] (** Partition ∃xs. σ into equivalent ∃xs. τ ∧ ∃ks. ν where ks and ν are maximal where ∃ks. ν is universally valid, xs ⊇ ks and ks ∩ fv(τ) = ∅. *) let partition_valid xs s = [%trace] ~call:(fun {pf} -> pf "@ @[%a@ %a@]" Var.Set.pp_xs xs pp s) ~retn:(fun {pf} (t, ks, u) -> pf "%a@ %a@ %a" pp t Var.Set.pp_xs ks pp u ) @@ fun () -> (* Move equations e=f from s to t when ∃ks.e=f fails to be provably valid. When moving an equation, reduce ks by fv(e=f) to maintain ks ∩ fv(t) = ∅. This reduction may cause equations in s to no longer be valid, so loop until no change. *) let rec partition_valid_ t ks s = let t', ks', s' = Trm.Map.fold s (t, ks, s) ~f:(fun ~key ~data (t, ks, s) -> if is_valid_eq ks key data then (t, ks, s) else let t = Trm.Map.add ~key ~data t and ks = Var.Set.diff ks (Var.Set.union (Trm.fv key) (Trm.fv data)) and s = Trm.Map.remove key s in (t, ks, s) ) in if s' != s then partition_valid_ t' ks' s' else (t', ks', s') in if Var.Set.is_empty xs then (s, Var.Set.empty, empty) else partition_valid_ empty xs s (* direct representation manipulation *) let add = Trm.Map.add let remove = Trm.Map.remove end (** Theory Solver *) (** prefer representative terms that are minimal in the order s.t. Var < Sized < Extract < Concat < others, then using height of sequence nesting, and then using Trm.compare *) let prefer e f = let rank e = match (e : Trm.t) with | Var _ -> 0 | Sized _ -> 1 | Extract _ -> 2 | Concat _ -> 3 | _ -> 4 in let o = compare (rank e) (rank f) in if o <> 0 then o else let o = compare (Trm.height e) (Trm.height f) in if o <> 0 then o else Trm.compare e f (** orient equations based on representative preference *) let orient e f = match Sign.of_int (prefer e f) with | Neg -> Some (e, f) | Zero -> None | Pos -> Some (f, e) let norm (_, _, s) e = Subst.norm s e let compose1 ?f ~var ~rep (us, xs, s) = let s = match f with | Some f when not (f var rep) -> s | _ -> Subst.compose1 ~key:var ~data:rep s in Some (us, xs, s) let fresh name (wrt, xs, s) = let x, wrt = Var.fresh name ~wrt in let xs = Var.Set.add x xs in (Trm.var x, (wrt, xs, s)) let solve_poly ?f p q s = [%trace] ~call:(fun {pf} -> pf "@ %a = %a" Trm.pp p Trm.pp q) ~retn:(fun {pf} -> function | Some (_, xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s | None -> pf "false" ) @@ fun () -> match Trm.sub p q with | Z z -> if Z.equal Z.zero z then Some s else None | Var _ as var -> compose1 ?f ~var ~rep:Trm.zero s | p_q -> ( match Trm.Arith.solve_zero_eq p_q with | Some (var, rep) -> compose1 ?f ~var:(Trm.arith var) ~rep:(Trm.arith rep) s | None -> compose1 ?f ~var:p_q ~rep:Trm.zero s ) (* α[o,l) = β ==> l = |β| ∧ α = (⟨n,c⟩[0,o) ^ β ^ ⟨n,c⟩[o+l,n-o-l)) where n = |α| and c fresh *) let rec solve_extract ?f a o l b s = [%trace] ~call:(fun {pf} -> pf "@ %a = %a@ %a%a" Trm.pp (Trm.extract ~seq:a ~off:o ~len:l) Trm.pp b Var.Set.pp_xs (snd3 s) Subst.pp (trd3 s) ) ~retn:(fun {pf} -> function | Some (_, xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s | None -> pf "false" ) @@ fun () -> let n = Trm.seq_size_exn a in let c, s = fresh "c" s in let n_c = Trm.sized ~siz:n ~seq:c in let o_l = Trm.add o l in let n_o_l = Trm.sub n o_l in let c0 = Trm.extract ~seq:n_c ~off:Trm.zero ~len:o in let c1 = Trm.extract ~seq:n_c ~off:o_l ~len:n_o_l in let b, s = match Trm.seq_size b with | None -> (Trm.sized ~siz:l ~seq:b, Some s) | Some m -> (b, solve_ ?f l m s) in s >>= solve_ ?f a (Trm.concat [|c0; b; c1|]) (* α₀^…^αᵢ^αⱼ^…^αᵥ = β ==> |α₀^…^αᵥ| = |β| ∧ … ∧ αⱼ = β[n₀+…+nᵢ,nⱼ) ∧ … where nₓ ≡ |αₓ| and m = |β| *) and solve_concat ?f a0V b m s = [%trace] ~call:(fun {pf} -> pf "@ %a = %a@ %a%a" Trm.pp (Trm.concat a0V) Trm.pp b Var.Set.pp_xs (snd3 s) Subst.pp (trd3 s) ) ~retn:(fun {pf} -> function | Some (_, xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s | None -> pf "false" ) @@ fun () -> Iter.fold_until (Array.to_iter a0V) (s, Trm.zero) ~f:(fun aJ (s, oI) -> let nJ = Trm.seq_size_exn aJ in let oJ = Trm.add oI nJ in match solve_ ?f aJ (Trm.extract ~seq:b ~off:oI ~len:nJ) s with | Some s -> `Continue (s, oJ) | None -> `Stop None ) ~finish:(fun (s, n0V) -> solve_ ?f n0V m s) and solve_ ?f d e s = [%Trace.call fun {pf} -> pf "@ %a@[%a@ %a@ %a@]" Var.Set.pp_xs (snd3 s) Trm.pp d Trm.pp e Subst.pp (trd3 s)] ; ( match orient (norm s d) (norm s e) with (* e' = f' ==> true when e' ≡ f' *) | None -> Some s (* i = j ==> false when i ≠ j *) | Some (Z _, Z _) | Some (Q _, Q _) -> None (* * Concat *) (* ⟨0,a⟩ = β ==> a = β = ⟨⟩ *) | Some (Sized {siz= n; seq= a}, b) when n == Trm.zero -> s |> solve_ ?f a (Trm.concat [||]) >>= solve_ ?f b (Trm.concat [||]) | Some (b, Sized {siz= n; seq= a}) when n == Trm.zero -> s |> solve_ ?f a (Trm.concat [||]) >>= solve_ ?f b (Trm.concat [||]) (* ⟨n,0⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,0⟩[n₀+…+nᵢ,nⱼ) ∧ … *) | Some ((Sized {siz= n; seq} as b), Concat a0V) when seq == Trm.zero -> solve_concat ?f a0V b n s (* ⟨n,e^⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,e^⟩[n₀+…+nᵢ,nⱼ) ∧ … *) | Some ((Sized {siz= n; seq= Splat _} as b), Concat a0V) -> solve_concat ?f a0V b n s | Some ((Var _ as v), (Concat a0V as c)) -> if not (Var.Set.mem (Var.of_ v) (Trm.fv c)) then (* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *) compose1 ?f ~var:v ~rep:c s else (* v = α₀^…^αᵥ ==> ⟨|α₀^…^αᵥ|,v⟩ = α₀^…^αᵥ when v ∈ fv(α₀^…^αᵥ) *) let m = Trm.seq_size_exn c in solve_concat ?f a0V (Trm.sized ~siz:m ~seq:v) m s (* α₀^…^αᵥ = β₀^…^βᵤ ==> … ∧ αⱼ = (β₀^…^βᵤ)[n₀+…+nᵢ,nⱼ) ∧ … *) | Some (Concat a0V, (Concat _ as c)) -> solve_concat ?f a0V c (Trm.seq_size_exn c) s (* α[o,l) = α₀^…^αᵥ ==> … ∧ αⱼ = α[o,l)[n₀+…+nᵢ,nⱼ) ∧ … *) | Some ((Extract {len= l} as e), Concat a0V) -> solve_concat ?f a0V e l s (* * Extract *) | Some ((Var _ as v), (Extract {len= l} as e)) -> if not (Var.Set.mem (Var.of_ v) (Trm.fv e)) then (* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *) compose1 ?f ~var:v ~rep:e s else (* v = α[o,l) ==> α[o,l) ↦ ⟨l,v⟩ when v ∈ fv(α[o,l)) *) compose1 ?f ~var:e ~rep:(Trm.sized ~siz:l ~seq:v) s (* α[o,l) = β ==> … ∧ α = _^β^_ *) | Some (Extract {seq= a; off= o; len= l}, e) -> solve_extract ?f a o l e s (* * Sized *) (* v = ⟨n,a⟩ ==> v = a *) | Some ((Var _ as v), Sized {seq= a}) -> s |> solve_ ?f v a (* ⟨n,a⟩ = ⟨m,b⟩ ==> n = m ∧ a = β *) | Some (Sized {siz= n; seq= a}, Sized {siz= m; seq= b}) -> s |> solve_ ?f n m >>= solve_ ?f a b (* ⟨n,a⟩ = β ==> n = |β| ∧ a = β *) | Some (Sized {siz= n; seq= a}, b) -> ( match Trm.seq_size b with | None -> Some s | Some m -> solve_ ?f n m s ) >>= solve_ ?f a b (* * Splat *) (* a^ = b^ ==> a = b *) | Some (Splat a, Splat b) -> s |> solve_ ?f a b (* * Arithmetic *) (* p = q ==> p-q = 0 *) | Some (((Arith _ | Z _ | Q _) as p), q | q, ((Arith _ | Z _ | Q _) as p)) -> solve_poly ?f p q s (* * Uninterpreted *) (* r = v ==> v ↦ r *) | Some (rep, var) -> assert (not (is_interpreted var)) ; assert (not (is_interpreted rep)) ; compose1 ?f ~var ~rep s ) |> [%Trace.retn fun {pf} -> function | Some (_, xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s | None -> pf "false"] let solve ?f ~wrt ~xs d e = [%Trace.call fun {pf} -> pf "@ %a@ %a" Trm.pp d Trm.pp e] ; ( solve_ ?f d e (wrt, xs, Subst.empty) |>= fun (_, xs, s) -> let xs = Var.Set.inter xs (Subst.fv s) in (xs, s) ) |> [%Trace.retn fun {pf} -> function | Some (xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s | None -> pf "false"] (** Equality Relations *) (** see also [invariant] *) type t = { xs: Var.Set.t (** existential variables that did not appear in input equations *) ; sat: bool (** [false] only if constraints are inconsistent *) ; rep: Subst.t (** functional set of oriented equations: map [a] to [a'], indicating that [a = a'] holds, and that [a'] is the 'rep(resentative)' of [a] *) } [@@deriving compare, equal, sexp] let classes r = let add elt rep cls = if Trm.equal elt rep then cls else Trm.Map.add_multi ~key:rep ~data:elt cls in Subst.fold r.rep Trm.Map.empty ~f:(fun ~key:elt ~data:rep cls -> match classify elt with | Interpreted | Atomic -> add elt rep cls | Uninterpreted -> add (Trm.map ~f:(Subst.apply r.rep) elt) rep cls ) let cls_of r e = let e' = Subst.apply r.rep e in Trm.Map.find e' (classes r) |> Option.value ~default:[e'] (** Pretty-printing *) let pp_raw fs {sat; rep} = let pp_alist pp_k pp_v fs alist = let pp_assoc fs (k, v) = Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_k k pp_v (k, v) in Format.fprintf fs "[@[%a@]]" (List.pp ";@ " pp_assoc) alist in let pp_term_v fs (k, v) = if not (Trm.equal k v) then Trm.pp fs v in Format.fprintf fs "@[{@[sat= %b;@ rep= %a@]}@]" sat (pp_alist Trm.pp pp_term_v) (Iter.to_list (Subst.to_iter rep)) let pp_diff fs (r, s) = let pp_sat fs = if not (Bool.equal r.sat s.sat) then Format.fprintf fs "sat= @[-- %b@ ++ %b@];@ " r.sat s.sat in let pp_rep fs = if not (Subst.is_empty r.rep) then Format.fprintf fs "rep= %a" Subst.pp_diff (r.rep, s.rep) in Format.fprintf fs "@[{@[%t%t@]}@]" pp_sat pp_rep let ppx_cls x = List.pp "@ = " (Trm.ppx x) let pp_cls = ppx_cls (fun _ -> None) let pp_diff_cls = List.pp_diff ~cmp:Trm.compare "@ = " Trm.pp let ppx_classes x fs clss = List.pp "@ @<2>∧ " (fun fs (rep, cls) -> if not (List.is_empty cls) then Format.fprintf fs "@[%a@ = %a@]" (Trm.ppx x) rep (ppx_cls x) cls ) fs (Iter.to_list (Trm.Map.to_iter clss)) let pp_classes fs r = ppx_classes (fun _ -> None) fs (classes r) let pp_diff_clss = Trm.Map.pp_diff ~eq:(List.equal Trm.equal) Trm.pp pp_cls pp_diff_cls let pp fs r = let clss = classes r in if Trm.Map.is_empty clss then Format.fprintf fs (if r.sat then "tt" else "ff") else ppx_classes (fun _ -> None) fs clss let ppx var_strength fs clss noneqs = let without_anon_vars = List.filter ~f:(fun e -> match Var.of_trm e with | Some v -> Poly.(var_strength v <> Some `Anonymous) | None -> true ) in let clss = Trm.Map.fold clss Trm.Map.empty ~f:(fun ~key:rep ~data:cls m -> let cls = without_anon_vars cls in if not (List.is_empty cls) then Trm.Map.add ~key:rep ~data:(List.sort ~cmp:Trm.compare cls) m else m ) in let first = Trm.Map.is_empty clss in if not first then Format.fprintf fs " " ; ppx_classes var_strength fs clss ; List.pp ~pre:(if first then "@[ " else "@ @[@<2>∧ ") "@ @<2>∧ " (Fml.ppx var_strength) fs noneqs ~suf:"@]" ; first && List.is_empty noneqs (** Basic queries *) (** test membership in carrier *) let in_car r e = Subst.mem e r.rep (** congruent specialized to assume subterms of [a'] are [Subst.norm]alized wrt [r] (or canonized) *) let semi_congruent r a' b = Trm.equal a' (Trm.map ~f:(Subst.norm r.rep) b) (** terms are congruent if equal after normalizing subterms *) let congruent r a b = semi_congruent r (Trm.map ~f:(Subst.norm r.rep) a) b (** Invariant *) let pre_invariant r = let@ () = Invariant.invariant [%here] r [%sexp_of: t] in Subst.iteri r.rep ~f:(fun ~key:trm ~data:_ -> (* no interpreted terms in carrier *) assert ( (not (is_interpreted trm)) || fail "non-interp %a" Trm.pp trm () ) ; (* carrier is closed under subterms *) Iter.iter (Trm.trms trm) ~f:(fun subtrm -> assert ( is_interpreted subtrm || (match subtrm with Z _ | Q _ -> true | _ -> false) || in_car r subtrm || fail "@[subterm %a@ of %a@ not in carrier of@ %a@]" Trm.pp subtrm Trm.pp trm pp r () ) ) ) let invariant r = let@ () = Invariant.invariant [%here] r [%sexp_of: t] in pre_invariant r ; assert ( (not r.sat) || Subst.for_alli r.rep ~f:(fun ~key:a ~data:a' -> Subst.for_alli r.rep ~f:(fun ~key:b ~data:b' -> Trm.compare a b >= 0 || (not (congruent r a b)) || Trm.equal a' b' || fail "not congruent %a@ %a@ in@ %a" Trm.pp a Trm.pp b pp r () ) ) ) (** Core operations *) let empty = let rep = Subst.empty in (* let rep = Option.get_exn (Subst.extend Trm.true_ rep) in * let rep = Option.get_exn (Subst.extend Trm.false_ rep) in *) {xs= Var.Set.empty; sat= true; rep} |> check invariant let unsat = {empty with sat= false} (** [lookup r a] is [b'] if [a ~ b = b'] for some equation [b = b'] in rep *) let lookup r a = ([%Trace.call fun {pf} -> pf "@ %a" Trm.pp a] ; Iter.find_map (Subst.to_iter r.rep) ~f:(fun (b, b') -> Option.return_if (semi_congruent r a b) b' ) |> Option.value ~default:a) |> [%Trace.retn fun {pf} -> pf "%a" Trm.pp] (** rewrite a term into canonical form using rep and, for non-interpreted terms, congruence composed with rep *) let rec canon r a = [%Trace.call fun {pf} -> pf "@ %a" Trm.pp a] ; ( match classify a with | Atomic -> Subst.apply r.rep a | Interpreted -> Trm.map ~f:(canon r) a | Uninterpreted -> ( let a' = Trm.map ~f:(canon r) a in match classify a' with | Atomic -> Subst.apply r.rep a' | Interpreted -> a' | Uninterpreted -> lookup r a' ) ) |> [%Trace.retn fun {pf} -> pf "%a" Trm.pp] let canon_f r b = [%trace] ~call:(fun {pf} -> pf "@ %a@ %a" Fml.pp b pp_raw r) ~retn:(fun {pf} -> pf "%a" Fml.pp) @@ fun () -> Fml.map_trms ~f:(canon r) b let rec extend_ a r = match (a : Trm.t) with | Z _ | Q _ -> r | _ -> ( if is_interpreted a then Iter.fold ~f:extend_ (Trm.trms a) r else (* add uninterpreted terms *) match Subst.extend a r with (* and their subterms if newly added *) | Some r -> Iter.fold ~f:extend_ (Trm.trms a) r | None -> r ) (** add a term to the carrier *) let extend a r = let rep = extend_ a r.rep in if rep == r.rep then r else {r with rep} |> check pre_invariant let merge ~wrt a b r = [%Trace.call fun {pf} -> pf "@ %a@ %a@ %a" Trm.pp a Trm.pp b pp r] ; ( match solve ~wrt ~xs:r.xs a b with | Some (xs, s) -> {r with xs= Var.Set.union r.xs xs; rep= Subst.compose r.rep s} | None -> {r with sat= false} ) |> [%Trace.retn fun {pf} r' -> pf "%a" pp_diff (r, r') ; pre_invariant r'] (** find an unproved equation between congruent terms *) let find_missing r = Iter.find_map (Subst.to_iter r.rep) ~f:(fun (a, a') -> let a_subnorm = Trm.map ~f:(Subst.norm r.rep) a in Iter.find_map (Subst.to_iter r.rep) ~f:(fun (b, b') -> (* need to equate a' and b'? *) let need_a'_eq_b' = (* optimize: do not consider both a = b and b = a *) Trm.compare a b < 0 (* a and b are not already equal *) && (not (Trm.equal a' b')) (* a and b are congruent *) && semi_congruent r a_subnorm b in Option.return_if need_a'_eq_b' (a', b') ) ) let rec close ~wrt r = if not r.sat then r else match find_missing r with | Some (a', b') -> close ~wrt (merge ~wrt a' b' r) | None -> r let close ~wrt r = [%Trace.call fun {pf} -> pf "@ %a" pp r] ; close ~wrt r |> [%Trace.retn fun {pf} r' -> pf "%a" pp_diff (r, r') ; invariant r'] let and_eq_ ~wrt a b r = if not r.sat then r else let r0 = r in let a' = canon r a in let b' = canon r b in let r = extend a' r in let r = extend b' r in if Trm.equal a' b' then r else let r = merge ~wrt a' b' r in match (a, b) with | (Var _ as v), _ when not (in_car r0 v) -> r | _, (Var _ as v) when not (in_car r0 v) -> r | _ -> close ~wrt r let extract_xs r = (r.xs, {r with xs= Var.Set.empty}) (** Exposed interface *) let is_empty {sat; rep} = sat && Subst.for_alli rep ~f:(fun ~key:a ~data:a' -> Trm.equal a a') let is_unsat {sat} = not sat let implies r b = [%Trace.call fun {pf} -> pf "@ %a@ %a" Fml.pp b pp r] ; Fml.equal Fml.tt (canon_f r b) |> [%Trace.retn fun {pf} -> pf "%b"] let refutes r b = Fml.equal Fml.ff (canon_f r b) let normalize r e = Term.map_trms ~f:(canon r) e let class_of r e = match Term.get_trm (normalize r e) with | Some e' -> List.map ~f:Term.of_trm (e' :: Trm.Map.find_multi e' (classes r)) | None -> [] let diff_classes r s = Trm.Map.filter_mapi (classes r) ~f:(fun ~key:rep ~data:cls -> match List.filter cls ~f:(fun exp -> not (implies s (Fml.eq rep exp))) with | [] -> None | cls -> Some cls ) let ppx_diff var_strength fs parent_ctx fml ctx = let fml' = canon_f ctx fml in ppx var_strength fs (diff_classes ctx parent_ctx) (if Fml.(equal tt fml') then [] else [fml']) let fold_uses_of r t s ~f = let rec fold_ e s ~f = let s = Iter.fold (Trm.trms e) s ~f:(fun sub s -> fold_ ~f sub (if Trm.equal t sub then f e s else s) ) in if is_interpreted e then Iter.fold ~f:(fold_ ~f) (Trm.trms e) s else s in Subst.fold r.rep s ~f:(fun ~key:trm ~data:rep s -> fold_ ~f trm (fold_ ~f rep s) ) let iter_uses_of t r ~f = fold_uses_of r t () ~f:(fun use () -> f use) let uses_of t r = Iter.from_labelled_iter (iter_uses_of t r) let apply_subst wrt s r = [%Trace.call fun {pf} -> pf "@ %a@ %a" Subst.pp s pp r] ; ( if Subst.is_empty s then r else Trm.Map.fold (classes r) {r with rep= Subst.empty} ~f:(fun ~key:rep ~data:cls r -> let rep' = Subst.subst_ s rep in List.fold cls r ~f:(fun trm r -> let trm' = Subst.subst_ s trm in and_eq_ ~wrt trm' rep' r ) ) ) |> extract_xs |> [%Trace.retn fun {pf} (xs, r') -> pf "%a%a" Var.Set.pp_xs xs pp_diff (r, r') ; invariant r'] let union wrt r s = [%Trace.call fun {pf} -> pf "@ @[ %a@ @<2>∧ %a@]" pp r pp s] ; ( if not r.sat then r else if not s.sat then s else let s, r = if Subst.length s.rep <= Subst.length r.rep then (s, r) else (r, s) in Subst.fold s.rep r ~f:(fun ~key:e ~data:e' r -> and_eq_ ~wrt e e' r) ) |> extract_xs |> [%Trace.retn fun {pf} (_, r') -> pf "%a" pp_diff (r, r') ; invariant r'] let inter wrt r s = [%Trace.call fun {pf} -> pf "@ @[ %a@ @<2>∨ %a@]" pp r pp s] ; ( if not s.sat then r else if not r.sat then s else let merge_mems rs r s = Trm.Map.fold (classes s) rs ~f:(fun ~key:rep ~data:cls rs -> List.fold cls ([rep], rs) ~f:(fun exp (reps, rs) -> match List.find ~f:(fun rep -> implies r (Fml.eq exp rep)) reps with | Some rep -> (reps, and_eq_ ~wrt exp rep rs) | None -> (exp :: reps, rs) ) |> snd ) in let rs = empty in let rs = merge_mems rs r s in let rs = merge_mems rs s r in rs ) |> extract_xs |> [%Trace.retn fun {pf} (_, r') -> pf "%a" pp_diff (r, r') ; invariant r'] let interN us rs = match rs with | [] -> (us, unsat) | r :: rs -> List.fold ~f:(fun r (us, s) -> inter us s r) rs (us, r) let rec add_ wrt b r = match (b : Fml.t) with | Tt -> r | Not Tt -> unsat | And {pos; neg} -> Fml.fold_pos_neg ~f:(add_ wrt) ~pos ~neg r | Eq (d, e) -> and_eq_ ~wrt d e r | Eq0 e -> and_eq_ ~wrt Trm.zero e r | Pos _ | Not _ | Or _ | Iff _ | Cond _ | Lit _ -> r let add us b r = [%Trace.call fun {pf} -> pf "@ %a@ %a" Fml.pp b pp r] ; add_ us b r |> extract_xs |> [%Trace.retn fun {pf} (_, r') -> pf "%a" pp_diff (r, r') ; invariant r'] let dnf f = let meet1 a (vs, p, x) = let vs, x = add vs a x in (vs, Fml.and_ p a, x) in let join1 = Iter.cons in let top = (Var.Set.empty, Fml.tt, empty) in let bot = Iter.empty in Fml.fold_dnf ~meet1 ~join1 ~top ~bot f let rename r sub = [%Trace.call fun {pf} -> pf "@ @[%a@]@ %a" Var.Subst.pp sub pp r] ; let rep = Subst.map_entries ~f:(Trm.map_vars ~f:(Var.Subst.apply sub)) r.rep in (if rep == r.rep then r else {r with rep}) |> [%Trace.retn fun {pf} r' -> pf "%a" pp_diff (r, r') ; invariant r'] let trms r = Iter.flat_map ~f:(fun (k, v) -> Iter.doubleton k v) (Subst.to_iter r.rep) let vars r = Iter.flat_map ~f:Trm.vars (trms r) let fv r = Var.Set.of_iter (vars r) (** Existential Witnessing and Elimination *) let subst_invariant us s0 s = assert (s0 == s || not (Subst.equal s0 s)) ; assert ( Subst.iteri s ~f:(fun ~key ~data -> (* dom of new entries not ito us *) assert ( Option.for_all ~f:(Trm.equal data) (Subst.find key s0) || not (Var.Set.subset (Trm.fv key) ~of_:us) ) ; (* rep not ito us implies trm not ito us *) assert ( Var.Set.subset (Trm.fv data) ~of_:us || not (Var.Set.subset (Trm.fv key) ~of_:us) ) ) ; true ) type 'a zom = Zero | One of 'a | Many (** try to solve [p = q] such that [fv (p - q) ⊆ us ∪ xs] and [p - q] has at most one maximal solvable subterm, [kill], where [fv kill ⊈ us]; solve [p = q] for [kill]; extend subst mapping [kill] to the solution *) let solve_poly_eq us p' q' subst = [%Trace.call fun {pf} -> pf "@ %a = %a" Trm.pp p' Trm.pp q'] ; let diff = Trm.sub p' q' in let max_solvables_not_ito_us = fold_max_solvables diff Zero ~f:(fun solvable_subterm -> function | Many -> Many | zom when Var.Set.subset (Trm.fv solvable_subterm) ~of_:us -> zom | One _ -> Many | Zero -> One solvable_subterm ) in ( match max_solvables_not_ito_us with | One kill -> let+ kill, keep = Trm.Arith.solve_zero_eq diff ~for_:kill in Subst.compose1 ~key:(Trm.arith kill) ~data:(Trm.arith keep) subst | Many | Zero -> None ) |> [%Trace.retn fun {pf} subst' -> pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)] let solve_seq_eq ~wrt us e' f' subst = [%Trace.call fun {pf} -> pf "@ %a = %a" Trm.pp e' Trm.pp f'] ; let f x u = (not (Var.Set.subset (Trm.fv x) ~of_:us)) && Var.Set.subset (Trm.fv u) ~of_:us in let solve_concat ms n a = let a, n = match Trm.seq_size a with | Some n -> (a, n) | None -> (Trm.sized ~siz:n ~seq:a, n) in let+ _, xs, s = solve_concat ~f ms a n (wrt, Var.Set.empty, subst) in assert (Var.Set.disjoint xs (Subst.fv s)) ; s in ( match ((e' : Trm.t), (f' : Trm.t)) with | (Concat ms as c), a when f c a -> solve_concat ms (Trm.seq_size_exn c) a | a, (Concat ms as c) when f c a -> solve_concat ms (Trm.seq_size_exn c) a | (Sized {seq= Var _ as v} as m), u when f m u -> Some (Subst.compose1 ~key:v ~data:u subst) | u, (Sized {seq= Var _ as v} as m) when f m u -> Some (Subst.compose1 ~key:v ~data:u subst) | _ -> None ) |> [%Trace.retn fun {pf} subst' -> pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)] let solve_interp_eq ~wrt us e' (cls, subst) = [%Trace.call fun {pf} -> pf "@ trm: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp e' pp_cls cls Subst.pp subst] ; List.find_map cls ~f:(fun f -> let f' = Subst.norm subst f in match solve_seq_eq ~wrt us e' f' subst with | Some subst -> Some subst | None -> solve_poly_eq us e' f' subst ) |> [%Trace.retn fun {pf} subst' -> pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst) ; Option.iter ~f:(subst_invariant us subst) subst'] (** move equations from [cls] to [subst] which are between interpreted terms and can be expressed, after normalizing with [subst], as [x ↦ u] where [us ∪ xs ⊇ fv x ⊈ us] and [fv u ⊆ us] or else [fv u ⊆ us ∪ xs] *) let rec solve_interp_eqs ~wrt us (cls, subst) = [%Trace.call fun {pf} -> pf "@ cls: @[%a@]@ subst: @[%a@]" pp_cls cls Subst.pp subst] ; let rec solve_interp_eqs_ cls' (cls, subst) = match cls with | [] -> (cls', subst) | trm :: cls -> let trm' = Subst.norm subst trm in if is_interpreted trm' then match solve_interp_eq ~wrt us trm' (cls, subst) with | Some subst -> solve_interp_eqs_ cls' (cls, subst) | None -> solve_interp_eqs_ (trm' :: cls') (cls, subst) else solve_interp_eqs_ (trm' :: cls') (cls, subst) in let cls', subst' = solve_interp_eqs_ [] (cls, subst) in ( if subst' != subst then solve_interp_eqs ~wrt us (cls', subst') else (cls', subst') ) |> [%Trace.retn fun {pf} (cls', subst') -> pf "cls: @[%a@]@ subst: @[%a@]" pp_diff_cls (cls, cls') Subst.pp_diff (subst, subst')] type cls_solve_state = { rep_us: Trm.t option (** rep, that is ito us, for class *) ; cls_us: Trm.t list (** cls that is ito us, or interpreted *) ; rep_xs: Trm.t option (** rep, that is *not* ito us, for class *) ; cls_xs: Trm.t list (** cls that is *not* ito us *) } let dom_trm e = match (e : Trm.t) with | Sized {seq= Var _ as v} -> Some v | _ when not (is_interpreted e) -> Some e | _ -> None (** move equations from [cls] (which is assumed to be normalized by [subst]) to [subst] which can be expressed as [x ↦ u] where [x] is non-interpreted [us ∪ xs ⊇ fv x ⊈ us] and [fv u ⊆ us] or else [fv u ⊆ us ∪ xs] *) let solve_uninterp_eqs us (cls, subst) = [%Trace.call fun {pf} -> pf "@ cls: @[%a@]@ subst: @[%a@]" pp_cls cls Subst.pp subst] ; let compare e f = [%compare: kind * Trm.t] (classify e, e) (classify f, f) in let {rep_us; cls_us; rep_xs; cls_xs} = List.fold cls {rep_us= None; cls_us= []; rep_xs= None; cls_xs= []} ~f:(fun trm ({rep_us; cls_us; rep_xs; cls_xs} as s) -> if Var.Set.subset (Trm.fv trm) ~of_:us then match rep_us with | Some rep when compare rep trm <= 0 -> {s with cls_us= trm :: cls_us} | Some rep -> {s with rep_us= Some trm; cls_us= rep :: cls_us} | None -> {s with rep_us= Some trm} else match rep_xs with | Some rep -> ( if compare rep trm <= 0 then match dom_trm trm with | Some trm -> {s with cls_xs= trm :: cls_xs} | None -> {s with cls_us= trm :: cls_us} else match dom_trm rep with | Some rep -> {s with rep_xs= Some trm; cls_xs= rep :: cls_xs} | None -> {s with rep_xs= Some trm; cls_us= rep :: cls_us} ) | None -> {s with rep_xs= Some trm} ) in ( match rep_us with | Some rep_us -> let cls = rep_us :: cls_us in let cls, cls_xs = match rep_xs with | Some rep -> ( match dom_trm rep with | Some rep -> (cls, rep :: cls_xs) | None -> (rep :: cls, cls_xs) ) | None -> (cls, cls_xs) in let subst = List.fold cls_xs subst ~f:(fun trm_xs subst -> let trm_xs = Subst.subst_ subst trm_xs in let rep_us = Subst.subst_ subst rep_us in Subst.compose1 ~key:trm_xs ~data:rep_us subst ) in (cls, subst) | None -> ( match rep_xs with | Some rep_xs -> let cls = rep_xs :: cls_us in let subst = List.fold cls_xs subst ~f:(fun trm_xs subst -> Subst.compose1 ~key:trm_xs ~data:rep_xs subst ) in (cls, subst) | None -> (cls, subst) ) ) |> [%Trace.retn fun {pf} (cls', subst') -> pf "cls: @[%a@]@ subst: @[%a@]" pp_diff_cls (cls, cls') Subst.pp_diff (subst, subst') ; subst_invariant us subst subst'] (** move equations between terms in [rep]'s class [cls] from [classes] to [subst] which can be expressed, after normalizing with [subst], as [x ↦ u] where [us ∪ xs ⊇ fv x ⊈ us] and [fv u ⊆ us] or else [fv u ⊆ us ∪ xs] *) let solve_class ~wrt us us_xs ~key:rep ~data:cls (classes, subst) = let classes0 = classes in [%Trace.call fun {pf} -> pf "@ rep: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Trm.pp rep pp_cls cls Subst.pp subst] ; let cls, cls_not_ito_us_xs = List.partition ~f:(fun e -> Var.Set.subset (Trm.fv e) ~of_:us_xs) (rep :: cls) in let cls, subst = solve_interp_eqs ~wrt us (cls, subst) in let cls, subst = solve_uninterp_eqs us (cls, subst) in let cls = List.rev_append cls_not_ito_us_xs cls in let cls = List.remove ~eq:Trm.equal (Subst.norm subst rep) cls in let classes = if List.is_empty cls then Trm.Map.remove rep classes else Trm.Map.add ~key:rep ~data:cls classes in (classes, subst) |> [%Trace.retn fun {pf} (classes', subst') -> pf "subst: @[%a@]@ classes: @[%a@]" Subst.pp_diff (subst, subst') pp_diff_clss (classes0, classes')] let solve_concat_extracts_eq r x = [%Trace.call fun {pf} -> pf "@ %a@ %a" Trm.pp x pp r] ; let uses = fold_uses_of r x [] ~f:(fun use uses -> match use with | Sized _ as m -> fold_uses_of r m uses ~f:(fun use uses -> match use with Extract _ as e -> e :: uses | _ -> uses ) | _ -> uses ) in let find_extracts_at_off off = List.filter uses ~f:(function | Extract {off= o} -> implies r (Fml.eq o off) | _ -> false ) in let rec find_extracts full_rev_extracts rev_prefix off = List.fold (find_extracts_at_off off) full_rev_extracts ~f:(fun e full_rev_extracts -> match e with | Extract {seq= Sized {siz= n}; off= o; len= l} -> let o_l = Trm.add o l in if implies r (Fml.eq n o_l) then (e :: rev_prefix) :: full_rev_extracts else find_extracts full_rev_extracts (e :: rev_prefix) o_l | _ -> full_rev_extracts ) in find_extracts [] [] Trm.zero |> [%Trace.retn fun {pf} -> pf "@[[%a]@]" (List.pp ";@ " (List.pp ",@ " Trm.pp))] let solve_concat_extracts r us x (classes, subst, us_xs) = match List.filter_map (solve_concat_extracts_eq r x) ~f:(fun rev_extracts -> Iter.fold_opt (Iter.of_list rev_extracts) [] ~f:(fun e suffix -> let+ rep_ito_us = List.fold (cls_of r e) None ~f:(fun trm rep_ito_us -> match rep_ito_us with | Some rep when Trm.compare rep trm <= 0 -> rep_ito_us | _ when Var.Set.subset (Trm.fv trm) ~of_:us -> Some trm | _ -> rep_ito_us ) in Trm.sized ~siz:(Trm.seq_size_exn e) ~seq:rep_ito_us :: suffix ) ) |> Iter.of_list |> Iter.min ~lt:(fun xs ys -> [%compare: Trm.t list] xs ys < 0) with | Some extracts -> let concat = Trm.concat (Array.of_list extracts) in let subst = Subst.compose1 ~key:x ~data:concat subst in (classes, subst, us_xs) | None -> (classes, subst, us_xs) let solve_for_xs r us xs = Var.Set.fold xs ~f:(fun x (classes, subst, us_xs) -> let x = Trm.var x in if Subst.mem x subst then (classes, subst, us_xs) else solve_concat_extracts r us x (classes, subst, us_xs) ) (** move equations from [classes] to [subst] which can be expressed, after normalizing with [subst], as [x ↦ u] where [us ∪ xs ⊇ fv x ⊈ us] and [fv u ⊆ us] or else [fv u ⊆ us ∪ xs]. *) let solve_classes ~wrt r xs (classes, subst, us) = [%Trace.call fun {pf} -> pf "@ us: {@[%a@]}@ xs: {@[%a@]}" Var.Set.pp us Var.Set.pp xs] ; let rec solve_classes_ (classes0, subst0, us_xs) = let classes, subst = Trm.Map.fold ~f:(solve_class ~wrt us us_xs) classes0 (classes0, subst0) in if subst != subst0 then solve_classes_ (classes, subst, us_xs) else (classes, subst, us_xs) in (classes, subst, Var.Set.union us xs) |> solve_classes_ |> solve_for_xs r us xs |> [%Trace.retn fun {pf} (classes', subst', _) -> pf "subst: @[%a@]@ classes: @[%a@]" Subst.pp_diff (subst, subst') pp_diff_clss (classes, classes')] let pp_vss fs vss = Format.fprintf fs "[@[%a@]]" (List.pp ";@ " (fun fs vs -> Format.fprintf fs "{@[%a@]}" Var.Set.pp vs)) vss (** enumerate variable contexts vᵢ in [v₁;…] and accumulate a solution subst with entries [x ↦ u] where [r] entails [x = u] and [⋃ⱼ₌₁ⁱ vⱼ ⊇ fv x ⊈ ⋃ⱼ₌₁ⁱ⁻¹ vⱼ] and [fv u ⊆ ⋃ⱼ₌₁ⁱ⁻¹ vⱼ] if possible and otherwise [fv u ⊆ ⋃ⱼ₌₁ⁱ vⱼ] *) let solve_for_vars vss r = [%Trace.call fun {pf} -> pf "@ %a@ @[%a@]" pp_vss vss pp r ; invariant r] ; let wrt = Var.Set.union_list vss in let us, vss = match vss with us :: vss -> (us, vss) | [] -> (Var.Set.empty, vss) in List.fold ~f:(solve_classes ~wrt r) vss (classes r, Subst.empty, us) |> snd3 |> [%Trace.retn fun {pf} subst -> pf "%a" Subst.pp subst ; Subst.iteri subst ~f:(fun ~key ~data -> assert ( implies r (Fml.eq key data) || fail "@[%a@ = %a@ not entailed by@ @[%a@]@]" Trm.pp key Trm.pp data pp_classes r () ) ; assert ( Iter.fold_until (Iter.of_list vss) us ~f:(fun xs us -> let us_xs = Var.Set.union us xs in let ks = Trm.fv key in let ds = Trm.fv data in if Var.Set.subset ks ~of_:us_xs && Var.Set.subset ds ~of_:us_xs && ( Var.Set.subset ds ~of_:us || not (Var.Set.subset ks ~of_:us) ) then `Stop true else `Continue us_xs ) ~finish:(fun _ -> false) ) )] let trivial vs r = [%trace] ~call:(fun {pf} -> pf "@ %a@ %a" Var.Set.pp_xs vs pp_raw r) ~retn:(fun {pf} ks -> pf "%a" Var.Set.pp_xs ks) @@ fun () -> Var.Set.fold vs Var.Set.empty ~f:(fun v ks -> let x = Trm.var v in match Subst.find x r.rep with | None -> Var.Set.add v ks | Some x' when Trm.equal x x' && Iter.is_empty (uses_of x r) -> Var.Set.add v ks | _ -> ks ) let trim ks r = [%trace] ~call:(fun {pf} -> pf "@ %a@ %a" Var.Set.pp_xs ks pp_raw r) ~retn:(fun {pf} r' -> pf "%a" pp_raw r' ; assert (Var.Set.disjoint ks (fv r')) ) @@ fun () -> let kills = Trm.Set.of_iter (Iter.map ~f:Trm.var (Var.Set.to_iter ks)) in (* compute classes including reps *) let reps = Subst.fold r.rep Trm.Set.empty ~f:(fun ~key:_ ~data:rep reps -> Trm.Set.add rep reps ) in let clss = Trm.Set.fold reps (classes r) ~f:(fun rep clss -> Trm.Map.add_multi ~key:rep ~data:rep clss ) in (* trim classes to those that intersect kills *) let clss = Trm.Map.filter_mapi clss ~f:(fun ~key:_ ~data:cls -> let cls = Trm.Set.of_list cls in if Trm.Set.disjoint kills cls then None else Some cls ) in (* enumerate affected classes and update solution subst *) let rep = Trm.Map.fold clss r.rep ~f:(fun ~key:rep ~data:cls s -> (* remove mappings for non-rep class elements to kill *) let drop = Trm.Set.inter cls kills in let s = Trm.Set.fold ~f:Subst.remove drop s in if not (Trm.Set.mem rep kills) then s else (* if rep is to be removed, choose new one from the keepers *) let keep = Trm.Set.diff cls drop in match Trm.Set.reduce keep ~f:(fun x y -> if prefer x y < 0 then x else y ) with | Some rep' -> (* add mappings from each keeper to the new representative *) Trm.Set.fold keep s ~f:(fun elt s -> Subst.add ~key:elt ~data:rep' s ) | None -> s ) in {r with rep} let apply_and_elim ~wrt xs s r = [%trace] ~call:(fun {pf} -> pf "@ %a%a@ %a" Var.Set.pp_xs xs Subst.pp s pp_raw r) ~retn:(fun {pf} (zs, r', ks) -> pf "%a@ %a@ %a" Var.Set.pp_xs zs pp_raw r' Var.Set.pp_xs ks ; assert (Var.Set.subset ks ~of_:xs) ; assert (Var.Set.disjoint ks (fv r')) ) @@ fun () -> if Subst.is_empty s then (Var.Set.empty, r, Var.Set.empty) else let zs, r = apply_subst wrt s r in if is_unsat r then (Var.Set.empty, unsat, Var.Set.empty) else let ks = trivial xs r in let r = trim ks r in (zs, r, ks) (* * Replay debugging *) type call = | Add of Var.Set.t * Formula.t * t | Union of Var.Set.t * t * t | Inter of Var.Set.t * t * t | InterN of Var.Set.t * t list | Rename of t * Var.Subst.t | Is_unsat of t | Implies of t * Formula.t | Refutes of t * Formula.t | Normalize of t * Term.t | Apply_subst of Var.Set.t * Subst.t * t | Solve_for_vars of Var.Set.t list * t | Apply_and_elim of Var.Set.t * Var.Set.t * Subst.t * t [@@deriving sexp] let replay c = match call_of_sexp (Sexp.of_string c) with | Add (us, e, r) -> add us e r |> ignore | Union (us, r, s) -> union us r s |> ignore | Inter (us, r, s) -> inter us r s |> ignore | InterN (us, rs) -> interN us rs |> ignore | Rename (r, s) -> rename r s |> ignore | Is_unsat r -> is_unsat r |> ignore | Implies (r, f) -> implies r f |> ignore | Refutes (r, f) -> refutes r f |> ignore | Normalize (r, e) -> normalize r e |> ignore | Apply_subst (us, s, r) -> apply_subst us s r |> ignore | Solve_for_vars (vss, r) -> solve_for_vars vss r |> ignore | Apply_and_elim (wrt, xs, s, r) -> apply_and_elim ~wrt xs s r |> ignore (* Debug wrappers *) let report ~name ~elapsed ~aggregate ~count = Format.eprintf "%15s time: %12.3f ms %12.3f ms %12d calls@." name elapsed aggregate count let dump_threshold = ref 1000. let wrap tmr f call = let f () = Timer.start tmr ; let r = f () in Timer.stop_report tmr (fun ~name ~elapsed ~aggregate ~count -> report ~name ~elapsed ~aggregate ~count ; if Float.(elapsed > !dump_threshold) then ( dump_threshold := 2. *. !dump_threshold ; Format.eprintf "@\n%a@\n@." Sexp.pp_hum (sexp_of_call (call ())) ) ) ; r in if not [%debug] then f () else try f () with exn -> let bt = Printexc.get_raw_backtrace () in let exn = Replay (exn, sexp_of_call (call ())) in Printexc.raise_with_backtrace exn bt let add_tmr = Timer.create "add" ~at_exit:report let union_tmr = Timer.create "union" ~at_exit:report let inter_tmr = Timer.create "inter" ~at_exit:report let interN_tmr = Timer.create "interN" ~at_exit:report let rename_tmr = Timer.create "rename" ~at_exit:report let is_unsat_tmr = Timer.create "is_unsat" ~at_exit:report let implies_tmr = Timer.create "implies" ~at_exit:report let refutes_tmr = Timer.create "refutes" ~at_exit:report let normalize_tmr = Timer.create "normalize" ~at_exit:report let apply_subst_tmr = Timer.create "apply_subst" ~at_exit:report let solve_for_vars_tmr = Timer.create "solve_for_vars" ~at_exit:report let apply_and_elim_tmr = Timer.create "apply_and_elim" ~at_exit:report let add us e r = wrap add_tmr (fun () -> add us e r) (fun () -> Add (us, e, r)) let union us r s = wrap union_tmr (fun () -> union us r s) (fun () -> Union (us, r, s)) let inter us r s = wrap inter_tmr (fun () -> inter us r s) (fun () -> Inter (us, r, s)) let interN us rs = wrap interN_tmr (fun () -> interN us rs) (fun () -> InterN (us, rs)) let rename r s = wrap rename_tmr (fun () -> rename r s) (fun () -> Rename (r, s)) let is_unsat r = wrap is_unsat_tmr (fun () -> is_unsat r) (fun () -> Is_unsat r) let implies r f = wrap implies_tmr (fun () -> implies r f) (fun () -> Implies (r, f)) let refutes r f = wrap refutes_tmr (fun () -> refutes r f) (fun () -> Refutes (r, f)) let normalize r e = wrap normalize_tmr (fun () -> normalize r e) (fun () -> Normalize (r, e)) let apply_subst us s r = wrap apply_subst_tmr (fun () -> apply_subst us s r) (fun () -> Apply_subst (us, s, r)) let solve_for_vars vss r = wrap solve_for_vars_tmr (fun () -> solve_for_vars vss r) (fun () -> Solve_for_vars (vss, r)) let apply_and_elim ~wrt xs s r = wrap apply_and_elim_tmr (fun () -> apply_and_elim ~wrt xs s r) (fun () -> Apply_and_elim (wrt, xs, s, r))