[sledge] Refactor theory solver into separate module

Summary:
The solver for single equations no longer needs access to the internal
representation of the context.

Reviewed By: jvillard

Differential Revision: D25883722

fbshipit-source-id: 4ccd97674
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent 757e44ca50
commit aeba96a3c7

@ -9,31 +9,6 @@
open Exp
(* Classification of Terms ================================================*)
type kind = Interpreted | Atomic | Uninterpreted
[@@deriving compare, equal]
let classify e =
match (e : Trm.t) with
| Var _ | Z _ | Q _ | Concat [||] | Apply (_, [||]) -> Atomic
| Arith a -> (
match Trm.Arith.classify a with
| Trm _ | Const _ -> violates Trm.invariant e
| Interpreted -> Interpreted
| Uninterpreted -> Uninterpreted )
| Splat _ | Sized _ | Extract _ | Concat _ -> Interpreted
| Apply _ -> Uninterpreted
let is_interpreted e = equal_kind (classify e) Interpreted
let is_uninterpreted e = equal_kind (classify e) Uninterpreted
let rec max_solvables e =
if not (is_interpreted e) then Iter.return e
else Iter.flat_map ~f:max_solvables (Trm.trms e)
let fold_max_solvables e s ~f = Iter.fold ~f (max_solvables e) s
(* Solution Substitutions =================================================*)
module Subst : sig
@ -96,7 +71,7 @@ end = struct
~call:(fun {pf} -> pf "@ %a" Trm.pp a)
~retn:(fun {pf} -> pf "%a" Trm.pp)
@@ fun () ->
if is_interpreted a then Trm.map ~f:(norm s) a else apply s a
if Theory.is_interpreted a then Trm.map ~f:(norm s) a else apply s a
(** compose two substitutions *)
let compose r s =
@ -159,8 +134,10 @@ end = struct
in
( is_var_in xs e
|| is_var_in xs f
|| (is_uninterpreted e && Iter.exists ~f:(is_var_in xs) (Trm.trms e))
|| (is_uninterpreted f && Iter.exists ~f:(is_var_in xs) (Trm.trms f)) )
|| Theory.is_uninterpreted e
&& Iter.exists ~f:(is_var_in xs) (Trm.trms e)
|| Theory.is_uninterpreted f
&& Iter.exists ~f:(is_var_in xs) (Trm.trms f) )
$> fun b ->
[%Trace.info
"is_valid_eq %a%a=%a = %b" Var.Set.pp_xs xs Trm.pp e Trm.pp f b]
@ -200,213 +177,6 @@ end = struct
let remove = Trm.Map.remove
end
(* Theory Solver ==========================================================*)
(** prefer representative terms that are minimal in the order s.t. Var <
Sized < Extract < Concat < others, then using height of sequence
nesting, and then using Trm.compare *)
let prefer e f =
let rank e =
match (e : Trm.t) with
| Var _ -> 0
| Sized _ -> 1
| Extract _ -> 2
| Concat _ -> 3
| _ -> 4
in
let o = compare (rank e) (rank f) in
if o <> 0 then o
else
let o = compare (Trm.height e) (Trm.height f) in
if o <> 0 then o else Trm.compare e f
(** orient equations based on representative preference *)
let orient e f =
match Sign.of_int (prefer e f) with
| Neg -> Some (e, f)
| Zero -> None
| Pos -> Some (f, e)
type solve_state =
{ wrt: Var.Set.t
; no_fresh: bool
; fresh: Var.Set.t
; solved: (Trm.t * Trm.t) list option
; pending: (Trm.t * Trm.t) list }
let pp_solve_state ppf = function
| {solved= None} -> Format.fprintf ppf "unsat"
| {solved= Some solved; fresh; pending} ->
Format.fprintf ppf "%a%a : %a" Var.Set.pp_xs fresh
(List.pp ";@ " (fun ppf (var, rep) ->
Format.fprintf ppf "@[%a ↦ %a@]" Trm.pp var Trm.pp rep ))
solved
(List.pp ";@ " (fun ppf (a, b) ->
Format.fprintf ppf "@[%a = %a@]" Trm.pp a Trm.pp b ))
pending
let add_solved ~var ~rep s =
match s with
| {solved= None} -> s
| {solved= Some solved} -> {s with solved= Some ((var, rep) :: solved)}
let add_pending a b s = {s with pending= (a, b) :: s.pending}
let fresh name s =
if s.no_fresh then None
else
let x, wrt = Var.fresh name ~wrt:s.wrt in
let fresh = Var.Set.add x s.fresh in
Some (Trm.var x, {s with wrt; fresh})
let solve_poly p q s =
[%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp p Trm.pp q)
~retn:(fun {pf} -> pf "%a" pp_solve_state)
@@ fun () ->
match Trm.sub p q with
| Z z -> if Z.equal Z.zero z then s else {s with solved= None}
| Var _ as var -> add_solved ~var ~rep:Trm.zero s
| p_q -> (
match Trm.Arith.solve_zero_eq p_q with
| Some (var, rep) ->
add_solved ~var:(Trm.arith var) ~rep:(Trm.arith rep) s
| None -> add_solved ~var:p_q ~rep:Trm.zero s )
(* α[o,l) = β ==> l = |β| ∧ α = (⟨n,c⟩[0,o) ^ β ^ ⟨n,c⟩[o+l,n-o-l)) where n
= |α| and c fresh *)
let solve_extract a o l b s =
[%trace]
~call:(fun {pf} ->
pf "@ %a = %a" Trm.pp (Trm.extract ~seq:a ~off:o ~len:l) Trm.pp b )
~retn:(fun {pf} -> pf "%a" pp_solve_state)
@@ fun () ->
match fresh "c" s with
| None -> s
| Some (c, s) ->
let n = Trm.seq_size_exn a in
let n_c = Trm.sized ~siz:n ~seq:c in
let o_l = Trm.add o l in
let n_o_l = Trm.sub n o_l in
let c0 = Trm.extract ~seq:n_c ~off:Trm.zero ~len:o in
let c1 = Trm.extract ~seq:n_c ~off:o_l ~len:n_o_l in
let b, s =
match Trm.seq_size b with
| None -> (Trm.sized ~siz:l ~seq:b, s)
| Some m -> (b, add_pending l m s)
in
add_pending a (Trm.concat [|c0; b; c1|]) s
(* α₀^…^αᵢ^αⱼ^…^αᵥ = β ==> |α₀^…^αᵥ| = |β| ∧ … ∧ αⱼ = β[n₀+…+nᵢ,nⱼ) ∧ …
where n |α| and m = |β| *)
let solve_concat a0V b m s =
[%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp (Trm.concat a0V) Trm.pp b)
~retn:(fun {pf} -> pf "%a" pp_solve_state)
@@ fun () ->
let s, n0V =
Iter.fold (Array.to_iter a0V) (s, Trm.zero) ~f:(fun aJ (s, oI) ->
let nJ = Trm.seq_size_exn aJ in
let oJ = Trm.add oI nJ in
let s = add_pending aJ (Trm.extract ~seq:b ~off:oI ~len:nJ) s in
(s, oJ) )
in
add_pending n0V m s
let solve_ d e s =
[%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp d Trm.pp e)
~retn:(fun {pf} -> pf "%a" pp_solve_state)
@@ fun () ->
match orient d e with
(* e' = f' ==> true when e' ≡ f' *)
| None -> s
(* i = j ==> false when i ≠ j *)
| Some (Z _, Z _) | Some (Q _, Q _) -> {s with solved= None}
(*
* Concat
*)
(* ⟨0,a⟩ = β ==> a = β = ⟨⟩ *)
| Some (Sized {siz= n; seq= a}, b) when n == Trm.zero ->
s
|> add_pending a (Trm.concat [||])
|> add_pending b (Trm.concat [||])
| Some (b, Sized {siz= n; seq= a}) when n == Trm.zero ->
s
|> add_pending a (Trm.concat [||])
|> add_pending b (Trm.concat [||])
(* ⟨n,0⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,0⟩[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Sized {siz= n; seq} as b), Concat a0V) when seq == Trm.zero ->
solve_concat a0V b n s
(* ⟨n,e^⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,e^⟩[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Sized {siz= n; seq= Splat _} as b), Concat a0V) ->
solve_concat a0V b n s
| Some ((Var _ as v), (Concat a0V as c)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv c)) then
(* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *)
add_solved ~var:v ~rep:c s
else
(* v = α₀^…^αᵥ ==> ⟨|α₀^…^αᵥ|,v⟩ = α₀^…^αᵥ when v ∈ fv(α₀^…^αᵥ) *)
let m = Trm.seq_size_exn c in
solve_concat a0V (Trm.sized ~siz:m ~seq:v) m s
(* α₀^…^αᵥ = β₀^…^βᵤ ==> … ∧ αⱼ = (β₀^…^βᵤ)[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some (Concat a0V, (Concat _ as c)) ->
solve_concat a0V c (Trm.seq_size_exn c) s
(* α[o,l) = α₀^…^αᵥ ==> … ∧ αⱼ = α[o,l)[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Extract {len= l} as e), Concat a0V) -> solve_concat a0V e l s
(*
* Extract
*)
| Some ((Var _ as v), (Extract {len= l} as e)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv e)) then
(* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *)
add_solved ~var:v ~rep:e s
else
(* v = α[o,l) ==> α[o,l) ↦ ⟨l,v⟩ when v ∈ fv(α[o,l)) *)
add_solved ~var:e ~rep:(Trm.sized ~siz:l ~seq:v) s
(* α[o,l) = β ==> … ∧ α = _^β^_ *)
| Some (Extract {seq= a; off= o; len= l}, e) -> solve_extract a o l e s
(*
* Sized
*)
(* v = ⟨n,a⟩ ==> v = a *)
| Some ((Var _ as v), Sized {seq= a}) -> s |> add_pending v a
(* ⟨n,a⟩ = ⟨m,b⟩ ==> n = m ∧ a = β *)
| Some (Sized {siz= n; seq= a}, Sized {siz= m; seq= b}) ->
s |> add_pending n m |> add_pending a b
(* ⟨n,a⟩ = β ==> n = |β| ∧ a = β *)
| Some (Sized {siz= n; seq= a}, b) ->
s
|> Option.fold ~f:(add_pending n) (Trm.seq_size b)
|> add_pending a b
(*
* Splat
*)
(* a^ = b^ ==> a = b *)
| Some (Splat a, Splat b) -> s |> add_pending a b
(*
* Arithmetic
*)
(* p = q ==> p-q = 0 *)
| Some (((Arith _ | Z _ | Q _) as p), q | q, ((Arith _ | Z _ | Q _) as p))
->
solve_poly p q s
(*
* Uninterpreted
*)
(* r = v ==> v ↦ r *)
| Some (rep, var) ->
assert (not (is_interpreted var)) ;
assert (not (is_interpreted rep)) ;
add_solved ~var ~rep s
let solve ~wrt ~xs d e pending =
[%trace]
~call:(fun {pf} -> pf "@ %a@ %a" Trm.pp d Trm.pp e)
~retn:(fun {pf} -> pf "%a" pp_solve_state)
@@ fun () ->
solve_ d e {wrt; no_fresh= false; fresh= xs; solved= Some []; pending}
(* Equality classes =======================================================*)
module Cls : sig
@ -572,11 +342,12 @@ let pre_invariant r =
Subst.iteri r.rep ~f:(fun ~key:trm ~data:rep ->
(* no interpreted terms in carrier *)
assert (
(not (is_interpreted trm)) || fail "non-interp %a" Trm.pp trm () ) ;
(not (Theory.is_interpreted trm))
|| fail "non-interp %a" Trm.pp trm () ) ;
(* carrier is closed under subterms *)
Iter.iter (Trm.trms trm) ~f:(fun subtrm ->
assert (
is_interpreted subtrm
Theory.is_interpreted subtrm
|| (match subtrm with Z _ | Q _ -> true | _ -> false)
|| in_car r subtrm
|| fail "@[subterm %a@ of %a@ not in carrier of@ %a@]" Trm.pp
@ -618,6 +389,14 @@ let propagate1 (trm, rep) x =
let rep = Subst.compose1 ~key:trm ~data:rep x.rep in
{x with rep}
let solve ~wrt ~xs d e pending =
[%trace]
~call:(fun {pf} -> pf "@ %a@ %a" Trm.pp d Trm.pp e)
~retn:(fun {pf} -> pf "%a" Theory.pp)
@@ fun () ->
Theory.solve d e
{wrt; no_fresh= false; fresh= xs; solved= Some []; pending}
let rec propagate ~wrt x =
[%trace]
~call:(fun {pf} -> pf "@ %a" pp_raw x)
@ -658,12 +437,12 @@ let lookup r a =
let rec canon r a =
[%Trace.call fun {pf} -> pf "@ %a" Trm.pp a]
;
( match classify a with
( match Theory.classify a with
| Atomic -> Subst.apply r.rep a
| Interpreted -> Trm.map ~f:(canon r) a
| Uninterpreted -> (
let a' = Trm.map ~f:(canon r) a in
match classify a' with
match Theory.classify a' with
| Atomic -> Subst.apply r.rep a'
| Interpreted -> a'
| Uninterpreted -> lookup r a' ) )
@ -680,7 +459,7 @@ let rec extend_ a r =
match (a : Trm.t) with
| Z _ | Q _ -> r
| _ -> (
if is_interpreted a then Iter.fold ~f:extend_ (Trm.trms a) r
if Theory.is_interpreted a then Iter.fold ~f:extend_ (Trm.trms a) r
else
(* add uninterpreted terms *)
match Subst.extend a r with
@ -793,7 +572,8 @@ let fold_uses_of r t s ~f =
Iter.fold (Trm.trms e) s ~f:(fun sub s ->
fold_ ~f sub (if Trm.equal t sub then f e s else s) )
in
if is_interpreted e then Iter.fold ~f:(fold_ ~f) (Trm.trms e) s else s
if Theory.is_interpreted e then Iter.fold ~f:(fold_ ~f) (Trm.trms e) s
else s
in
Subst.fold r.rep s ~f:(fun ~key:trm ~data:rep s ->
fold_ ~f trm (fold_ ~f rep s) )
@ -915,6 +695,12 @@ let fv r = Var.Set.of_iter (vars r)
(* Existential Witnessing and Elimination =================================*)
let rec max_solvables e =
if not (Theory.is_interpreted e) then Iter.return e
else Iter.flat_map ~f:max_solvables (Trm.trms e)
let fold_max_solvables e s ~f = Iter.fold ~f (max_solvables e) s
let subst_invariant us s0 s =
assert (s0 == s || not (Subst.equal s0 s)) ;
assert (
@ -955,12 +741,12 @@ let solve_poly_eq us p' q' subst =
[%Trace.retn fun {pf} subst' ->
pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)]
let rec solve_pending s soln =
let rec solve_pending (s : Theory.t) soln =
match s.pending with
| (a, b) :: pending -> (
let a' = Subst.norm soln a in
let b' = Subst.norm soln b in
match solve_ a' b' {s with pending} with
match Theory.solve a' b' {s with pending} with
| {solved= Some solved} as s ->
solve_pending {s with solved= Some []}
(List.fold solved soln ~f:(fun (trm, rep) soln ->
@ -982,7 +768,7 @@ let solve_seq_eq us e' f' subst =
| None -> (Trm.sized ~siz:n ~seq:a, n)
in
solve_pending
(solve_concat ms a n
(Theory.solve_concat ms a n
{ wrt= Var.Set.empty
; no_fresh= true
; fresh= Var.Set.empty
@ -1032,7 +818,7 @@ let rec solve_interp_eqs us (cls, subst) =
| None -> (cls', subst)
| Some (trm, cls) ->
let trm' = Subst.norm subst trm in
if is_interpreted trm' then
if Theory.is_interpreted trm' then
match solve_interp_eq us trm' (cls, subst) with
| Some subst -> solve_interp_eqs_ cls' (cls, subst)
| None -> solve_interp_eqs_ (Cls.add trm' cls') (cls, subst)
@ -1055,7 +841,7 @@ type cls_solve_state =
let dom_trm e =
match (e : Trm.t) with
| Sized {seq= Var _ as v} -> Some v
| _ when not (is_interpreted e) -> Some e
| _ when not (Theory.is_interpreted e) -> Some e
| _ -> None
(** move equations from [cls] (which is assumed to be normalized by [subst])
@ -1067,7 +853,9 @@ let solve_uninterp_eqs us (cls, subst) =
pf "@ cls: @[%a@]@ subst: @[%a@]" Cls.pp cls Subst.pp subst]
;
let compare e f =
[%compare: kind * Trm.t] (classify e, e) (classify f, f)
[%compare: Theory.kind * Trm.t]
(Theory.classify e, e)
(Theory.classify f, f)
in
let {rep_us; cls_us; rep_xs; cls_xs} =
Cls.fold cls
@ -1329,7 +1117,7 @@ let trim ks r =
let keep = Trm.Set.diff cls drop in
match
Trm.Set.reduce keep ~f:(fun x y ->
if prefer x y < 0 then x else y )
if Theory.prefer x y < 0 then x else y )
with
| Some rep' ->
(* add mappings from each keeper to the new representative *)

@ -11,7 +11,9 @@
Functions that return contexts that might be stronger than their
argument contexts accept and return a set of variables. The input set is
the variables with which any generated variables must be chosen fresh,
and the output set is the variables that have been generated. *)
and the output set is the variables that have been generated. If the
empty set is given, then no fresh variables are generated and equations
that cannot be solved without generating fresh variables are dropped. *)
open Exp

@ -0,0 +1,229 @@
(*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*)
(** Theory Solver *)
(* Theory equation solver state ===========================================*)
type t =
{ wrt: Var.Set.t
; no_fresh: bool
; fresh: Var.Set.t
; solved: (Trm.t * Trm.t) list option
; pending: (Trm.t * Trm.t) list }
let pp ppf = function
| {solved= None} -> Format.fprintf ppf "unsat"
| {solved= Some solved; fresh; pending} ->
Format.fprintf ppf "%a%a : %a" Var.Set.pp_xs fresh
(List.pp ";@ " (fun ppf (var, rep) ->
Format.fprintf ppf "@[%a ↦ %a@]" Trm.pp var Trm.pp rep ))
solved
(List.pp ";@ " (fun ppf (a, b) ->
Format.fprintf ppf "@[%a = %a@]" Trm.pp a Trm.pp b ))
pending
(* Classification of terms ================================================*)
type kind = Interpreted | Atomic | Uninterpreted
[@@deriving compare, equal]
let classify e =
match (e : Trm.t) with
| Var _ | Z _ | Q _ | Concat [||] | Apply (_, [||]) -> Atomic
| Arith a -> (
match Trm.Arith.classify a with
| Trm _ | Const _ -> violates Trm.invariant e
| Interpreted -> Interpreted
| Uninterpreted -> Uninterpreted )
| Splat _ | Sized _ | Extract _ | Concat _ -> Interpreted
| Apply _ -> Uninterpreted
let is_interpreted e = equal_kind (classify e) Interpreted
let is_uninterpreted e = equal_kind (classify e) Uninterpreted
(* Solving equations ======================================================*)
(** prefer representative terms that are minimal in the order s.t. Var <
Sized < Extract < Concat < others, then using height of sequence
nesting, and then using Trm.compare *)
let prefer e f =
let rank e =
match (e : Trm.t) with
| Var _ -> 0
| Sized _ -> 1
| Extract _ -> 2
| Concat _ -> 3
| _ -> 4
in
let o = compare (rank e) (rank f) in
if o <> 0 then o
else
let o = compare (Trm.height e) (Trm.height f) in
if o <> 0 then o else Trm.compare e f
(** orient equations based on representative preference *)
let orient e f =
match Sign.of_int (prefer e f) with
| Neg -> Some (e, f)
| Zero -> None
| Pos -> Some (f, e)
let add_solved ~var ~rep s =
match s with
| {solved= None} -> s
| {solved= Some solved} -> {s with solved= Some ((var, rep) :: solved)}
let add_pending a b s = {s with pending= (a, b) :: s.pending}
let fresh name s =
if s.no_fresh then None
else
let x, wrt = Var.fresh name ~wrt:s.wrt in
let fresh = Var.Set.add x s.fresh in
Some (Trm.var x, {s with wrt; fresh})
let solve_poly p q s =
[%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp p Trm.pp q)
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
match Trm.sub p q with
| Z z -> if Z.equal Z.zero z then s else {s with solved= None}
| Var _ as var -> add_solved ~var ~rep:Trm.zero s
| p_q -> (
match Trm.Arith.solve_zero_eq p_q with
| Some (var, rep) ->
add_solved ~var:(Trm.arith var) ~rep:(Trm.arith rep) s
| None -> add_solved ~var:p_q ~rep:Trm.zero s )
(* α[o,l) = β ==> l = |β| ∧ α = (⟨n,c⟩[0,o) ^ β ^ ⟨n,c⟩[o+l,n-o-l)) where n
= |α| and c fresh *)
let solve_extract a o l b s =
[%trace]
~call:(fun {pf} ->
pf "@ %a = %a" Trm.pp (Trm.extract ~seq:a ~off:o ~len:l) Trm.pp b )
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
match fresh "c" s with
| None -> s
| Some (c, s) ->
let n = Trm.seq_size_exn a in
let n_c = Trm.sized ~siz:n ~seq:c in
let o_l = Trm.add o l in
let n_o_l = Trm.sub n o_l in
let c0 = Trm.extract ~seq:n_c ~off:Trm.zero ~len:o in
let c1 = Trm.extract ~seq:n_c ~off:o_l ~len:n_o_l in
let b, s =
match Trm.seq_size b with
| None -> (Trm.sized ~siz:l ~seq:b, s)
| Some m -> (b, add_pending l m s)
in
add_pending a (Trm.concat [|c0; b; c1|]) s
(* α₀^…^αᵢ^αⱼ^…^αᵥ = β ==> |α₀^…^αᵥ| = |β| ∧ … ∧ αⱼ = β[n₀+…+nᵢ,nⱼ) ∧ …
where n |α| and m = |β| *)
let solve_concat a0V b m s =
[%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp (Trm.concat a0V) Trm.pp b)
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
let s, n0V =
Iter.fold (Array.to_iter a0V) (s, Trm.zero) ~f:(fun aJ (s, oI) ->
let nJ = Trm.seq_size_exn aJ in
let oJ = Trm.add oI nJ in
let s = add_pending aJ (Trm.extract ~seq:b ~off:oI ~len:nJ) s in
(s, oJ) )
in
add_pending n0V m s
let solve d e s =
[%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp d Trm.pp e)
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
match orient d e with
(* e' = f' ==> true when e' ≡ f' *)
| None -> s
(* i = j ==> false when i ≠ j *)
| Some (Z _, Z _) | Some (Q _, Q _) -> {s with solved= None}
(*
* Concat
*)
(* ⟨0,a⟩ = β ==> a = β = ⟨⟩ *)
| Some (Sized {siz= n; seq= a}, b) when n == Trm.zero ->
s
|> add_pending a (Trm.concat [||])
|> add_pending b (Trm.concat [||])
| Some (b, Sized {siz= n; seq= a}) when n == Trm.zero ->
s
|> add_pending a (Trm.concat [||])
|> add_pending b (Trm.concat [||])
(* ⟨n,0⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,0⟩[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Sized {siz= n; seq} as b), Concat a0V) when seq == Trm.zero ->
solve_concat a0V b n s
(* ⟨n,e^⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,e^⟩[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Sized {siz= n; seq= Splat _} as b), Concat a0V) ->
solve_concat a0V b n s
| Some ((Var _ as v), (Concat a0V as c)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv c)) then
(* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *)
add_solved ~var:v ~rep:c s
else
(* v = α₀^…^αᵥ ==> ⟨|α₀^…^αᵥ|,v⟩ = α₀^…^αᵥ when v ∈ fv(α₀^…^αᵥ) *)
let m = Trm.seq_size_exn c in
solve_concat a0V (Trm.sized ~siz:m ~seq:v) m s
(* α₀^…^αᵥ = β₀^…^βᵤ ==> … ∧ αⱼ = (β₀^…^βᵤ)[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some (Concat a0V, (Concat _ as c)) ->
solve_concat a0V c (Trm.seq_size_exn c) s
(* α[o,l) = α₀^…^αᵥ ==> … ∧ αⱼ = α[o,l)[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Extract {len= l} as e), Concat a0V) -> solve_concat a0V e l s
(*
* Extract
*)
| Some ((Var _ as v), (Extract {len= l} as e)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv e)) then
(* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *)
add_solved ~var:v ~rep:e s
else
(* v = α[o,l) ==> α[o,l) ↦ ⟨l,v⟩ when v ∈ fv(α[o,l)) *)
add_solved ~var:e ~rep:(Trm.sized ~siz:l ~seq:v) s
(* α[o,l) = β ==> … ∧ α = _^β^_ *)
| Some (Extract {seq= a; off= o; len= l}, e) -> solve_extract a o l e s
(*
* Sized
*)
(* v = ⟨n,a⟩ ==> v = a *)
| Some ((Var _ as v), Sized {seq= a}) -> s |> add_pending v a
(* ⟨n,a⟩ = ⟨m,b⟩ ==> n = m ∧ a = β *)
| Some (Sized {siz= n; seq= a}, Sized {siz= m; seq= b}) ->
s |> add_pending n m |> add_pending a b
(* ⟨n,a⟩ = β ==> n = |β| ∧ a = β *)
| Some (Sized {siz= n; seq= a}, b) ->
s
|> Option.fold ~f:(add_pending n) (Trm.seq_size b)
|> add_pending a b
(*
* Splat
*)
(* a^ = b^ ==> a = b *)
| Some (Splat a, Splat b) -> s |> add_pending a b
(*
* Arithmetic
*)
(* p = q ==> p-q = 0 *)
| Some (((Arith _ | Z _ | Q _) as p), q | q, ((Arith _ | Z _ | Q _) as p))
->
solve_poly p q s
(*
* Uninterpreted
*)
(* r = v ==> v ↦ r *)
| Some (rep, var) ->
assert (not (is_interpreted var)) ;
assert (not (is_interpreted rep)) ;
add_solved ~var ~rep s

@ -0,0 +1,27 @@
(*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*)
(** Theory Solver *)
type t =
{ wrt: Var.Set.t
; no_fresh: bool
; fresh: Var.Set.t
; solved: (Trm.t * Trm.t) list option
; pending: (Trm.t * Trm.t) list }
val pp : t pp
type kind = Interpreted | Atomic | Uninterpreted
[@@deriving compare, equal]
val classify : Trm.t -> kind
val is_interpreted : Trm.t -> bool
val is_uninterpreted : Trm.t -> bool
val prefer : Trm.t -> Trm.t -> int
val solve_concat : Trm.t array -> Trm.t -> Trm.t -> t -> t
val solve : Trm.t -> Trm.t -> t -> t
Loading…
Cancel
Save