[sledge] Rework first-order solver to use a list of pending equations

Summary:
Currently when solving an equation requires solving one or more other
equations, they are handled recursively. This diff reworks the solver
to instead queue pending equations in a list rather than eagerly
solving them eagerly. The pending list will be needed when changing
the Context.t representation since some equations will need to be
delayed until representation invariants are re-established.

Additionally, the solution is accumulated as an association list
rather than being eagerly incorporated into the context
representation. This enables the uses in equality solving and
quantifier elimination to use different representations.

Reviewed By: jvillard

Differential Revision: D25883727

fbshipit-source-id: 2f6b5efa3
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent 90974b92e4
commit caae7515e8

@ -9,7 +9,7 @@
open Exp open Exp
(** Classification of Terms by Theory *) (* Classification of Terms ================================================*)
type kind = Interpreted | Atomic | Uninterpreted type kind = Interpreted | Atomic | Uninterpreted
[@@deriving compare, equal] [@@deriving compare, equal]
@ -34,7 +34,8 @@ let rec max_solvables e =
let fold_max_solvables e s ~f = Iter.fold ~f (max_solvables e) s let fold_max_solvables e s ~f = Iter.fold ~f (max_solvables e) s
(** Solution Substitutions *) (* Solution Substitutions =================================================*)
module Subst : sig module Subst : sig
type t [@@deriving compare, equal, sexp] type t [@@deriving compare, equal, sexp]
@ -58,7 +59,6 @@ module Subst : sig
val extend : Trm.t -> t -> t option val extend : Trm.t -> t -> t option
val map_entries : f:(Trm.t -> Trm.t) -> t -> t val map_entries : f:(Trm.t -> Trm.t) -> t -> t
val to_iter : t -> (Trm.t * Trm.t) iter val to_iter : t -> (Trm.t * Trm.t) iter
val fv : t -> Var.Set.t
val partition_valid : Var.Set.t -> t -> t * Var.Set.t * t val partition_valid : Var.Set.t -> t -> t * Var.Set.t * t
(* direct representation manipulation *) (* direct representation manipulation *)
@ -84,14 +84,6 @@ end = struct
let for_alli = Trm.Map.for_alli let for_alli = Trm.Map.for_alli
let to_iter = Trm.Map.to_iter let to_iter = Trm.Map.to_iter
let vars s =
s
|> to_iter
|> Iter.flat_map ~f:(fun (k, v) ->
Iter.append (Trm.vars k) (Trm.vars v) )
let fv s = Var.Set.of_iter (vars s)
(** look up a term in a substitution *) (** look up a term in a substitution *)
let apply s a = Trm.Map.find a s |> Option.value ~default:a let apply s a = Trm.Map.find a s |> Option.value ~default:a
@ -126,7 +118,7 @@ end = struct
| _ when Trm.equal key data -> r | _ when Trm.equal key data -> r
| _ -> | _ ->
assert ( assert (
(not (Trm.Map.mem key r)) Option.for_all ~f:(Trm.equal key) (Trm.Map.find key r)
|| fail "domains intersect: %a" Trm.pp key () ) ; || fail "domains intersect: %a" Trm.pp key () ) ;
let s = Trm.Map.singleton key data in let s = Trm.Map.singleton key data in
let r' = Trm.Map.map_endo ~f:(norm s) r in let r' = Trm.Map.map_endo ~f:(norm s) r in
@ -208,7 +200,7 @@ end = struct
let remove = Trm.Map.remove let remove = Trm.Map.remove
end end
(** Theory Solver *) (* Theory Solver ==========================================================*)
(** prefer representative terms that are minimal in the order s.t. Var < (** prefer representative terms that are minimal in the order s.t. Var <
Sized < Extract < Concat < others, then using height of sequence Sized < Extract < Concat < others, then using height of sequence
@ -235,155 +227,163 @@ let orient e f =
| Zero -> None | Zero -> None
| Pos -> Some (f, e) | Pos -> Some (f, e)
let norm (_, _, s) e = Subst.norm s e type solve_state =
{ wrt: Var.Set.t
let compose1 ~var ~rep (us, xs, s) = ; no_fresh: bool
let s = Subst.compose1 ~key:var ~data:rep s in ; fresh: Var.Set.t
Some (us, xs, s) ; solved: (Trm.t * Trm.t) list option
; pending: (Trm.t * Trm.t) list }
let fresh name (wrt, xs, s) =
let x, wrt = Var.fresh name ~wrt in let pp_solve_state ppf = function
let xs = Var.Set.add x xs in | {solved= None} -> Format.fprintf ppf "unsat"
(Trm.var x, (wrt, xs, s)) | {solved= Some solved; fresh; pending} ->
Format.fprintf ppf "%a%a : %a" Var.Set.pp_xs fresh
(List.pp ";@ " (fun ppf (var, rep) ->
Format.fprintf ppf "@[%a ↦ %a@]" Trm.pp var Trm.pp rep ))
solved
(List.pp ";@ " (fun ppf (a, b) ->
Format.fprintf ppf "@[%a = %a@]" Trm.pp a Trm.pp b ))
pending
let add_solved ~var ~rep s =
match s with
| {solved= None} -> s
| {solved= Some solved} -> {s with solved= Some ((var, rep) :: solved)}
let add_pending a b s = {s with pending= (a, b) :: s.pending}
let fresh name s =
if s.no_fresh then None
else
let x, wrt = Var.fresh name ~wrt:s.wrt in
let fresh = Var.Set.add x s.fresh in
Some (Trm.var x, {s with wrt; fresh})
let solve_poly p q s = let solve_poly p q s =
[%trace] [%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp p Trm.pp q) ~call:(fun {pf} -> pf "@ %a = %a" Trm.pp p Trm.pp q)
~retn:(fun {pf} -> function ~retn:(fun {pf} -> pf "%a" pp_solve_state)
| Some (_, xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s
| None -> pf "false" )
@@ fun () -> @@ fun () ->
match Trm.sub p q with match Trm.sub p q with
| Z z -> if Z.equal Z.zero z then Some s else None | Z z -> if Z.equal Z.zero z then s else {s with solved= None}
| Var _ as var -> compose1 ~var ~rep:Trm.zero s | Var _ as var -> add_solved ~var ~rep:Trm.zero s
| p_q -> ( | p_q -> (
match Trm.Arith.solve_zero_eq p_q with match Trm.Arith.solve_zero_eq p_q with
| Some (var, rep) -> | Some (var, rep) ->
compose1 ~var:(Trm.arith var) ~rep:(Trm.arith rep) s add_solved ~var:(Trm.arith var) ~rep:(Trm.arith rep) s
| None -> compose1 ~var:p_q ~rep:Trm.zero s ) | None -> add_solved ~var:p_q ~rep:Trm.zero s )
(* α[o,l) = β ==> l = |β| ∧ α = (⟨n,c⟩[0,o) ^ β ^ ⟨n,c⟩[o+l,n-o-l)) where n (* α[o,l) = β ==> l = |β| ∧ α = (⟨n,c⟩[0,o) ^ β ^ ⟨n,c⟩[o+l,n-o-l)) where n
= |α| and c fresh *) = |α| and c fresh *)
let rec solve_extract ?(no_fresh = false) a o l b s = let solve_extract a o l b s =
[%trace] [%trace]
~call:(fun {pf} -> ~call:(fun {pf} ->
pf "@ %a = %a@ %a%a" Trm.pp pf "@ %a = %a" Trm.pp (Trm.extract ~seq:a ~off:o ~len:l) Trm.pp b )
(Trm.extract ~seq:a ~off:o ~len:l) ~retn:(fun {pf} -> pf "%a" pp_solve_state)
Trm.pp b Var.Set.pp_xs (snd3 s) Subst.pp (trd3 s) )
~retn:(fun {pf} -> function
| Some (_, xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s
| None -> pf "false" )
@@ fun () -> @@ fun () ->
if no_fresh then Some s match fresh "c" s with
else | None -> s
let n = Trm.seq_size_exn a in | Some (c, s) ->
let c, s = fresh "c" s in let n = Trm.seq_size_exn a in
let n_c = Trm.sized ~siz:n ~seq:c in let n_c = Trm.sized ~siz:n ~seq:c in
let o_l = Trm.add o l in let o_l = Trm.add o l in
let n_o_l = Trm.sub n o_l in let n_o_l = Trm.sub n o_l in
let c0 = Trm.extract ~seq:n_c ~off:Trm.zero ~len:o in let c0 = Trm.extract ~seq:n_c ~off:Trm.zero ~len:o in
let c1 = Trm.extract ~seq:n_c ~off:o_l ~len:n_o_l in let c1 = Trm.extract ~seq:n_c ~off:o_l ~len:n_o_l in
let b, s = let b, s =
match Trm.seq_size b with match Trm.seq_size b with
| None -> (Trm.sized ~siz:l ~seq:b, Some s) | None -> (Trm.sized ~siz:l ~seq:b, s)
| Some m -> (b, solve_ l m s) | Some m -> (b, add_pending l m s)
in in
s >>= solve_ a (Trm.concat [|c0; b; c1|]) add_pending a (Trm.concat [|c0; b; c1|]) s
(* α₀^…^αᵢ^αⱼ^…^αᵥ = β ==> |α₀^…^αᵥ| = |β| ∧ … ∧ αⱼ = β[n₀+…+nᵢ,nⱼ) ∧ … (* α₀^…^αᵢ^αⱼ^…^αᵥ = β ==> |α₀^…^αᵥ| = |β| ∧ … ∧ αⱼ = β[n₀+…+nᵢ,nⱼ) ∧ …
where n |α| and m = |β| *) where n |α| and m = |β| *)
and solve_concat ?no_fresh a0V b m s = let solve_concat a0V b m s =
[%trace] [%trace]
~call:(fun {pf} -> ~call:(fun {pf} -> pf "@ %a = %a" Trm.pp (Trm.concat a0V) Trm.pp b)
pf "@ %a = %a@ %a%a" Trm.pp (Trm.concat a0V) Trm.pp b Var.Set.pp_xs ~retn:(fun {pf} -> pf "%a" pp_solve_state)
(snd3 s) Subst.pp (trd3 s) )
~retn:(fun {pf} -> function
| Some (_, xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s
| None -> pf "false" )
@@ fun () -> @@ fun () ->
Iter.fold_until (Array.to_iter a0V) (s, Trm.zero) let s, n0V =
~f:(fun aJ (s, oI) -> Iter.fold (Array.to_iter a0V) (s, Trm.zero) ~f:(fun aJ (s, oI) ->
let nJ = Trm.seq_size_exn aJ in let nJ = Trm.seq_size_exn aJ in
let oJ = Trm.add oI nJ in let oJ = Trm.add oI nJ in
match solve_ ?no_fresh aJ (Trm.extract ~seq:b ~off:oI ~len:nJ) s with let s = add_pending aJ (Trm.extract ~seq:b ~off:oI ~len:nJ) s in
| Some s -> `Continue (s, oJ) (s, oJ) )
| None -> `Stop None ) in
~finish:(fun (s, n0V) -> solve_ ?no_fresh n0V m s) add_pending n0V m s
and solve_ ?no_fresh d e s = let solve_ d e s =
[%Trace.call fun {pf} -> [%trace]
pf "@ %a@[%a@ %a@ %a@]" Var.Set.pp_xs (snd3 s) Trm.pp d Trm.pp e ~call:(fun {pf} -> pf "@ %a = %a" Trm.pp d Trm.pp e)
Subst.pp (trd3 s)] ~retn:(fun {pf} -> pf "%a" pp_solve_state)
; @@ fun () ->
( match orient (norm s d) (norm s e) with match orient d e with
(* e' = f' ==> true when e' ≡ f' *) (* e' = f' ==> true when e' ≡ f' *)
| None -> Some s | None -> s
(* i = j ==> false when i ≠ j *) (* i = j ==> false when i ≠ j *)
| Some (Z _, Z _) | Some (Q _, Q _) -> None | Some (Z _, Z _) | Some (Q _, Q _) -> {s with solved= None}
(* (*
* Concat * Concat
*) *)
(* ⟨0,a⟩ = β ==> a = β = ⟨⟩ *) (* ⟨0,a⟩ = β ==> a = β = ⟨⟩ *)
| Some (Sized {siz= n; seq= a}, b) when n == Trm.zero -> | Some (Sized {siz= n; seq= a}, b) when n == Trm.zero ->
s s
|> solve_ ?no_fresh a (Trm.concat [||]) |> add_pending a (Trm.concat [||])
>>= solve_ ?no_fresh b (Trm.concat [||]) |> add_pending b (Trm.concat [||])
| Some (b, Sized {siz= n; seq= a}) when n == Trm.zero -> | Some (b, Sized {siz= n; seq= a}) when n == Trm.zero ->
s s
|> solve_ ?no_fresh a (Trm.concat [||]) |> add_pending a (Trm.concat [||])
>>= solve_ ?no_fresh b (Trm.concat [||]) |> add_pending b (Trm.concat [||])
(* ⟨n,0⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,0⟩[n₀+…+nᵢ,nⱼ) ∧ … *) (* ⟨n,0⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,0⟩[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Sized {siz= n; seq} as b), Concat a0V) when seq == Trm.zero -> | Some ((Sized {siz= n; seq} as b), Concat a0V) when seq == Trm.zero ->
solve_concat ?no_fresh a0V b n s solve_concat a0V b n s
(* ⟨n,e^⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,e^⟩[n₀+…+nᵢ,nⱼ) ∧ … *) (* ⟨n,e^⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,e^⟩[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Sized {siz= n; seq= Splat _} as b), Concat a0V) -> | Some ((Sized {siz= n; seq= Splat _} as b), Concat a0V) ->
solve_concat ?no_fresh a0V b n s solve_concat a0V b n s
| Some ((Var _ as v), (Concat a0V as c)) -> | Some ((Var _ as v), (Concat a0V as c)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv c)) then if not (Var.Set.mem (Var.of_ v) (Trm.fv c)) then
(* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *) (* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *)
compose1 ~var:v ~rep:c s add_solved ~var:v ~rep:c s
else else
(* v = α₀^…^αᵥ ==> ⟨|α₀^…^αᵥ|,v⟩ = α₀^…^αᵥ when v ∈ fv(α₀^…^αᵥ) *) (* v = α₀^…^αᵥ ==> ⟨|α₀^…^αᵥ|,v⟩ = α₀^…^αᵥ when v ∈ fv(α₀^…^αᵥ) *)
let m = Trm.seq_size_exn c in let m = Trm.seq_size_exn c in
solve_concat ?no_fresh a0V (Trm.sized ~siz:m ~seq:v) m s solve_concat a0V (Trm.sized ~siz:m ~seq:v) m s
(* α₀^…^αᵥ = β₀^…^βᵤ ==> … ∧ αⱼ = (β₀^…^βᵤ)[n₀+…+nᵢ,nⱼ) ∧ … *) (* α₀^…^αᵥ = β₀^…^βᵤ ==> … ∧ αⱼ = (β₀^…^βᵤ)[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some (Concat a0V, (Concat _ as c)) -> | Some (Concat a0V, (Concat _ as c)) ->
solve_concat ?no_fresh a0V c (Trm.seq_size_exn c) s solve_concat a0V c (Trm.seq_size_exn c) s
(* α[o,l) = α₀^…^αᵥ ==> … ∧ αⱼ = α[o,l)[n₀+…+nᵢ,nⱼ) ∧ … *) (* α[o,l) = α₀^…^αᵥ ==> … ∧ αⱼ = α[o,l)[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Extract {len= l} as e), Concat a0V) -> | Some ((Extract {len= l} as e), Concat a0V) -> solve_concat a0V e l s
solve_concat ?no_fresh a0V e l s
(* (*
* Extract * Extract
*) *)
| Some ((Var _ as v), (Extract {len= l} as e)) -> | Some ((Var _ as v), (Extract {len= l} as e)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv e)) then if not (Var.Set.mem (Var.of_ v) (Trm.fv e)) then
(* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *) (* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *)
compose1 ~var:v ~rep:e s add_solved ~var:v ~rep:e s
else else
(* v = α[o,l) ==> α[o,l) ↦ ⟨l,v⟩ when v ∈ fv(α[o,l)) *) (* v = α[o,l) ==> α[o,l) ↦ ⟨l,v⟩ when v ∈ fv(α[o,l)) *)
compose1 ~var:e ~rep:(Trm.sized ~siz:l ~seq:v) s add_solved ~var:e ~rep:(Trm.sized ~siz:l ~seq:v) s
(* α[o,l) = β ==> … ∧ α = _^β^_ *) (* α[o,l) = β ==> … ∧ α = _^β^_ *)
| Some (Extract {seq= a; off= o; len= l}, e) -> | Some (Extract {seq= a; off= o; len= l}, e) -> solve_extract a o l e s
solve_extract ?no_fresh a o l e s
(* (*
* Sized * Sized
*) *)
(* v = ⟨n,a⟩ ==> v = a *) (* v = ⟨n,a⟩ ==> v = a *)
| Some ((Var _ as v), Sized {seq= a}) -> s |> solve_ ?no_fresh v a | Some ((Var _ as v), Sized {seq= a}) -> s |> add_pending v a
(* ⟨n,a⟩ = ⟨m,b⟩ ==> n = m ∧ a = β *) (* ⟨n,a⟩ = ⟨m,b⟩ ==> n = m ∧ a = β *)
| Some (Sized {siz= n; seq= a}, Sized {siz= m; seq= b}) -> | Some (Sized {siz= n; seq= a}, Sized {siz= m; seq= b}) ->
s |> solve_ ?no_fresh n m >>= solve_ ?no_fresh a b s |> add_pending n m |> add_pending a b
(* ⟨n,a⟩ = β ==> n = |β| ∧ a = β *) (* ⟨n,a⟩ = β ==> n = |β| ∧ a = β *)
| Some (Sized {siz= n; seq= a}, b) -> | Some (Sized {siz= n; seq= a}, b) ->
( match Trm.seq_size b with s
| None -> Some s |> Option.fold ~f:(add_pending n) (Trm.seq_size b)
| Some m -> solve_ ?no_fresh n m s ) |> add_pending a b
>>= solve_ ?no_fresh a b
(* (*
* Splat * Splat
*) *)
(* a^ = b^ ==> a = b *) (* a^ = b^ ==> a = b *)
| Some (Splat a, Splat b) -> s |> solve_ ?no_fresh a b | Some (Splat a, Splat b) -> s |> add_pending a b
(* (*
* Arithmetic * Arithmetic
*) *)
@ -398,27 +398,16 @@ and solve_ ?no_fresh d e s =
| Some (rep, var) -> | Some (rep, var) ->
assert (not (is_interpreted var)) ; assert (not (is_interpreted var)) ;
assert (not (is_interpreted rep)) ; assert (not (is_interpreted rep)) ;
compose1 ~var ~rep s ) add_solved ~var ~rep s
|>
[%Trace.retn fun {pf} ->
function
| Some (_, xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s
| None -> pf "false"]
let solve ~wrt ~xs d e = let solve ~wrt ~xs d e pending =
[%Trace.call fun {pf} -> pf "@ %a@ %a" Trm.pp d Trm.pp e] [%trace]
; ~call:(fun {pf} -> pf "@ %a@ %a" Trm.pp d Trm.pp e)
( solve_ d e (wrt, xs, Subst.empty) ~retn:(fun {pf} -> pf "%a" pp_solve_state)
|>= fun (_, xs, s) -> @@ fun () ->
let xs = Var.Set.inter xs (Subst.fv s) in solve_ d e {wrt; no_fresh= false; fresh= xs; solved= Some []; pending}
(xs, s) )
|>
[%Trace.retn fun {pf} ->
function
| Some (xs, s) -> pf "%a%a" Var.Set.pp_xs xs Subst.pp s
| None -> pf "false"]
(** Equality classes *) (* Equality classes =======================================================*)
module Cls : sig module Cls : sig
type t [@@deriving equal] type t [@@deriving equal]
@ -460,7 +449,7 @@ end = struct
let pp_diff = List.pp_diff ~cmp:Trm.compare "@ = " Trm.pp let pp_diff = List.pp_diff ~cmp:Trm.compare "@ = " Trm.pp
end end
(** Equality Relations *) (* Conjunctions of atomic formula assumptions =============================*)
(** see also [invariant] *) (** see also [invariant] *)
type t = type t =
@ -470,7 +459,10 @@ type t =
; rep: Subst.t ; rep: Subst.t
(** functional set of oriented equations: map [a] to [a'], (** functional set of oriented equations: map [a] to [a'],
indicating that [a = a'] holds, and that [a'] is the indicating that [a = a'] holds, and that [a'] is the
'rep(resentative)' of [a] *) } 'rep(resentative)' of [a] *)
; pnd: (Trm.t * Trm.t) list
(** pending equations to add (once invariants are reestablished) *)
}
[@@deriving compare, equal, sexp] [@@deriving compare, equal, sexp]
let classes r = let classes r =
@ -484,9 +476,12 @@ let cls_of r e =
let e' = Subst.apply r.rep e in let e' = Subst.apply r.rep e in
Trm.Map.find e' (classes r) |> Option.value ~default:(Cls.of_ e') Trm.Map.find e' (classes r) |> Option.value ~default:(Cls.of_ e')
(** Pretty-printing *) (* Pretty-printing ========================================================*)
let pp_eq fs (e, f) = Format.fprintf fs "@[%a = %a@]" Trm.pp e Trm.pp f
let pp_pnd = List.pp ";@ " pp_eq
let pp_raw fs {sat; rep} = let pp_raw fs {sat; rep; pnd} =
let pp_alist pp_k pp_v fs alist = let pp_alist pp_k pp_v fs alist =
let pp_assoc fs (k, v) = let pp_assoc fs (k, v) =
Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_k k pp_v (k, v) Format.fprintf fs "[@[%a@ @<2>↦ %a@]]" pp_k k pp_v (k, v)
@ -494,9 +489,14 @@ let pp_raw fs {sat; rep} =
Format.fprintf fs "[@[<hv>%a@]]" (List.pp ";@ " pp_assoc) alist Format.fprintf fs "[@[<hv>%a@]]" (List.pp ";@ " pp_assoc) alist
in in
let pp_term_v fs (k, v) = if not (Trm.equal k v) then Trm.pp fs v in let pp_term_v fs (k, v) = if not (Trm.equal k v) then Trm.pp fs v in
Format.fprintf fs "@[{@[<hv>sat= %b;@ rep= %a@]}@]" sat let pp_pnd fs pnd =
if not (List.is_empty pnd) then
Format.fprintf fs ";@ pnd= @[%a@]" pp_pnd pnd
in
Format.fprintf fs "@[{@[<hv>sat= %b;@ rep= %a%a@]}@]" sat
(pp_alist Trm.pp pp_term_v) (pp_alist Trm.pp pp_term_v)
(Iter.to_list (Subst.to_iter rep)) (Iter.to_list (Subst.to_iter rep))
pp_pnd pnd
let pp_diff fs (r, s) = let pp_diff fs (r, s) =
let pp_sat fs = let pp_sat fs =
@ -505,9 +505,14 @@ let pp_diff fs (r, s) =
in in
let pp_rep fs = let pp_rep fs =
if not (Subst.is_empty r.rep) then if not (Subst.is_empty r.rep) then
Format.fprintf fs "rep= %a" Subst.pp_diff (r.rep, s.rep) Format.fprintf fs "rep= %a;@ " Subst.pp_diff (r.rep, s.rep)
in in
Format.fprintf fs "@[{@[<hv>%t%t@]}@]" pp_sat pp_rep let pp_pnd fs =
Format.fprintf fs "pnd= @[%a@]"
(List.pp_diff ~cmp:[%compare: Trm.t * Trm.t] ";@ " pp_eq)
(r.pnd, s.pnd)
in
Format.fprintf fs "@[{@[<hv>%t%t%t@]}@]" pp_sat pp_rep pp_pnd
let ppx_classes x fs clss = let ppx_classes x fs clss =
List.pp "@ @<2>∧ " List.pp "@ @<2>∧ "
@ -548,7 +553,7 @@ let ppx var_strength fs clss noneqs =
"@ @<2>∧ " (Fml.ppx var_strength) fs noneqs ~suf:"@]" ; "@ @<2>∧ " (Fml.ppx var_strength) fs noneqs ~suf:"@]" ;
first && List.is_empty noneqs first && List.is_empty noneqs
(** Basic queries *) (* Basic queries ==========================================================*)
(** test membership in carrier *) (** test membership in carrier *)
let in_car r e = Subst.mem e r.rep let in_car r e = Subst.mem e r.rep
@ -560,7 +565,7 @@ let semi_congruent r a' b = Trm.equal a' (Trm.map ~f:(Subst.norm r.rep) b)
(** terms are congruent if equal after normalizing subterms *) (** terms are congruent if equal after normalizing subterms *)
let congruent r a b = semi_congruent r (Trm.map ~f:(Subst.norm r.rep) a) b let congruent r a b = semi_congruent r (Trm.map ~f:(Subst.norm r.rep) a) b
(** Invariant *) (* Invariant ==============================================================*)
let pre_invariant r = let pre_invariant r =
let@ () = Invariant.invariant [%here] r [%sexp_of: t] in let@ () = Invariant.invariant [%here] r [%sexp_of: t] in
@ -586,6 +591,7 @@ let pre_invariant r =
let invariant r = let invariant r =
let@ () = Invariant.invariant [%here] r [%sexp_of: t] in let@ () = Invariant.invariant [%here] r [%sexp_of: t] in
pre_invariant r ; pre_invariant r ;
assert (List.is_empty r.pnd) ;
assert ( assert (
(not r.sat) (not r.sat)
|| Subst.for_alli r.rep ~f:(fun ~key:a ~data:a' -> || Subst.for_alli r.rep ~f:(fun ~key:a ~data:a' ->
@ -596,13 +602,44 @@ let invariant r =
|| fail "not congruent %a@ %a@ in@ %a" Trm.pp a Trm.pp b pp r || fail "not congruent %a@ %a@ in@ %a" Trm.pp a Trm.pp b pp r
() ) ) ) () ) ) )
(** Core operations *) (* Representation helpers =================================================*)
let add_to_pnd a a' x =
if Trm.equal a a' then x else {x with pnd= (a, a') :: x.pnd}
(* Propagation ============================================================*)
let propagate1 (trm, rep) x =
[%trace]
~call:(fun {pf} ->
pf "@ @[%a ↦ %a@]@ %a" Trm.pp trm Trm.pp rep pp_raw x )
~retn:(fun {pf} -> pf "%a" pp_raw)
@@ fun () ->
let rep = Subst.compose1 ~key:trm ~data:rep x.rep in
{x with rep}
let rec propagate ~wrt x =
[%trace]
~call:(fun {pf} -> pf "@ %a" pp_raw x)
~retn:(fun {pf} -> pf "%a" pp_raw)
@@ fun () ->
match x.pnd with
| (a, b) :: pnd -> (
let a' = Subst.norm x.rep a in
let b' = Subst.norm x.rep b in
match solve ~wrt ~xs:x.xs a' b' pnd with
| {solved= Some solved; wrt; fresh; pending} ->
let xs = Var.Set.union x.xs fresh in
let x = {x with xs; pnd= pending} in
propagate ~wrt (List.fold ~f:propagate1 solved x)
| {solved= None} -> {x with sat= false; pnd= []} )
| [] -> x
(* Core operations ========================================================*)
let empty = let empty =
let rep = Subst.empty in let rep = Subst.empty in
(* let rep = Option.get_exn (Subst.extend Trm.true_ rep) in {xs= Var.Set.empty; sat= true; rep; pnd= []} |> check invariant
* let rep = Option.get_exn (Subst.extend Trm.false_ rep) in *)
{xs= Var.Set.empty; sat= true; rep} |> check invariant
let unsat = {empty with sat= false} let unsat = {empty with sat= false}
@ -656,17 +693,13 @@ let extend a r =
let rep = extend_ a r.rep in let rep = extend_ a r.rep in
if rep == r.rep then r else {r with rep} |> check pre_invariant if rep == r.rep then r else {r with rep} |> check pre_invariant
let merge ~wrt a b r = let merge ~wrt a b x =
[%Trace.call fun {pf} -> pf "@ %a@ %a@ %a" Trm.pp a Trm.pp b pp r] [%trace]
; ~call:(fun {pf} -> pf "@ %a@ %a@ %a" Trm.pp a Trm.pp b pp x)
( match solve ~wrt ~xs:r.xs a b with ~retn:(fun {pf} x' ->
| Some (xs, s) -> pf "%a" pp_diff (x, x') ;
{r with xs= Var.Set.union r.xs xs; rep= Subst.compose r.rep s} pre_invariant x' )
| None -> {r with sat= false} ) @@ fun () -> propagate ~wrt (add_to_pnd a b x)
|>
[%Trace.retn fun {pf} r' ->
pf "%a" pp_diff (r, r') ;
pre_invariant r']
(** find an unproved equation between congruent terms *) (** find an unproved equation between congruent terms *)
let find_missing r = let find_missing r =
@ -684,12 +717,12 @@ let find_missing r =
in in
Option.return_if need_a'_eq_b' (a', b') ) ) Option.return_if need_a'_eq_b' (a', b') ) )
let rec close ~wrt r = let rec close ~wrt x =
if not r.sat then r if not x.sat then x
else else
match find_missing r with match find_missing x with
| Some (a', b') -> close ~wrt (merge ~wrt a' b' r) | Some (a', b') -> close ~wrt (merge ~wrt a' b' x)
| None -> r | None -> x
let close ~wrt r = let close ~wrt r =
[%Trace.call fun {pf} -> pf "@ %a" pp r] [%Trace.call fun {pf} -> pf "@ %a" pp r]
@ -718,7 +751,7 @@ let and_eq_ ~wrt a b r =
let extract_xs r = (r.xs, {r with xs= Var.Set.empty}) let extract_xs r = (r.xs, {r with xs= Var.Set.empty})
(** Exposed interface *) (* Exposed interface ======================================================*)
let is_empty {sat; rep} = let is_empty {sat; rep} =
sat && Subst.for_alli rep ~f:(fun ~key:a ~data:a' -> Trm.equal a a') sat && Subst.for_alli rep ~f:(fun ~key:a ~data:a' -> Trm.equal a a')
@ -880,7 +913,7 @@ let trms r =
let vars r = Iter.flat_map ~f:Trm.vars (trms r) let vars r = Iter.flat_map ~f:Trm.vars (trms r)
let fv r = Var.Set.of_iter (vars r) let fv r = Var.Set.of_iter (vars r)
(** Existential Witnessing and Elimination *) (* Existential Witnessing and Elimination =================================*)
let subst_invariant us s0 s = let subst_invariant us s0 s =
assert (s0 == s || not (Subst.equal s0 s)) ; assert (s0 == s || not (Subst.equal s0 s)) ;
@ -922,6 +955,19 @@ let solve_poly_eq us p' q' subst =
[%Trace.retn fun {pf} subst' -> [%Trace.retn fun {pf} subst' ->
pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)] pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)]
let rec solve_pending s soln =
match s.pending with
| (a, b) :: pending -> (
let a' = Subst.norm soln a in
let b' = Subst.norm soln b in
match solve_ a' b' {s with pending} with
| {solved= Some solved} as s ->
solve_pending {s with solved= Some []}
(List.fold solved soln ~f:(fun (trm, rep) soln ->
Subst.compose1 ~key:trm ~data:rep soln ))
| {solved= None} -> None )
| [] -> Some soln
let solve_seq_eq ~wrt us e' f' subst = let solve_seq_eq ~wrt us e' f' subst =
[%Trace.call fun {pf} -> pf "@ %a = %a" Trm.pp e' Trm.pp f'] [%Trace.call fun {pf} -> pf "@ %a = %a" Trm.pp e' Trm.pp f']
; ;
@ -935,11 +981,14 @@ let solve_seq_eq ~wrt us e' f' subst =
| Some n -> (a, n) | Some n -> (a, n)
| None -> (Trm.sized ~siz:n ~seq:a, n) | None -> (Trm.sized ~siz:n ~seq:a, n)
in in
let+ _, xs, s = solve_pending
solve_concat ~no_fresh:true ms a n (wrt, Var.Set.empty, subst) (solve_concat ms a n
in { wrt
assert (Var.Set.is_empty xs) ; ; no_fresh= true
s ; fresh= Var.Set.empty
; solved= Some []
; pending= [] })
subst
in in
( match ((e' : Trm.t), (f' : Trm.t)) with ( match ((e' : Trm.t), (f' : Trm.t)) with
| (Concat ms as c), a when x_ito_us c a -> | (Concat ms as c), a when x_ito_us c a ->
@ -1309,9 +1358,7 @@ let apply_and_elim ~wrt xs s r =
let r = trim ks r in let r = trim ks r in
(zs, r, ks) (zs, r, ks)
(* (* Replay debugging =======================================================*)
* Replay debugging
*)
type call = type call =
| Add of Var.Set.t * Formula.t * t | Add of Var.Set.t * Formula.t * t

@ -91,23 +91,24 @@ let%test_module _ =
let%expect_test _ = let%expect_test _ =
replay replay
{|(Solve_for_vars {|(Solve_for_vars
(((Var (id 0) (name 2)) (Var (id 0) (name 6)) (Var (id 0) (name 8))) (((Var (id 0) (name 2)) (Var (id 0) (name 6)) (Var (id 0) (name 8)))
((Var (id 5) (name a0)) (Var (id 6) (name b)) (Var (id 7) (name m)) ((Var (id 5) (name a0)) (Var (id 6) (name b)) (Var (id 7) (name m))
(Var (id 8) (name a)) (Var (id 9) (name a0)))) (Var (id 8) (name a)) (Var (id 9) (name a0))))
((xs ()) (sat true) ((xs ()) (sat true)
(rep (rep
(((Var (id 9) (name a0)) (Var (id 5) (name a0))) (((Var (id 9) (name a0)) (Var (id 5) (name a0)))
((Var (id 8) (name a)) ((Var (id 8) (name a))
(Concat (Concat
((Sized (seq (Var (id 5) (name a0))) (siz (Z 4))) ((Sized (seq (Var (id 5) (name a0))) (siz (Z 4)))
(Sized (seq (Z 0)) (siz (Z 4)))))) (Sized (seq (Z 0)) (siz (Z 4))))))
((Var (id 7) (name m)) (Z 8)) ((Var (id 7) (name m)) (Z 8))
((Var (id 6) (name b)) (Var (id 0) (name 2))) ((Var (id 6) (name b)) (Var (id 0) (name 2)))
((Var (id 5) (name a0)) (Var (id 5) (name a0))) ((Var (id 5) (name a0)) (Var (id 5) (name a0)))
((Var (id 0) (name 6)) ((Var (id 0) (name 6))
(Concat (Concat
((Sized (seq (Var (id 5) (name a0))) (siz (Z 4))) ((Sized (seq (Var (id 5) (name a0))) (siz (Z 4)))
(Sized (seq (Z 0)) (siz (Z 4)))))) (Sized (seq (Z 0)) (siz (Z 4))))))
((Var (id 0) (name 2)) (Var (id 0) (name 2)))))))|} ; ((Var (id 0) (name 2)) (Var (id 0) (name 2)))))
(pnd ())))|} ;
[%expect {| |}] [%expect {| |}]
end ) end )

@ -78,11 +78,11 @@ let%test_module _ =
pp (star p q) ; pp (star p q) ;
[%expect [%expect
{| {|
%x_7 . emp %x_7 . emp
0 = %x_7 emp 0 = %x_7 emp
0 = %x_7 emp |}] 0 = %x_7 emp |}]
let%expect_test _ = let%expect_test _ =
let q = let q =

Loading…
Cancel
Save