* Copyright (c) Facebook, Inc. and its affiliates.
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
(** Equality over uninterpreted functions and linear rational arithmetic *)
(** Classification of Terms by Theory *)
type kind = Interpreted | Simplified | Atomic | Uninterpreted
[@@deriving compare, equal]
let classify e =
match (e : Term.t) with
| Add _ | Mul _ -> Interpreted
| Ap2 (Memory, _, _) | Ap3 (Extract, _, _, _) | ApN (Concat, _) ->
| Ap2 ((Eq | Dq), _, _) -> Simplified
| Ap1 _ | Ap2 _ | Ap3 _ | ApN _ -> Uninterpreted
| RecN _ | Var _ | Integer _ | Rational _ | Float _ | Nondet _ | Label _
let interpreted e = equal_kind (classify e) Interpreted
let non_interpreted e = not (interpreted e)
let uninterpreted e = equal_kind (classify e) Uninterpreted
let rec fold_max_solvables e ~init ~f =
if non_interpreted e then f e init
else Term.fold e ~init ~f:(fun d s -> fold_max_solvables ~f d ~init:s)
let rec iter_max_solvables e ~f =
if non_interpreted e then f e else Term.iter ~f:(iter_max_solvables ~f) e
(** Solution Substitutions *)
module Subst : sig
type t [@@deriving compare, equal, sexp]
val pp : t pp
val pp_diff : (t * t) pp
val empty : t
val is_empty : t -> bool
val length : t -> int
val mem : t -> Term.t -> bool
val find : t -> Term.t -> Term.t option
val fold : t -> init:'a -> f:(key:Term.t -> data:Term.t -> 'a -> 'a) -> 'a
val iteri : t -> f:(key:Term.t -> data:Term.t -> unit) -> unit
val for_alli : t -> f:(key:Term.t -> data:Term.t -> bool) -> bool
val apply : t -> Term.t -> Term.t
val subst : t -> Term.t -> Term.t
val norm : t -> Term.t -> Term.t
val compose : t -> t -> t
val compose1 : key:Term.t -> data:Term.t -> t -> t
val extend : Term.t -> t -> t option
val remove : Var.Set.t -> t -> t
val map_entries : f:(Term.t -> Term.t) -> t -> t
val to_alist : t -> (Term.t * Term.t) list
val partition_valid : Var.Set.t -> t -> t * Var.Set.t * t
end = struct
type t = Term.t Term.Map.t [@@deriving compare, equal, sexp_of]
let t_of_sexp = Term.Map.t_of_sexp Term.t_of_sexp
let pp = Term.Map.pp Term.pp Term.pp
let pp_diff =
Term.Map.pp_diff ~data_equal:Term.equal Term.pp Term.pp Term.pp_diff
let empty = Term.Map.empty
let is_empty = Term.Map.is_empty
let length = Term.Map.length
let mem = Term.Map.mem
let find = Term.Map.find
let fold = Term.Map.fold
let iteri = Term.Map.iteri
let for_alli = Term.Map.for_alli
let to_alist = Term.Map.to_alist ~key_order:`Increasing
(** look up a term in a substitution *)
let apply s a = Term.Map.find s a |> Option.value ~default:a
let rec subst s a = apply s (Term.map ~f:(subst s) a)
(** apply a substitution to maximal non-interpreted subterms *)
let rec norm s a =
match classify a with
| Interpreted -> Term.map ~f:(norm s) a
| Simplified -> apply s (Term.map ~f:(norm s) a)
| Atomic | Uninterpreted -> apply s a
(** compose two substitutions *)
let compose r s =
[%Trace.call fun {pf} -> pf "%a@ %a" pp r pp s]
let r' = Term.Map.map_endo ~f:(norm s) r in
Term.Map.merge_endo r' s ~f:(fun ~key -> function
| `Both (data_r, data_s) ->
assert (
Term.equal data_s data_r
|| fail "domains intersect: %a" Term.pp key () ) ;
Some data_r
| `Left data | `Right data -> Some data )
[%Trace.retn fun {pf} r' ->
pf "%a" pp_diff (r, r') ;
assert (r' != r ==> not (equal r' r))]
(** compose a substitution with a mapping *)
let compose1 ~key ~data s =
if Term.equal key data then s
else compose s (Term.Map.singleton key data)
(** add an identity entry if the term is not already present *)
let extend e s =
let exception Found in
Term.Map.update s e ~f:(function
| Some _ -> raise_notrace Found
| None -> e )
| exception Found -> None
| s -> Some s
(** remove entries for vars *)
let remove xs s =
Var.Set.fold ~f:(fun s x -> Term.Map.remove s (Term.var x)) ~init:s xs
(** map over a subst, applying [f] to both domain and range, requires that
[f] is injective and for any set of terms [E], [f\[E\]] is disjoint
from [E] *)
let map_entries ~f s =
Term.Map.fold s ~init:s ~f:(fun ~key ~data s ->
let key' = f key in
let data' = f data in
if Term.equal key' key then
if Term.equal data' data then s
else Term.Map.set s ~key ~data:data'
else Term.Map.remove s key |> Term.Map.add_exn ~key:key' ~data:data'
(** Holds only if [true ⊢ ∃xs. e=f]. Clients assume
[not (is_valid_eq xs e f)] implies [not (is_valid_eq ys e f)] for
[ys ⊆ xs]. *)
let is_valid_eq xs e f =
let is_var_in xs e =
Option.exists ~f:(Var.Set.mem xs) (Var.of_term e)
( is_var_in xs e || is_var_in xs f
|| (uninterpreted e && Term.exists ~f:(is_var_in xs) e)
|| (uninterpreted f && Term.exists ~f:(is_var_in xs) f) )
$> fun b ->
"is_valid_eq %a%a=%a = %b" Var.Set.pp_xs xs Term.pp e Term.pp f b]
(** Partition ∃xs. σ into equivalent ∃xs. τ ∧ ∃ks. ν where ks
and ν are maximal where ∃ks. ν is universally valid, xs ⊇ ks and
ks ∩ fv(τ) = ∅. *)
let partition_valid xs s =
(* Move equations e=f from s to t when ∃ks.e=f fails to be provably
valid. When moving an equation, reduce ks by fv(e=f) to maintain ks ∩
fv(t) = ∅. This reduction may cause equations in s to no longer be
valid, so loop until no change. *)
let rec partition_valid_ t ks s =
let t', ks', s' =
Term.Map.fold s ~init:(t, ks, s) ~f:(fun ~key ~data (t, ks, s) ->
if is_valid_eq ks key data then (t, ks, s)
let t = Term.Map.set ~key ~data t
and ks =
Var.Set.diff ks (Var.Set.union (Term.fv key) (Term.fv data))
and s = Term.Map.remove s key in
(t, ks, s) )
if s' != s then partition_valid_ t' ks' s' else (t', ks', s')
partition_valid_ empty xs s
(** Theory Solver *)
(** prefer representative terms that are minimal in the order s.t. Var <
Memory < Extract < Concat < others, then using height of aggregate
nesting, and then using Term.compare *)
let prefer e f =
let rank e =
match (e : Term.t) with
| Var _ -> 0
| Ap2 (Memory, _, _) -> 1
| Ap3 (Extract, _, _, _) -> 2
| ApN (Concat, _) -> 3
| _ -> 4
let o = compare (rank e) (rank f) in
if o <> 0 then o
let o = compare (Term.height e) (Term.height f) in
if o <> 0 then o else Term.compare e f
(** orient equations based on representative preference *)
let orient e f =
match Ordering.of_int (prefer e f) with
| Less -> Some (e, f)
| Equal -> None
| Greater -> Some (f, e)
let norm (_, _, s) e = Subst.norm s e
let compose1 ?f ~var ~rep (us, xs, s) =
let s =
match f with
| Some f when not (f var rep) -> s
| _ -> Subst.compose1 ~key:var ~data:rep s
Some (us, xs, s)
let fresh name (us, xs, s) =
let x, us = Var.fresh name ~wrt:us in
let xs = Var.Set.add xs x in
(Term.var x, (us, xs, s))
let solve_poly ?f p q s =
match Term.sub p q with
| Integer {data} -> if Z.equal Z.zero data then Some s else None
| Var _ as var -> compose1 ?f ~var ~rep:Term.zero s
| p_q -> (
match Term.solve_zero_eq p_q with
| Some (var, rep) -> compose1 ?f ~var ~rep s
| None -> compose1 ?f ~var:p_q ~rep:Term.zero s )
(* α[o,l) = β ==> l = |β| ∧ α = (⟨n,c⟩[0,o) ^ β ^ ⟨n,c⟩[o+l,n-o-l)) where n
= |α| and c fresh *)
let rec solve_extract ?f a o l b s =
let n = Term.agg_size_exn a in
let c, s = fresh "c" s in
let n_c = Term.memory ~siz:n ~arr:c in
let o_l = Term.add o l in
let n_o_l = Term.sub n o_l in
let c0 = Term.extract ~agg:n_c ~off:Term.zero ~len:o in
let c1 = Term.extract ~agg:n_c ~off:o_l ~len:n_o_l in
let b, s =
match Term.agg_size b with
| None -> (Term.memory ~siz:l ~arr:b, Some s)
| Some m -> (b, solve_ ?f l m s)
s >>= solve_ ?f a (Term.concat [|c0; b; c1|])
(* α₀^…^αᵢ^αⱼ^…^αᵥ = β ==> |α₀^…^αᵥ| = |β| ∧ … ∧ αⱼ = β[n₀+…+nᵢ,nⱼ) ∧ …
where nₓ ≡ |αₓ| and m = |β| *)
and solve_concat ?f a0V b m s =
IArray.fold_until a0V ~init:(s, Term.zero)
~f:(fun (s, oI) aJ ->
let nJ = Term.agg_size_exn aJ in
let oJ = Term.add oI nJ in
match solve_ ?f aJ (Term.extract ~agg:b ~off:oI ~len:nJ) s with
| Some s -> Continue (s, oJ)
| None -> Stop None )
~finish:(fun (s, n0V) -> solve_ ?f n0V m s)
and solve_ ?f d e s =
[%Trace.call fun {pf} ->
pf "%a@[%a@ %a@ %a@]" Var.Set.pp_xs (snd3 s) Term.pp d Term.pp e
Subst.pp (trd3 s)]
( match orient (norm s d) (norm s e) with
(* e' = f' ==> true when e' ≡ f' *)
| None -> Some s
(* i = j ==> false when i ≠ j *)
| Some (Integer _, Integer _) | Some (Rational _, Rational _) -> None
(* ⟨0,a⟩ = β ==> a = β = ⟨⟩ *)
| Some (Ap2 (Memory, n, a), b) when Term.equal n Term.zero ->
s |> solve_ ?f a (Term.concat [||]) >>= solve_ ?f b (Term.concat [||])
| Some (b, Ap2 (Memory, n, a)) when Term.equal n Term.zero ->
s |> solve_ ?f a (Term.concat [||]) >>= solve_ ?f b (Term.concat [||])
(* v = ⟨n,a⟩ ==> v = a *)
| Some ((Var _ as v), Ap2 (Memory, _, a)) -> s |> solve_ ?f v a
(* ⟨n,a⟩ = ⟨m,b⟩ ==> n = m ∧ a = β *)
| Some (Ap2 (Memory, n, a), Ap2 (Memory, m, b)) ->
s |> solve_ ?f n m >>= solve_ ?f a b
(* ⟨n,a⟩ = β ==> n = |β| ∧ a = β *)
| Some (Ap2 (Memory, n, a), b) ->
( match Term.agg_size b with
| None -> Some s
| Some m -> solve_ ?f n m s )
>>= solve_ ?f a b
| Some ((Var _ as v), (Ap3 (Extract, _, _, l) as e)) ->
if not (Var.Set.mem (Term.fv e) (Var.of_ v)) then
(* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *)
compose1 ?f ~var:v ~rep:e s
(* v = α[o,l) ==> α[o,l) ↦ ⟨l,v⟩ when v ∈ fv(α[o,l)) *)
compose1 ?f ~var:e ~rep:(Term.memory ~siz:l ~arr:v) s
| Some ((Var _ as v), (ApN (Concat, a0V) as c)) ->
if not (Var.Set.mem (Term.fv c) (Var.of_ v)) then
(* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *)
compose1 ?f ~var:v ~rep:c s
(* v = α₀^…^αᵥ ==> ⟨|α₀^…^αᵥ|,v⟩ = α₀^…^αᵥ when v ∈ fv(α₀^…^αᵥ) *)
let m = Term.agg_size_exn c in
solve_concat ?f a0V (Term.memory ~siz:m ~arr:v) m s
| Some ((Ap3 (Extract, _, _, l) as e), ApN (Concat, a0V)) ->
solve_concat ?f a0V e l s
| Some (ApN (Concat, a0V), (ApN (Concat, _) as c)) ->
solve_concat ?f a0V c (Term.agg_size_exn c) s
| Some (Ap3 (Extract, a, o, l), e) -> solve_extract ?f a o l e s
(* p = q ==> p-q = 0 *)
| Some
( ((Add _ | Mul _ | Integer _ | Rational _) as p), q
| q, ((Add _ | Mul _ | Integer _ | Rational _) as p) ) ->
solve_poly ?f p q s
| Some (rep, var) ->
assert (non_interpreted var) ;
assert (non_interpreted rep) ;
compose1 ?f ~var ~rep s )
[%Trace.retn fun {pf} ->
| Some (_, xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s
| None -> pf "false"]
let solve ?f ~us ~xs d e =
[%Trace.call fun {pf} -> pf "%a@ %a" Term.pp d Term.pp e]
(solve_ ?f d e (us, xs, Subst.empty) >>| fun (_, xs, s) -> (xs, s))
[%Trace.retn fun {pf} ->
| Some (xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s
| None -> pf "false"]
(** Equality Relations *)
(** see also [invariant] *)
type t =
{ xs: Var.Set.t
(** existential variables that did not appear in input equations *)
; sat: bool (** [false] only if constraints are inconsistent *)
; rep: Subst.t
(** functional set of oriented equations: map [a] to [a'],
indicating that [a = a'] holds, and that [a'] is the
'rep(resentative)' of [a] *) }
[@@deriving compare, equal, sexp]
let classes r =
let add key data cls =
if Term.equal key data then cls
else Term.Map.add_multi cls ~key:data ~data:key
Subst.fold r.rep ~init:Term.Map.empty ~f:(fun ~key ~data cls ->
match classify key with
| Interpreted | Atomic -> add key data cls
| Simplified | Uninterpreted ->
add (Term.map ~f:(Subst.apply r.rep) key) data cls )
let cls_of r e =
let e' = Subst.apply r.rep e in
Term.Map.find (classes r) e' |> Option.value ~default:[e']
(** Pretty-printing *)
let pp fs {sat; rep} =
let pp_alist pp_k pp_v fs alist =
let pp_assoc fs (k, v) =
Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_k k pp_v (k, v)
Format.fprintf fs "[@[<hv>%a@]]" (List.pp ";@ " pp_assoc) alist
let pp_term_v fs (k, v) = if not (Term.equal k v) then Term.pp fs v in
Format.fprintf fs "@[{@[<hv>sat= %b;@ rep= %a@]}@]" sat
(pp_alist Term.pp pp_term_v)
(Subst.to_alist rep)
let pp_diff fs (r, s) =
let pp_sat fs =
if not (Bool.equal r.sat s.sat) then
Format.fprintf fs "sat= @[-- %b@ ++ %b@];@ " r.sat s.sat
let pp_rep fs =
if not (Subst.is_empty r.rep) then
Format.fprintf fs "rep= %a" Subst.pp_diff (r.rep, s.rep)
Format.fprintf fs "@[{@[<hv>%t%t@]}@]" pp_sat pp_rep
let ppx_cls x = List.pp "@ = " (Term.ppx x)
let pp_cls = ppx_cls (fun _ -> None)
let pp_diff_cls = List.pp_diff ~compare:Term.compare "@ = " Term.pp
let ppx_clss x fs cs =
List.pp "@ @<2>∧ "
(fun fs (key, data) ->
Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) key (ppx_cls x)
(List.sort ~compare:Term.compare data) )
fs (Term.Map.to_alist cs)
let pp_clss fs cs = ppx_clss (fun _ -> None) fs cs
let pp_diff_clss =
Term.Map.pp_diff ~data_equal:(List.equal Term.equal) Term.pp pp_cls
(** Basic queries *)
(** test membership in carrier *)
let in_car r e = Subst.mem r.rep e
(** terms are congruent if equal after normalizing subterms *)
let congruent r a b =
(Term.map ~f:(Subst.norm r.rep) a)
(Term.map ~f:(Subst.norm r.rep) b)
(** Invariant *)
let pre_invariant r =
Invariant.invariant [%here] r [%sexp_of: t]
@@ fun () ->
Subst.iteri r.rep ~f:(fun ~key:trm ~data:_ ->
(* no interpreted terms in carrier *)
assert (non_interpreted trm || fail "non-interp %a" Term.pp trm ()) ;
(* carrier is closed under subterms *)
iter_max_solvables trm ~f:(fun subtrm ->
assert (
in_car r subtrm
|| fail "@[subterm %a of %a not in carrier of@ %a@]" Term.pp
subtrm Term.pp trm pp r () ) ) )
let invariant r =
Invariant.invariant [%here] r [%sexp_of: t]
@@ fun () ->
pre_invariant r ;
assert (
(not r.sat)
|| Subst.for_alli r.rep ~f:(fun ~key:a ~data:a' ->
Subst.for_alli r.rep ~f:(fun ~key:b ~data:b' ->
Term.compare a b >= 0
|| congruent r a b ==> Term.equal a' b'
|| fail "not congruent %a@ %a@ in@ %a" Term.pp a Term.pp b pp
r () ) ) )
(** Core operations *)
let true_ =
{xs= Var.Set.empty; sat= true; rep= Subst.empty} |> check invariant
let false_ = {true_ with sat= false}
(** [lookup r a] is [b'] if [a ~ b = b'] for some equation [b = b'] in rep *)
let lookup r a =
[%Trace.call fun {pf} -> pf "%a" Term.pp a]
( with_return
@@ fun {return} ->
(* congruent specialized to assume [a] canonized and [b] non-interpreted *)
let semi_congruent r a b =
Term.equal a (Term.map ~f:(Subst.apply r.rep) b)
Subst.iteri r.rep ~f:(fun ~key ~data ->
if semi_congruent r a key then return data ) ;
a )
[%Trace.retn fun {pf} -> pf "%a" Term.pp]
(** rewrite a term into canonical form using rep and, for non-interpreted
terms, congruence composed with rep *)
let rec canon r a =
[%Trace.call fun {pf} -> pf "%a" Term.pp a]
( match classify a with
| Atomic -> Subst.apply r.rep a
| Interpreted -> Term.map ~f:(canon r) a
| Simplified | Uninterpreted -> (
let a' = Term.map ~f:(canon r) a in
match classify a' with
| Atomic -> Subst.apply r.rep a'
| Interpreted -> Term.map ~f:(canon r) a'
| Simplified | Uninterpreted -> lookup r a' ) )
[%Trace.retn fun {pf} -> pf "%a" Term.pp]
let rec extend_ a r =
match classify a with
| Interpreted | Simplified -> Term.fold ~f:extend_ a ~init:r
| Uninterpreted -> (
match Subst.extend a r with
| Some r -> Term.fold ~f:extend_ a ~init:r
| None -> r )
| Atomic -> r
(** add a term to the carrier *)
let extend a r =
let rep = extend_ a r.rep in
if rep == r.rep then r else {r with rep} |> check pre_invariant
let merge us a b r =
[%Trace.call fun {pf} -> pf "%a@ %a@ %a" Term.pp a Term.pp b pp r]
( match solve ~us ~xs:r.xs a b with
| Some (xs, s) ->
{r with xs= Var.Set.union r.xs xs; rep= Subst.compose r.rep s}
| None -> {r with sat= false} )
[%Trace.retn fun {pf} r' ->
pf "%a" pp_diff (r, r') ;
pre_invariant r']
(** find an unproved equation between congruent terms *)
let find_missing r =
@@ fun {return} ->
Subst.iteri r.rep ~f:(fun ~key:a ~data:a' ->
Subst.iteri r.rep ~f:(fun ~key:b ~data:b' ->
Term.compare a b < 0
&& (not (Term.equal a' b'))
&& congruent r a b
then return (Some (a', b')) ) ) ;
let rec close us r =
if not r.sat then r
match find_missing r with
| Some (a', b') -> close us (merge us a' b' r)
| None -> r
let close us r =
[%Trace.call fun {pf} -> pf "%a" pp r]
close us r
[%Trace.retn fun {pf} r' ->
pf "%a" pp_diff (r, r') ;
invariant r']
let and_eq_ us a b r =
if not r.sat then r
let a' = canon r a in
let b' = canon r b in
let r = extend a' r in
let r = extend b' r in
if Term.equal a' b' then r else close us (merge us a' b' r)
let extract_xs r = (r.xs, {r with xs= Var.Set.empty})
(** Exposed interface *)
let is_true {sat; rep} =
sat && Subst.for_alli rep ~f:(fun ~key:a ~data:a' -> Term.equal a a')
let is_false {sat} = not sat
let entails_eq r d e =
[%Trace.call fun {pf} -> pf "%a = %a@ %a" Term.pp d Term.pp e pp r]
Term.is_true (canon r (Term.eq d e))
[%Trace.retn fun {pf} -> pf "%b"]
let entails r s =
Subst.for_alli s.rep ~f:(fun ~key:e ~data:e' -> entails_eq r e e')
let normalize = canon
let class_of r e =
let e' = normalize r e in
e' :: Term.Map.find_multi (classes r) e'
let fold_uses_of r t ~init ~f =
let rec fold_ e ~init:s ~f =
let s =
Term.fold e ~init:s ~f:(fun sub s ->
if Term.equal t sub then f s e else s )
if interpreted e then
Term.fold e ~init:s ~f:(fun d s -> fold_ ~f d ~init:s)
else s
Subst.fold r.rep ~init ~f:(fun ~key:trm ~data:rep s ->
let f trm s = fold_ trm ~init:s ~f in
f trm (f rep s) )
let difference r a b =
[%Trace.call fun {pf} -> pf "%a@ %a@ %a" Term.pp a Term.pp b pp r]
let a = canon r a in
let b = canon r b in
( if Term.equal a b then Some Z.zero
match normalize r (Term.sub a b) with
| Integer {data} -> Some data
| _ -> None )
[%Trace.retn fun {pf} ->
function Some d -> pf "%a" Z.pp_print d | None -> pf ""]
let apply_subst us s r =
[%Trace.call fun {pf} -> pf "%a@ %a" Subst.pp s pp r]
Term.Map.fold (classes r) ~init:true_ ~f:(fun ~key:rep ~data:cls r ->
let rep' = Subst.subst s rep in
List.fold cls ~init:r ~f:(fun r trm ->
let trm' = Subst.subst s trm in
and_eq_ us trm' rep' r ) )
|> extract_xs
[%Trace.retn fun {pf} (xs, r') ->
pf "%a%a" Var.Set.pp_xs xs pp_diff (r, r') ;
invariant r']
let and_ us r s =
[%Trace.call fun {pf} -> pf "@[<hv 1> %a@ @<2>∧ %a@]" pp r pp s]
( if not r.sat then r
else if not s.sat then s
let s, r =
if Subst.length s.rep <= Subst.length r.rep then (s, r) else (r, s)
Subst.fold s.rep ~init:r ~f:(fun ~key:e ~data:e' r -> and_eq_ us e e' r)
|> extract_xs
[%Trace.retn fun {pf} (_, r') ->
pf "%a" pp_diff (r, r') ;
invariant r']
let or_ us r s =
[%Trace.call fun {pf} -> pf "@[<hv 1> %a@ @<2>∨ %a@]" pp r pp s]
( if not s.sat then r
else if not r.sat then s
let merge_mems rs r s =
Term.Map.fold (classes s) ~init:rs ~f:(fun ~key:rep ~data:cls rs ->
List.fold cls
~init:([rep], rs)
~f:(fun (reps, rs) exp ->
match List.find ~f:(entails_eq r exp) reps with
| Some rep -> (reps, and_eq_ us exp rep rs)
| None -> (exp :: reps, rs) )
|> snd )
let rs = true_ in
let rs = merge_mems rs r s in
let rs = merge_mems rs s r in
rs )
|> extract_xs
[%Trace.retn fun {pf} (_, r') ->
pf "%a" pp_diff (r, r') ;
invariant r']
let orN us rs =
match rs with
| [] -> (us, false_)
| r :: rs -> List.fold ~f:(fun (us, s) r -> or_ us s r) ~init:(us, r) rs
let rec and_term_ us e r =
let eq_false b r = and_eq_ us b Term.false_ r in
match (e : Term.t) with
| Integer {data} -> if Z.is_false data then false_ else r
| Ap2 (And, a, b) -> and_term_ us a (and_term_ us b r)
| Ap2 (Eq, a, b) -> and_eq_ us a b r
| Ap2 (Xor, Integer {data}, a) when Z.is_true data -> eq_false a r
| Ap2 (Xor, a, Integer {data}) when Z.is_true data -> eq_false a r
| _ -> r
let and_term us e r =
[%Trace.call fun {pf} -> pf "%a@ %a" Term.pp e pp r]
and_term_ us e r |> extract_xs
[%Trace.retn fun {pf} (_, r') ->
pf "%a" pp_diff (r, r') ;
invariant r']
let and_eq us a b r =
[%Trace.call fun {pf} -> pf "%a = %a@ %a" Term.pp a Term.pp b pp r]
and_eq_ us a b r |> extract_xs
[%Trace.retn fun {pf} (_, r') ->
pf "%a" pp_diff (r, r') ;
invariant r']
let rename r sub =
[%Trace.call fun {pf} -> pf "%a" pp r]
let rep = Subst.map_entries ~f:(Term.rename sub) r.rep in
(if rep == r.rep then r else {r with rep})
[%Trace.retn fun {pf} r' ->
pf "%a" pp_diff (r, r') ;
invariant r']
let fold_terms r ~init ~f =
Subst.fold r.rep ~f:(fun ~key ~data z -> f (f z data) key) ~init
let fold_vars r ~init ~f =
fold_terms r ~init ~f:(fun init -> Term.fold_vars ~f ~init)
let fv e = fold_vars e ~f:Var.Set.add ~init:Var.Set.empty
let pp_classes fs r = pp_clss fs (classes r)
let ppx_classes x fs r = ppx_clss x fs (classes r)
let ppx_classes_diff x fs (r, s) =
let clss = classes s in
let clss =
Term.Map.filter_mapi clss ~f:(fun ~key:rep ~data:cls ->
List.filter cls ~f:(fun exp -> not (entails_eq r rep exp))
| [] -> None
| cls -> Some cls )
List.pp "@ @<2>∧ "
(fun fs (rep, cls) ->
Format.fprintf fs "@[%a@ = %a@]" (Term.ppx x) rep
(List.pp "@ = " (Term.ppx x))
(List.dedup_and_sort ~compare:Term.compare cls) )
fs (Term.Map.to_alist clss)
(** Existential Witnessing and Elimination *)
let subst_invariant us s0 s =
assert (s0 == s || not (Subst.equal s0 s)) ;
assert (
Subst.iteri s ~f:(fun ~key ~data ->
(* dom of new entries not ito us *)
assert (
Option.for_all ~f:(Term.equal data) (Subst.find s0 key)
|| not (Var.Set.is_subset (Term.fv key) ~of_:us) ) ;
(* rep not ito us implies trm not ito us *)
assert (
Var.Set.is_subset (Term.fv data) ~of_:us
|| not (Var.Set.is_subset (Term.fv key) ~of_:us) ) ) ;
true )
type 'a zom = Zero | One of 'a | Many
(** try to solve [p = q] such that [fv (p - q) ⊆ us ∪ xs] and [p - q]
has at most one maximal solvable subterm, [kill], where
[fv kill ⊈ us]; solve [p = q] for [kill]; extend subst mapping [kill]
to the solution *)
let solve_poly_eq us p' q' subst =
[%Trace.call fun {pf} -> pf "%a = %a" Term.pp p' Term.pp q']
let diff = Term.sub p' q' in
let max_solvables_not_ito_us =
fold_max_solvables diff ~init:Zero ~f:(fun solvable_subterm -> function
| Many -> Many
| zom when Var.Set.is_subset (Term.fv solvable_subterm) ~of_:us -> zom
| One _ -> Many
| Zero -> One solvable_subterm )
( match max_solvables_not_ito_us with
| One kill ->
let+ kill, keep = Term.solve_zero_eq diff ~for_:kill in
Subst.compose1 ~key:kill ~data:keep subst
| Many | Zero -> None )
[%Trace.retn fun {pf} subst' ->
pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)]
let solve_memory_eq us e' f' subst =
[%Trace.call fun {pf} -> pf "%a = %a" Term.pp e' Term.pp f']
let f x u =
(not (Var.Set.is_subset (Term.fv x) ~of_:us))
&& Var.Set.is_subset (Term.fv u) ~of_:us
let solve_concat ms n a =
let a, n =
match Term.agg_size a with
| Some n -> (a, n)
| None -> (Term.memory ~siz:n ~arr:a, n)
let+ _, xs, s = solve_concat ~f ms a n (us, Var.Set.empty, subst) in
assert (Var.Set.is_empty xs) ;
( match ((e' : Term.t), (f' : Term.t)) with
| (ApN (Concat, ms) as c), a when f c a ->
solve_concat ms (Term.agg_size_exn c) a
| a, (ApN (Concat, ms) as c) when f c a ->
solve_concat ms (Term.agg_size_exn c) a
| (Ap2 (Memory, _, (Var _ as v)) as m), u when f m u ->
Some (Subst.compose1 ~key:v ~data:u subst)
| u, (Ap2 (Memory, _, (Var _ as v)) as m) when f m u ->
Some (Subst.compose1 ~key:v ~data:u subst)
| _ -> None )
[%Trace.retn fun {pf} subst' ->
pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)]
let solve_interp_eq us e' (cls, subst) =
[%Trace.call fun {pf} ->
pf "trm: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Term.pp e' pp_cls cls
Subst.pp subst]
List.find_map cls ~f:(fun f ->
let f' = Subst.norm subst f in
match solve_memory_eq us e' f' subst with
| Some subst -> Some subst
| None -> solve_poly_eq us e' f' subst )
[%Trace.retn fun {pf} subst' ->
pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst) ;
Option.iter ~f:(subst_invariant us subst) subst']
(** move equations from [cls] to [subst] which are between interpreted terms
and can be expressed, after normalizing with [subst], as [x ↦ u] where
[us ∪ xs ⊇ fv x ⊈ us] and [fv u ⊆ us] or else
[fv u ⊆ us ∪ xs] *)
let rec solve_interp_eqs us (cls, subst) =
[%Trace.call fun {pf} ->
pf "cls: @[%a@]@ subst: @[%a@]" pp_cls cls Subst.pp subst]
let rec solve_interp_eqs_ cls' (cls, subst) =
match cls with
| [] -> (cls', subst)
| trm :: cls ->
let trm' = Subst.norm subst trm in
if interpreted trm' then
match solve_interp_eq us trm' (cls, subst) with
| Some subst -> solve_interp_eqs_ cls' (cls, subst)
| None -> solve_interp_eqs_ (trm' :: cls') (cls, subst)
else solve_interp_eqs_ (trm' :: cls') (cls, subst)
let cls', subst' = solve_interp_eqs_ [] (cls, subst) in
( if subst' != subst then solve_interp_eqs us (cls', subst')
else (cls', subst') )
[%Trace.retn fun {pf} (cls', subst') ->
pf "cls: @[%a@]@ subst: @[%a@]" pp_diff_cls (cls, cls') Subst.pp_diff
(subst, subst')]
type cls_solve_state =
{ rep_us: Term.t option (** rep, that is ito us, for class *)
; cls_us: Term.t list (** cls that is ito us, or interpreted *)
; rep_xs: Term.t option (** rep, that is *not* ito us, for class *)
; cls_xs: Term.t list (** cls that is *not* ito us *) }
let dom_trm e =
match (e : Term.t) with
| Ap2 (Memory, _, (Var _ as v)) -> Some v
| _ when non_interpreted e -> Some e
| _ -> None
(** move equations from [cls] (which is assumed to be normalized by [subst])
to [subst] which can be expressed as [x ↦ u] where [x] is
non-interpreted [us ∪ xs ⊇ fv x ⊈ us] and [fv u ⊆ us] or else
[fv u ⊆ us ∪ xs] *)
let solve_uninterp_eqs us (cls, subst) =
[%Trace.call fun {pf} ->
pf "cls: @[%a@]@ subst: @[%a@]" pp_cls cls Subst.pp subst]
let compare e f =
[%compare: kind * Term.t] (classify e, e) (classify f, f)
let {rep_us; cls_us; rep_xs; cls_xs} =
List.fold cls ~init:{rep_us= None; cls_us= []; rep_xs= None; cls_xs= []}
~f:(fun ({rep_us; cls_us; rep_xs; cls_xs} as s) trm ->
if Var.Set.is_subset (Term.fv trm) ~of_:us then
match rep_us with
| Some rep when compare rep trm <= 0 ->
{s with cls_us= trm :: cls_us}
| Some rep -> {s with rep_us= Some trm; cls_us= rep :: cls_us}
| None -> {s with rep_us= Some trm}
match rep_xs with
| Some rep -> (
if compare rep trm <= 0 then
match dom_trm trm with
| Some trm -> {s with cls_xs= trm :: cls_xs}
| None -> {s with cls_us= trm :: cls_us}
match dom_trm rep with
| Some rep ->
{s with rep_xs= Some trm; cls_xs= rep :: cls_xs}
| None -> {s with rep_xs= Some trm; cls_us= rep :: cls_us} )
| None -> {s with rep_xs= Some trm} )
( match rep_us with
| Some rep_us ->
let cls = rep_us :: cls_us in
let cls, cls_xs =
match rep_xs with
| Some rep -> (
match dom_trm rep with
| Some rep -> (cls, rep :: cls_xs)
| None -> (rep :: cls, cls_xs) )
| None -> (cls, cls_xs)
let subst =
List.fold cls_xs ~init:subst ~f:(fun subst trm_xs ->
Subst.compose1 ~key:trm_xs ~data:rep_us subst )
(cls, subst)
| None -> (
match rep_xs with
| Some rep_xs ->
let cls = rep_xs :: cls_us in
let subst =
List.fold cls_xs ~init:subst ~f:(fun subst trm_xs ->
Subst.compose1 ~key:trm_xs ~data:rep_xs subst )
(cls, subst)
| None -> (cls, subst) ) )
[%Trace.retn fun {pf} (cls', subst') ->
pf "cls: @[%a@]@ subst: @[%a@]" pp_diff_cls (cls, cls') Subst.pp_diff
(subst, subst') ;
subst_invariant us subst subst']
(** move equations between terms in [rep]'s class [cls] from [classes] to
[subst] which can be expressed, after normalizing with [subst], as
[x ↦ u] where [us ∪ xs ⊇ fv x ⊈ us] and [fv u ⊆ us] or else
[fv u ⊆ us ∪ xs] *)
let solve_class us us_xs ~key:rep ~data:cls (classes, subst) =
let classes0 = classes in
[%Trace.call fun {pf} ->
pf "rep: @[%a@]@ cls: @[%a@]@ subst: @[%a@]" Term.pp rep pp_cls cls
Subst.pp subst]
let cls, cls_not_ito_us_xs =
~f:(fun e -> Var.Set.is_subset (Term.fv e) ~of_:us_xs)
(rep :: cls)
let cls, subst = solve_interp_eqs us (cls, subst) in
let cls, subst = solve_uninterp_eqs us (cls, subst) in
let cls = List.rev_append cls_not_ito_us_xs cls in
let cls =
List.remove ~equal:Term.equal cls (Subst.norm subst rep)
|> Option.value ~default:cls
let classes =
if List.is_empty cls then Term.Map.remove classes rep
else Term.Map.set classes ~key:rep ~data:cls
(classes, subst)
[%Trace.retn fun {pf} (classes', subst') ->
pf "subst: @[%a@]@ classes: @[%a@]" Subst.pp_diff (subst, subst')
pp_diff_clss (classes0, classes')]
let solve_concat_extracts_eq r x =
[%Trace.call fun {pf} -> pf "%a@ %a" Term.pp x pp r]
let uses =
fold_uses_of r x ~init:[] ~f:(fun uses -> function
| Ap2 (Memory, _, _) as m ->
fold_uses_of r m ~init:uses ~f:(fun uses -> function
| Ap3 (Extract, _, _, _) as e -> e :: uses | _ -> uses )
| _ -> uses )
let find_extracts_at_off off =
List.filter uses ~f:(fun use ->
match (use : Term.t) with
| Ap3 (Extract, _, o, _) -> entails_eq r o off
| _ -> false )
let rec find_extracts full_rev_extracts rev_prefix off =
List.fold (find_extracts_at_off off) ~init:full_rev_extracts
~f:(fun full_rev_extracts e ->
match e with
| Ap3 (Extract, Ap2 (Memory, n, _), o, l) ->
let o_l = Term.add o l in
if entails_eq r n o_l then
(e :: rev_prefix) :: full_rev_extracts
else find_extracts full_rev_extracts (e :: rev_prefix) o_l
| _ -> full_rev_extracts )
find_extracts [] [] Term.zero
[%Trace.retn fun {pf} ->
pf "@[[%a]@]" (List.pp ";@ " (List.pp ",@ " Term.pp))]
let solve_concat_extracts r us x (classes, subst, us_xs) =
List.filter_map (solve_concat_extracts_eq r x) ~f:(fun rev_extracts ->
List.fold_option rev_extracts ~init:[] ~f:(fun suffix e ->
let+ rep_ito_us =
List.fold (cls_of r e) ~init:None ~f:(fun rep_ito_us trm ->
match rep_ito_us with
| Some rep when Term.compare rep trm <= 0 -> rep_ito_us
| _ when Var.Set.is_subset (Term.fv trm) ~of_:us ->
Some trm
| _ -> rep_ito_us )
Term.memory ~siz:(Term.agg_size_exn e) ~arr:rep_ito_us :: suffix
) )
|> List.min_elt ~compare:[%compare: Term.t list]
| Some extracts ->
let concat = Term.concat (Array.of_list extracts) in
let subst = Subst.compose1 ~key:x ~data:concat subst in
(classes, subst, us_xs)
| None -> (classes, subst, us_xs)
let solve_for_xs r us xs (classes, subst, us_xs) =
Var.Set.fold xs ~init:(classes, subst, us_xs)
~f:(fun (classes, subst, us_xs) x ->
let x = Term.var x in
if Subst.mem subst x then (classes, subst, us_xs)
else solve_concat_extracts r us x (classes, subst, us_xs) )
(** move equations from [classes] to [subst] which can be expressed, after
normalizing with [subst], as [x ↦ u] where [us ∪ xs ⊇ fv x ⊈ us]
and [fv u ⊆ us] or else [fv u ⊆ us ∪ xs]. *)
let solve_classes r (classes, subst, us) xs =
[%Trace.call fun {pf} ->
pf "us: {@[%a@]}@ xs: {@[%a@]}" Var.Set.pp us Var.Set.pp xs]
let rec solve_classes_ (classes0, subst0, us_xs) =
let classes, subst =
Term.Map.fold ~f:(solve_class us us_xs) classes0
~init:(classes0, subst0)
if subst != subst0 then solve_classes_ (classes, subst, us_xs)
else (classes, subst, us_xs)
(classes, subst, Var.Set.union us xs)
|> solve_classes_ |> solve_for_xs r us xs
[%Trace.retn fun {pf} (classes', subst', _) ->
pf "subst: @[%a@]@ classes: @[%a@]" Subst.pp_diff (subst, subst')
pp_diff_clss (classes, classes')]
let pp_vss fs vss =
Format.fprintf fs "[@[%a@]]"
(List.pp ";@ " (fun fs vs -> Format.fprintf fs "{@[%a@]}" Var.Set.pp vs))
(** enumerate variable contexts vᵢ in [v₁;…] and accumulate a solution
subst with entries [x ↦ u] where [r] entails [x = u] and
[⋃ⱼ₌₁ⁱ vⱼ ⊇ fv x ⊈ ⋃ⱼ₌₁ⁱ⁻¹ vⱼ] and
[fv u ⊆ ⋃ⱼ₌₁ⁱ⁻¹ vⱼ] if possible and otherwise
[fv u ⊆ ⋃ⱼ₌₁ⁱ vⱼ] *)
let solve_for_vars vss r =
[%Trace.call fun {pf} ->
pf "%a@ @[%a@]@ @[%a@]" pp_vss vss pp_classes r pp r ;
invariant r]
let us, vss =
match vss with us :: vss -> (us, vss) | [] -> (Var.Set.empty, vss)
List.fold ~f:(solve_classes r) ~init:(classes r, Subst.empty, us) vss
|> snd3
[%Trace.retn fun {pf} subst ->
pf "%a" Subst.pp subst ;
Subst.iteri subst ~f:(fun ~key ~data ->
assert (
entails_eq r key data
|| fail "@[%a@ = %a@ not entailed by@ @[%a@]@]" Term.pp key
Term.pp data pp_classes r () ) ;
assert (
List.fold_until vss ~init:us
~f:(fun us xs ->
let us_xs = Var.Set.union us xs in
let ks = Term.fv key in
let ds = Term.fv data in
Var.Set.is_subset ks ~of_:us_xs
&& Var.Set.is_subset ds ~of_:us_xs
&& ( Var.Set.is_subset ds ~of_:us
|| not (Var.Set.is_subset ks ~of_:us) )
then Stop true
else Continue us_xs )
~finish:(fun _ -> false) ) )]
let elim xs r = {r with rep= Subst.remove xs r.rep}
(* Replay debugging *)
type call =
| Normalize of t * Term.t
| And_eq of Var.Set.t * Term.t * Term.t * t
| And_term of Var.Set.t * Term.t * t
| And_ of Var.Set.t * t * t
| Or_ of Var.Set.t * t * t
| OrN of Var.Set.t * t list
| Rename of t * Var.Subst.t
| Apply_subst of Var.Set.t * Subst.t * t
| Solve_for_vars of Var.Set.t list * t
[@@deriving sexp]
let replay c =
match call_of_sexp (Sexp.of_string c) with
| Normalize (r, e) -> normalize r e |> ignore
| And_eq (us, a, b, r) -> and_eq us a b r |> ignore
| And_term (us, e, r) -> and_term us e r |> ignore
| And_ (us, r, s) -> and_ us r s |> ignore
| Or_ (us, r, s) -> or_ us r s |> ignore
| OrN (us, rs) -> orN us rs |> ignore
| Rename (r, s) -> rename r s |> ignore
| Apply_subst (us, s, r) -> apply_subst us s r |> ignore
| Solve_for_vars (vss, r) -> solve_for_vars vss r |> ignore
(* Debug wrappers *)
let report ~name ~elapsed ~aggregate ~count =
Format.eprintf "%15s time: %12.3f ms %12.3f ms %12d calls@." name
elapsed aggregate count
let dump_threshold = ref 1000.
let wrap tmr f call =
let f () =
Timer.start tmr ;
let r = f () in
Timer.stop_report tmr (fun ~name ~elapsed ~aggregate ~count ->
report ~name ~elapsed ~aggregate ~count ;
if Float.(elapsed > !dump_threshold) then (
dump_threshold := 2. *. !dump_threshold ;
Format.eprintf "@\n%a@\n@." Sexp.pp_hum (sexp_of_call (call ())) )
) ;
if not [%debug] then f ()
try f () with exn -> raise_s ([%sexp_of: exn * call] (exn, call ()))
let normalize_tmr = Timer.create "normalize" ~at_exit:report
let and_eq_tmr = Timer.create "and_eq" ~at_exit:report
let and_term_tmr = Timer.create "and_term" ~at_exit:report
let and_tmr = Timer.create "and_" ~at_exit:report
let or_tmr = Timer.create "or_" ~at_exit:report
let orN_tmr = Timer.create "orN" ~at_exit:report
let rename_tmr = Timer.create "rename" ~at_exit:report
let apply_subst_tmr = Timer.create "apply_subst" ~at_exit:report
let solve_for_vars_tmr = Timer.create "solve_for_vars" ~at_exit:report
let normalize r e =
wrap normalize_tmr (fun () -> normalize r e) (fun () -> Normalize (r, e))
let and_eq us a b r =
wrap and_eq_tmr
(fun () -> and_eq us a b r)
(fun () -> And_eq (us, a, b, r))
let and_term us e r =
wrap and_term_tmr
(fun () -> and_term us e r)
(fun () -> And_term (us, e, r))
let and_ us r s =
wrap and_tmr (fun () -> and_ us r s) (fun () -> And_ (us, r, s))
let or_ us r s =
wrap or_tmr (fun () -> or_ us r s) (fun () -> Or_ (us, r, s))
let orN us rs = wrap orN_tmr (fun () -> orN us rs) (fun () -> OrN (us, rs))
let rename r s =
wrap rename_tmr (fun () -> rename r s) (fun () -> Rename (r, s))
let apply_subst us s r =
wrap apply_subst_tmr
(fun () -> apply_subst us s r)
(fun () -> Apply_subst (us, s, r))
let solve_for_vars vss r =
wrap solve_for_vars_tmr
(fun () -> solve_for_vars vss r)
(fun () -> Solve_for_vars (vss, r))