[sledge] Refactor theory solver into separate module

Summary:
The solver for single equations no longer needs access to the internal
representation of the context.

Reviewed By: jvillard

Differential Revision: D25883722

fbshipit-source-id: 4ccd97674
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent 757e44ca50
commit aeba96a3c7

@ -9,31 +9,6 @@
open Exp open Exp
(* Classification of Terms ================================================*)
type kind = Interpreted | Atomic | Uninterpreted
[@@deriving compare, equal]
let classify e =
match (e : Trm.t) with
| Var _ | Z _ | Q _ | Concat [||] | Apply (_, [||]) -> Atomic
| Arith a -> (
match Trm.Arith.classify a with
| Trm _ | Const _ -> violates Trm.invariant e
| Interpreted -> Interpreted
| Uninterpreted -> Uninterpreted )
| Splat _ | Sized _ | Extract _ | Concat _ -> Interpreted
| Apply _ -> Uninterpreted
let is_interpreted e = equal_kind (classify e) Interpreted
let is_uninterpreted e = equal_kind (classify e) Uninterpreted
let rec max_solvables e =
if not (is_interpreted e) then Iter.return e
else Iter.flat_map ~f:max_solvables (Trm.trms e)
let fold_max_solvables e s ~f = Iter.fold ~f (max_solvables e) s
(* Solution Substitutions =================================================*) (* Solution Substitutions =================================================*)
module Subst : sig module Subst : sig
@ -96,7 +71,7 @@ end = struct
~call:(fun {pf} -> pf "@ %a" Trm.pp a) ~call:(fun {pf} -> pf "@ %a" Trm.pp a)
~retn:(fun {pf} -> pf "%a" Trm.pp) ~retn:(fun {pf} -> pf "%a" Trm.pp)
@@ fun () -> @@ fun () ->
if is_interpreted a then Trm.map ~f:(norm s) a else apply s a if Theory.is_interpreted a then Trm.map ~f:(norm s) a else apply s a
(** compose two substitutions *) (** compose two substitutions *)
let compose r s = let compose r s =
@ -159,8 +134,10 @@ end = struct
in in
( is_var_in xs e ( is_var_in xs e
|| is_var_in xs f || is_var_in xs f
|| (is_uninterpreted e && Iter.exists ~f:(is_var_in xs) (Trm.trms e)) || Theory.is_uninterpreted e
|| (is_uninterpreted f && Iter.exists ~f:(is_var_in xs) (Trm.trms f)) ) && Iter.exists ~f:(is_var_in xs) (Trm.trms e)
|| Theory.is_uninterpreted f
&& Iter.exists ~f:(is_var_in xs) (Trm.trms f) )
$> fun b -> $> fun b ->
[%Trace.info [%Trace.info
"is_valid_eq %a%a=%a = %b" Var.Set.pp_xs xs Trm.pp e Trm.pp f b] "is_valid_eq %a%a=%a = %b" Var.Set.pp_xs xs Trm.pp e Trm.pp f b]
@ -200,213 +177,6 @@ end = struct
let remove = Trm.Map.remove let remove = Trm.Map.remove
end end
(* Theory Solver ==========================================================*)
(** prefer representative terms that are minimal in the order s.t. Var <
Sized < Extract < Concat < others, then using height of sequence
nesting, and then using Trm.compare *)
let prefer e f =
let rank e =
match (e : Trm.t) with
| Var _ -> 0
| Sized _ -> 1
| Extract _ -> 2
| Concat _ -> 3
| _ -> 4
in
let o = compare (rank e) (rank f) in
if o <> 0 then o
else
let o = compare (Trm.height e) (Trm.height f) in
if o <> 0 then o else Trm.compare e f
(** orient equations based on representative preference *)
let orient e f =
match Sign.of_int (prefer e f) with
| Neg -> Some (e, f)
| Zero -> None
| Pos -> Some (f, e)
type solve_state =
{ wrt: Var.Set.t
; no_fresh: bool
; fresh: Var.Set.t
; solved: (Trm.t * Trm.t) list option
; pending: (Trm.t * Trm.t) list }
let pp_solve_state ppf = function
| {solved= None} -> Format.fprintf ppf "unsat"
| {solved= Some solved; fresh; pending} ->
Format.fprintf ppf "%a%a : %a" Var.Set.pp_xs fresh
(List.pp ";@ " (fun ppf (var, rep) ->
Format.fprintf ppf "@[%a ↦ %a@]" Trm.pp var Trm.pp rep ))
solved
(List.pp ";@ " (fun ppf (a, b) ->
Format.fprintf ppf "@[%a = %a@]" Trm.pp a Trm.pp b ))
pending
let add_solved ~var ~rep s =
match s with
| {solved= None} -> s
| {solved= Some solved} -> {s with solved= Some ((var, rep) :: solved)}
let add_pending a b s = {s with pending= (a, b) :: s.pending}
let fresh name s =
if s.no_fresh then None
else
let x, wrt = Var.fresh name ~wrt:s.wrt in
let fresh = Var.Set.add x s.fresh in
Some (Trm.var x, {s with wrt; fresh})
let solve_poly p q s =
[%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp p Trm.pp q)
~retn:(fun {pf} -> pf "%a" pp_solve_state)
@@ fun () ->
match Trm.sub p q with
| Z z -> if Z.equal Z.zero z then s else {s with solved= None}
| Var _ as var -> add_solved ~var ~rep:Trm.zero s
| p_q -> (
match Trm.Arith.solve_zero_eq p_q with
| Some (var, rep) ->
add_solved ~var:(Trm.arith var) ~rep:(Trm.arith rep) s
| None -> add_solved ~var:p_q ~rep:Trm.zero s )
(* α[o,l) = β ==> l = |β| ∧ α = (⟨n,c⟩[0,o) ^ β ^ ⟨n,c⟩[o+l,n-o-l)) where n
= |α| and c fresh *)
let solve_extract a o l b s =
[%trace]
~call:(fun {pf} ->
pf "@ %a = %a" Trm.pp (Trm.extract ~seq:a ~off:o ~len:l) Trm.pp b )
~retn:(fun {pf} -> pf "%a" pp_solve_state)
@@ fun () ->
match fresh "c" s with
| None -> s
| Some (c, s) ->
let n = Trm.seq_size_exn a in
let n_c = Trm.sized ~siz:n ~seq:c in
let o_l = Trm.add o l in
let n_o_l = Trm.sub n o_l in
let c0 = Trm.extract ~seq:n_c ~off:Trm.zero ~len:o in
let c1 = Trm.extract ~seq:n_c ~off:o_l ~len:n_o_l in
let b, s =
match Trm.seq_size b with
| None -> (Trm.sized ~siz:l ~seq:b, s)
| Some m -> (b, add_pending l m s)
in
add_pending a (Trm.concat [|c0; b; c1|]) s
(* α₀^…^αᵢ^αⱼ^…^αᵥ = β ==> |α₀^…^αᵥ| = |β| ∧ … ∧ αⱼ = β[n₀+…+nᵢ,nⱼ) ∧ …
where n |α| and m = |β| *)
let solve_concat a0V b m s =
[%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp (Trm.concat a0V) Trm.pp b)
~retn:(fun {pf} -> pf "%a" pp_solve_state)
@@ fun () ->
let s, n0V =
Iter.fold (Array.to_iter a0V) (s, Trm.zero) ~f:(fun aJ (s, oI) ->
let nJ = Trm.seq_size_exn aJ in
let oJ = Trm.add oI nJ in
let s = add_pending aJ (Trm.extract ~seq:b ~off:oI ~len:nJ) s in
(s, oJ) )
in
add_pending n0V m s
let solve_ d e s =
[%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp d Trm.pp e)
~retn:(fun {pf} -> pf "%a" pp_solve_state)
@@ fun () ->
match orient d e with
(* e' = f' ==> true when e' ≡ f' *)
| None -> s
(* i = j ==> false when i ≠ j *)
| Some (Z _, Z _) | Some (Q _, Q _) -> {s with solved= None}
(*
* Concat
*)
(* ⟨0,a⟩ = β ==> a = β = ⟨⟩ *)
| Some (Sized {siz= n; seq= a}, b) when n == Trm.zero ->
s
|> add_pending a (Trm.concat [||])
|> add_pending b (Trm.concat [||])
| Some (b, Sized {siz= n; seq= a}) when n == Trm.zero ->
s
|> add_pending a (Trm.concat [||])
|> add_pending b (Trm.concat [||])
(* ⟨n,0⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,0⟩[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Sized {siz= n; seq} as b), Concat a0V) when seq == Trm.zero ->
solve_concat a0V b n s
(* ⟨n,e^⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,e^⟩[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Sized {siz= n; seq= Splat _} as b), Concat a0V) ->
solve_concat a0V b n s
| Some ((Var _ as v), (Concat a0V as c)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv c)) then
(* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *)
add_solved ~var:v ~rep:c s
else
(* v = α₀^…^αᵥ ==> ⟨|α₀^…^αᵥ|,v⟩ = α₀^…^αᵥ when v ∈ fv(α₀^…^αᵥ) *)
let m = Trm.seq_size_exn c in
solve_concat a0V (Trm.sized ~siz:m ~seq:v) m s
(* α₀^…^αᵥ = β₀^…^βᵤ ==> … ∧ αⱼ = (β₀^…^βᵤ)[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some (Concat a0V, (Concat _ as c)) ->
solve_concat a0V c (Trm.seq_size_exn c) s
(* α[o,l) = α₀^…^αᵥ ==> … ∧ αⱼ = α[o,l)[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Extract {len= l} as e), Concat a0V) -> solve_concat a0V e l s
(*
* Extract
*)
| Some ((Var _ as v), (Extract {len= l} as e)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv e)) then
(* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *)
add_solved ~var:v ~rep:e s
else
(* v = α[o,l) ==> α[o,l) ↦ ⟨l,v⟩ when v ∈ fv(α[o,l)) *)
add_solved ~var:e ~rep:(Trm.sized ~siz:l ~seq:v) s
(* α[o,l) = β ==> … ∧ α = _^β^_ *)
| Some (Extract {seq= a; off= o; len= l}, e) -> solve_extract a o l e s
(*
* Sized
*)
(* v = ⟨n,a⟩ ==> v = a *)
| Some ((Var _ as v), Sized {seq= a}) -> s |> add_pending v a
(* ⟨n,a⟩ = ⟨m,b⟩ ==> n = m ∧ a = β *)
| Some (Sized {siz= n; seq= a}, Sized {siz= m; seq= b}) ->
s |> add_pending n m |> add_pending a b
(* ⟨n,a⟩ = β ==> n = |β| ∧ a = β *)
| Some (Sized {siz= n; seq= a}, b) ->
s
|> Option.fold ~f:(add_pending n) (Trm.seq_size b)
|> add_pending a b
(*
* Splat
*)
(* a^ = b^ ==> a = b *)
| Some (Splat a, Splat b) -> s |> add_pending a b
(*
* Arithmetic
*)
(* p = q ==> p-q = 0 *)
| Some (((Arith _ | Z _ | Q _) as p), q | q, ((Arith _ | Z _ | Q _) as p))
->
solve_poly p q s
(*
* Uninterpreted
*)
(* r = v ==> v ↦ r *)
| Some (rep, var) ->
assert (not (is_interpreted var)) ;
assert (not (is_interpreted rep)) ;
add_solved ~var ~rep s
let solve ~wrt ~xs d e pending =
[%trace]
~call:(fun {pf} -> pf "@ %a@ %a" Trm.pp d Trm.pp e)
~retn:(fun {pf} -> pf "%a" pp_solve_state)
@@ fun () ->
solve_ d e {wrt; no_fresh= false; fresh= xs; solved= Some []; pending}
(* Equality classes =======================================================*) (* Equality classes =======================================================*)
module Cls : sig module Cls : sig
@ -572,11 +342,12 @@ let pre_invariant r =
Subst.iteri r.rep ~f:(fun ~key:trm ~data:rep -> Subst.iteri r.rep ~f:(fun ~key:trm ~data:rep ->
(* no interpreted terms in carrier *) (* no interpreted terms in carrier *)
assert ( assert (
(not (is_interpreted trm)) || fail "non-interp %a" Trm.pp trm () ) ; (not (Theory.is_interpreted trm))
|| fail "non-interp %a" Trm.pp trm () ) ;
(* carrier is closed under subterms *) (* carrier is closed under subterms *)
Iter.iter (Trm.trms trm) ~f:(fun subtrm -> Iter.iter (Trm.trms trm) ~f:(fun subtrm ->
assert ( assert (
is_interpreted subtrm Theory.is_interpreted subtrm
|| (match subtrm with Z _ | Q _ -> true | _ -> false) || (match subtrm with Z _ | Q _ -> true | _ -> false)
|| in_car r subtrm || in_car r subtrm
|| fail "@[subterm %a@ of %a@ not in carrier of@ %a@]" Trm.pp || fail "@[subterm %a@ of %a@ not in carrier of@ %a@]" Trm.pp
@ -618,6 +389,14 @@ let propagate1 (trm, rep) x =
let rep = Subst.compose1 ~key:trm ~data:rep x.rep in let rep = Subst.compose1 ~key:trm ~data:rep x.rep in
{x with rep} {x with rep}
let solve ~wrt ~xs d e pending =
[%trace]
~call:(fun {pf} -> pf "@ %a@ %a" Trm.pp d Trm.pp e)
~retn:(fun {pf} -> pf "%a" Theory.pp)
@@ fun () ->
Theory.solve d e
{wrt; no_fresh= false; fresh= xs; solved= Some []; pending}
let rec propagate ~wrt x = let rec propagate ~wrt x =
[%trace] [%trace]
~call:(fun {pf} -> pf "@ %a" pp_raw x) ~call:(fun {pf} -> pf "@ %a" pp_raw x)
@ -658,12 +437,12 @@ let lookup r a =
let rec canon r a = let rec canon r a =
[%Trace.call fun {pf} -> pf "@ %a" Trm.pp a] [%Trace.call fun {pf} -> pf "@ %a" Trm.pp a]
; ;
( match classify a with ( match Theory.classify a with
| Atomic -> Subst.apply r.rep a | Atomic -> Subst.apply r.rep a
| Interpreted -> Trm.map ~f:(canon r) a | Interpreted -> Trm.map ~f:(canon r) a
| Uninterpreted -> ( | Uninterpreted -> (
let a' = Trm.map ~f:(canon r) a in let a' = Trm.map ~f:(canon r) a in
match classify a' with match Theory.classify a' with
| Atomic -> Subst.apply r.rep a' | Atomic -> Subst.apply r.rep a'
| Interpreted -> a' | Interpreted -> a'
| Uninterpreted -> lookup r a' ) ) | Uninterpreted -> lookup r a' ) )
@ -680,7 +459,7 @@ let rec extend_ a r =
match (a : Trm.t) with match (a : Trm.t) with
| Z _ | Q _ -> r | Z _ | Q _ -> r
| _ -> ( | _ -> (
if is_interpreted a then Iter.fold ~f:extend_ (Trm.trms a) r if Theory.is_interpreted a then Iter.fold ~f:extend_ (Trm.trms a) r
else else
(* add uninterpreted terms *) (* add uninterpreted terms *)
match Subst.extend a r with match Subst.extend a r with
@ -793,7 +572,8 @@ let fold_uses_of r t s ~f =
Iter.fold (Trm.trms e) s ~f:(fun sub s -> Iter.fold (Trm.trms e) s ~f:(fun sub s ->
fold_ ~f sub (if Trm.equal t sub then f e s else s) ) fold_ ~f sub (if Trm.equal t sub then f e s else s) )
in in
if is_interpreted e then Iter.fold ~f:(fold_ ~f) (Trm.trms e) s else s if Theory.is_interpreted e then Iter.fold ~f:(fold_ ~f) (Trm.trms e) s
else s
in in
Subst.fold r.rep s ~f:(fun ~key:trm ~data:rep s -> Subst.fold r.rep s ~f:(fun ~key:trm ~data:rep s ->
fold_ ~f trm (fold_ ~f rep s) ) fold_ ~f trm (fold_ ~f rep s) )
@ -915,6 +695,12 @@ let fv r = Var.Set.of_iter (vars r)
(* Existential Witnessing and Elimination =================================*) (* Existential Witnessing and Elimination =================================*)
let rec max_solvables e =
if not (Theory.is_interpreted e) then Iter.return e
else Iter.flat_map ~f:max_solvables (Trm.trms e)
let fold_max_solvables e s ~f = Iter.fold ~f (max_solvables e) s
let subst_invariant us s0 s = let subst_invariant us s0 s =
assert (s0 == s || not (Subst.equal s0 s)) ; assert (s0 == s || not (Subst.equal s0 s)) ;
assert ( assert (
@ -955,12 +741,12 @@ let solve_poly_eq us p' q' subst =
[%Trace.retn fun {pf} subst' -> [%Trace.retn fun {pf} subst' ->
pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)] pf "@[%a@]" Subst.pp_diff (subst, Option.value subst' ~default:subst)]
let rec solve_pending s soln = let rec solve_pending (s : Theory.t) soln =
match s.pending with match s.pending with
| (a, b) :: pending -> ( | (a, b) :: pending -> (
let a' = Subst.norm soln a in let a' = Subst.norm soln a in
let b' = Subst.norm soln b in let b' = Subst.norm soln b in
match solve_ a' b' {s with pending} with match Theory.solve a' b' {s with pending} with
| {solved= Some solved} as s -> | {solved= Some solved} as s ->
solve_pending {s with solved= Some []} solve_pending {s with solved= Some []}
(List.fold solved soln ~f:(fun (trm, rep) soln -> (List.fold solved soln ~f:(fun (trm, rep) soln ->
@ -982,7 +768,7 @@ let solve_seq_eq us e' f' subst =
| None -> (Trm.sized ~siz:n ~seq:a, n) | None -> (Trm.sized ~siz:n ~seq:a, n)
in in
solve_pending solve_pending
(solve_concat ms a n (Theory.solve_concat ms a n
{ wrt= Var.Set.empty { wrt= Var.Set.empty
; no_fresh= true ; no_fresh= true
; fresh= Var.Set.empty ; fresh= Var.Set.empty
@ -1032,7 +818,7 @@ let rec solve_interp_eqs us (cls, subst) =
| None -> (cls', subst) | None -> (cls', subst)
| Some (trm, cls) -> | Some (trm, cls) ->
let trm' = Subst.norm subst trm in let trm' = Subst.norm subst trm in
if is_interpreted trm' then if Theory.is_interpreted trm' then
match solve_interp_eq us trm' (cls, subst) with match solve_interp_eq us trm' (cls, subst) with
| Some subst -> solve_interp_eqs_ cls' (cls, subst) | Some subst -> solve_interp_eqs_ cls' (cls, subst)
| None -> solve_interp_eqs_ (Cls.add trm' cls') (cls, subst) | None -> solve_interp_eqs_ (Cls.add trm' cls') (cls, subst)
@ -1055,7 +841,7 @@ type cls_solve_state =
let dom_trm e = let dom_trm e =
match (e : Trm.t) with match (e : Trm.t) with
| Sized {seq= Var _ as v} -> Some v | Sized {seq= Var _ as v} -> Some v
| _ when not (is_interpreted e) -> Some e | _ when not (Theory.is_interpreted e) -> Some e
| _ -> None | _ -> None
(** move equations from [cls] (which is assumed to be normalized by [subst]) (** move equations from [cls] (which is assumed to be normalized by [subst])
@ -1067,7 +853,9 @@ let solve_uninterp_eqs us (cls, subst) =
pf "@ cls: @[%a@]@ subst: @[%a@]" Cls.pp cls Subst.pp subst] pf "@ cls: @[%a@]@ subst: @[%a@]" Cls.pp cls Subst.pp subst]
; ;
let compare e f = let compare e f =
[%compare: kind * Trm.t] (classify e, e) (classify f, f) [%compare: Theory.kind * Trm.t]
(Theory.classify e, e)
(Theory.classify f, f)
in in
let {rep_us; cls_us; rep_xs; cls_xs} = let {rep_us; cls_us; rep_xs; cls_xs} =
Cls.fold cls Cls.fold cls
@ -1329,7 +1117,7 @@ let trim ks r =
let keep = Trm.Set.diff cls drop in let keep = Trm.Set.diff cls drop in
match match
Trm.Set.reduce keep ~f:(fun x y -> Trm.Set.reduce keep ~f:(fun x y ->
if prefer x y < 0 then x else y ) if Theory.prefer x y < 0 then x else y )
with with
| Some rep' -> | Some rep' ->
(* add mappings from each keeper to the new representative *) (* add mappings from each keeper to the new representative *)

@ -11,7 +11,9 @@
Functions that return contexts that might be stronger than their Functions that return contexts that might be stronger than their
argument contexts accept and return a set of variables. The input set is argument contexts accept and return a set of variables. The input set is
the variables with which any generated variables must be chosen fresh, the variables with which any generated variables must be chosen fresh,
and the output set is the variables that have been generated. *) and the output set is the variables that have been generated. If the
empty set is given, then no fresh variables are generated and equations
that cannot be solved without generating fresh variables are dropped. *)
open Exp open Exp

@ -0,0 +1,229 @@
(*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*)
(** Theory Solver *)
(* Theory equation solver state ===========================================*)
type t =
{ wrt: Var.Set.t
; no_fresh: bool
; fresh: Var.Set.t
; solved: (Trm.t * Trm.t) list option
; pending: (Trm.t * Trm.t) list }
let pp ppf = function
| {solved= None} -> Format.fprintf ppf "unsat"
| {solved= Some solved; fresh; pending} ->
Format.fprintf ppf "%a%a : %a" Var.Set.pp_xs fresh
(List.pp ";@ " (fun ppf (var, rep) ->
Format.fprintf ppf "@[%a ↦ %a@]" Trm.pp var Trm.pp rep ))
solved
(List.pp ";@ " (fun ppf (a, b) ->
Format.fprintf ppf "@[%a = %a@]" Trm.pp a Trm.pp b ))
pending
(* Classification of terms ================================================*)
type kind = Interpreted | Atomic | Uninterpreted
[@@deriving compare, equal]
let classify e =
match (e : Trm.t) with
| Var _ | Z _ | Q _ | Concat [||] | Apply (_, [||]) -> Atomic
| Arith a -> (
match Trm.Arith.classify a with
| Trm _ | Const _ -> violates Trm.invariant e
| Interpreted -> Interpreted
| Uninterpreted -> Uninterpreted )
| Splat _ | Sized _ | Extract _ | Concat _ -> Interpreted
| Apply _ -> Uninterpreted
let is_interpreted e = equal_kind (classify e) Interpreted
let is_uninterpreted e = equal_kind (classify e) Uninterpreted
(* Solving equations ======================================================*)
(** prefer representative terms that are minimal in the order s.t. Var <
Sized < Extract < Concat < others, then using height of sequence
nesting, and then using Trm.compare *)
let prefer e f =
let rank e =
match (e : Trm.t) with
| Var _ -> 0
| Sized _ -> 1
| Extract _ -> 2
| Concat _ -> 3
| _ -> 4
in
let o = compare (rank e) (rank f) in
if o <> 0 then o
else
let o = compare (Trm.height e) (Trm.height f) in
if o <> 0 then o else Trm.compare e f
(** orient equations based on representative preference *)
let orient e f =
match Sign.of_int (prefer e f) with
| Neg -> Some (e, f)
| Zero -> None
| Pos -> Some (f, e)
let add_solved ~var ~rep s =
match s with
| {solved= None} -> s
| {solved= Some solved} -> {s with solved= Some ((var, rep) :: solved)}
let add_pending a b s = {s with pending= (a, b) :: s.pending}
let fresh name s =
if s.no_fresh then None
else
let x, wrt = Var.fresh name ~wrt:s.wrt in
let fresh = Var.Set.add x s.fresh in
Some (Trm.var x, {s with wrt; fresh})
let solve_poly p q s =
[%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp p Trm.pp q)
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
match Trm.sub p q with
| Z z -> if Z.equal Z.zero z then s else {s with solved= None}
| Var _ as var -> add_solved ~var ~rep:Trm.zero s
| p_q -> (
match Trm.Arith.solve_zero_eq p_q with
| Some (var, rep) ->
add_solved ~var:(Trm.arith var) ~rep:(Trm.arith rep) s
| None -> add_solved ~var:p_q ~rep:Trm.zero s )
(* α[o,l) = β ==> l = |β| ∧ α = (⟨n,c⟩[0,o) ^ β ^ ⟨n,c⟩[o+l,n-o-l)) where n
= |α| and c fresh *)
let solve_extract a o l b s =
[%trace]
~call:(fun {pf} ->
pf "@ %a = %a" Trm.pp (Trm.extract ~seq:a ~off:o ~len:l) Trm.pp b )
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
match fresh "c" s with
| None -> s
| Some (c, s) ->
let n = Trm.seq_size_exn a in
let n_c = Trm.sized ~siz:n ~seq:c in
let o_l = Trm.add o l in
let n_o_l = Trm.sub n o_l in
let c0 = Trm.extract ~seq:n_c ~off:Trm.zero ~len:o in
let c1 = Trm.extract ~seq:n_c ~off:o_l ~len:n_o_l in
let b, s =
match Trm.seq_size b with
| None -> (Trm.sized ~siz:l ~seq:b, s)
| Some m -> (b, add_pending l m s)
in
add_pending a (Trm.concat [|c0; b; c1|]) s
(* α₀^…^αᵢ^αⱼ^…^αᵥ = β ==> |α₀^…^αᵥ| = |β| ∧ … ∧ αⱼ = β[n₀+…+nᵢ,nⱼ) ∧ …
where n |α| and m = |β| *)
let solve_concat a0V b m s =
[%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp (Trm.concat a0V) Trm.pp b)
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
let s, n0V =
Iter.fold (Array.to_iter a0V) (s, Trm.zero) ~f:(fun aJ (s, oI) ->
let nJ = Trm.seq_size_exn aJ in
let oJ = Trm.add oI nJ in
let s = add_pending aJ (Trm.extract ~seq:b ~off:oI ~len:nJ) s in
(s, oJ) )
in
add_pending n0V m s
let solve d e s =
[%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp d Trm.pp e)
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
match orient d e with
(* e' = f' ==> true when e' ≡ f' *)
| None -> s
(* i = j ==> false when i ≠ j *)
| Some (Z _, Z _) | Some (Q _, Q _) -> {s with solved= None}
(*
* Concat
*)
(* ⟨0,a⟩ = β ==> a = β = ⟨⟩ *)
| Some (Sized {siz= n; seq= a}, b) when n == Trm.zero ->
s
|> add_pending a (Trm.concat [||])
|> add_pending b (Trm.concat [||])
| Some (b, Sized {siz= n; seq= a}) when n == Trm.zero ->
s
|> add_pending a (Trm.concat [||])
|> add_pending b (Trm.concat [||])
(* ⟨n,0⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,0⟩[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Sized {siz= n; seq} as b), Concat a0V) when seq == Trm.zero ->
solve_concat a0V b n s
(* ⟨n,e^⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,e^⟩[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Sized {siz= n; seq= Splat _} as b), Concat a0V) ->
solve_concat a0V b n s
| Some ((Var _ as v), (Concat a0V as c)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv c)) then
(* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *)
add_solved ~var:v ~rep:c s
else
(* v = α₀^…^αᵥ ==> ⟨|α₀^…^αᵥ|,v⟩ = α₀^…^αᵥ when v ∈ fv(α₀^…^αᵥ) *)
let m = Trm.seq_size_exn c in
solve_concat a0V (Trm.sized ~siz:m ~seq:v) m s
(* α₀^…^αᵥ = β₀^…^βᵤ ==> … ∧ αⱼ = (β₀^…^βᵤ)[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some (Concat a0V, (Concat _ as c)) ->
solve_concat a0V c (Trm.seq_size_exn c) s
(* α[o,l) = α₀^…^αᵥ ==> … ∧ αⱼ = α[o,l)[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Extract {len= l} as e), Concat a0V) -> solve_concat a0V e l s
(*
* Extract
*)
| Some ((Var _ as v), (Extract {len= l} as e)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv e)) then
(* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *)
add_solved ~var:v ~rep:e s
else
(* v = α[o,l) ==> α[o,l) ↦ ⟨l,v⟩ when v ∈ fv(α[o,l)) *)
add_solved ~var:e ~rep:(Trm.sized ~siz:l ~seq:v) s
(* α[o,l) = β ==> … ∧ α = _^β^_ *)
| Some (Extract {seq= a; off= o; len= l}, e) -> solve_extract a o l e s
(*
* Sized
*)
(* v = ⟨n,a⟩ ==> v = a *)
| Some ((Var _ as v), Sized {seq= a}) -> s |> add_pending v a
(* ⟨n,a⟩ = ⟨m,b⟩ ==> n = m ∧ a = β *)
| Some (Sized {siz= n; seq= a}, Sized {siz= m; seq= b}) ->
s |> add_pending n m |> add_pending a b
(* ⟨n,a⟩ = β ==> n = |β| ∧ a = β *)
| Some (Sized {siz= n; seq= a}, b) ->
s
|> Option.fold ~f:(add_pending n) (Trm.seq_size b)
|> add_pending a b
(*
* Splat
*)
(* a^ = b^ ==> a = b *)
| Some (Splat a, Splat b) -> s |> add_pending a b
(*
* Arithmetic
*)
(* p = q ==> p-q = 0 *)
| Some (((Arith _ | Z _ | Q _) as p), q | q, ((Arith _ | Z _ | Q _) as p))
->
solve_poly p q s
(*
* Uninterpreted
*)
(* r = v ==> v ↦ r *)
| Some (rep, var) ->
assert (not (is_interpreted var)) ;
assert (not (is_interpreted rep)) ;
add_solved ~var ~rep s

@ -0,0 +1,27 @@
(*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*)
(** Theory Solver *)
type t =
{ wrt: Var.Set.t
; no_fresh: bool
; fresh: Var.Set.t
; solved: (Trm.t * Trm.t) list option
; pending: (Trm.t * Trm.t) list }
val pp : t pp
type kind = Interpreted | Atomic | Uninterpreted
[@@deriving compare, equal]
val classify : Trm.t -> kind
val is_interpreted : Trm.t -> bool
val is_uninterpreted : Trm.t -> bool
val prefer : Trm.t -> Trm.t -> int
val solve_concat : Trm.t array -> Trm.t -> Trm.t -> t -> t
val solve : Trm.t -> Trm.t -> t -> t
Loading…
Cancel
Save