You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

259 lines
8.4 KiB

(*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*)
(** Theory Solver *)
(* Theory equation solver state ===========================================*)
type oriented_equality = {var: Trm.t; rep: Trm.t}
type t =
{ wrt: Var.Set.t
; no_fresh: bool
; fresh: Var.Set.t
; solved: oriented_equality list option
; pending: (Trm.t * Trm.t) list }
let pp ppf = function
| {solved= None} -> Format.fprintf ppf "unsat"
| {solved= Some solved; fresh; pending} ->
Format.fprintf ppf "%a%a : %a" Var.Set.pp_xs fresh
(List.pp ";@ " (fun ppf {var; rep} ->
Format.fprintf ppf "@[%a ↦ %a@]" Trm.pp var Trm.pp rep ))
solved
(List.pp ";@ " (fun ppf (a, b) ->
Format.fprintf ppf "@[%a = %a@]" Trm.pp a Trm.pp b ))
pending
(* Classification of terms ================================================*)
type kind = InterpApp | NonInterpAtom | InterpAtom | UninterpApp
[@@deriving compare, equal]
let classify e =
match (e : Trm.t) with
| Var _ -> NonInterpAtom
| Z _ | Q _ -> InterpAtom
| Arith a ->
if Trm.Arith.is_uninterpreted a then UninterpApp
else (
assert (
match Trm.Arith.classify a with
| Trm _ | Const _ -> violates Trm.invariant e
| Interpreted -> true
| Uninterpreted -> false ) ;
InterpApp )
| Concat [||] -> InterpAtom
| Splat _ | Sized _ | Extract _ | Concat _ -> InterpApp
| Apply (_, [||]) -> NonInterpAtom
| Apply _ -> UninterpApp
let is_interpreted e = equal_kind (classify e) InterpApp
let is_uninterpreted e = equal_kind (classify e) UninterpApp
let is_noninterpreted e =
match classify e with
| InterpAtom | InterpApp -> false
| NonInterpAtom | UninterpApp -> true
let rec solvables e =
match classify e with
| InterpAtom -> Iter.empty
| InterpApp -> solvable_trms e
| NonInterpAtom | UninterpApp -> Iter.return e
and solvable_trms e = Iter.flat_map ~f:solvables (Trm.trms e)
let rec map_solvables e ~f =
match classify e with
| InterpAtom -> e
| NonInterpAtom | UninterpApp -> f e
| InterpApp -> Trm.map ~f:(map_solvables ~f) e
(* Solving equations ======================================================*)
(** prefer representative terms that are minimal in the order s.t. Var <
Sized < Extract < Concat < others, then using height of sequence
nesting, and then using Trm.compare *)
let prefer e f =
let rank e =
match (e : Trm.t) with
| Var _ -> 0
| Sized _ -> 1
| Extract _ -> 2
| Concat _ -> 3
| _ -> 4
in
let o = compare (rank e) (rank f) in
if o <> 0 then o
else
let o = compare (Trm.height e) (Trm.height f) in
if o <> 0 then o else Trm.compare e f
(** orient equations based on representative preference *)
let orient e f =
match Sign.of_int (prefer e f) with
| Neg -> Some (e, f)
| Zero -> None
| Pos -> Some (f, e)
let add_solved ~var ~rep s =
match s with
| {solved= None} -> s
| {solved= Some solved} -> {s with solved= Some ({var; rep} :: solved)}
let add_pending a b s = {s with pending= (a, b) :: s.pending}
let fresh name s =
if s.no_fresh then None
else
let x, wrt = Var.fresh name ~wrt:s.wrt in
let fresh = Var.Set.add x s.fresh in
Some (Trm.var x, {s with wrt; fresh})
let solve_poly p q s =
[%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp p Trm.pp q)
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
match Trm.sub p q with
| Z z -> if Z.equal Z.zero z then s else {s with solved= None}
| Var _ as var -> add_solved ~var ~rep:Trm.zero s
| p_q -> (
match Trm.Arith.solve_zero_eq p_q with
| Some (var, rep) ->
add_solved ~var:(Trm.arith var) ~rep:(Trm.arith rep) s
| None -> add_solved ~var:p_q ~rep:Trm.zero s )
(* α[o,l) = β ==> l = |β| ∧ α = (⟨n,c⟩[0,o) ^ β ^ ⟨n,c⟩[o+l,n-o-l)) where n
= |α| and c fresh *)
let solve_extract a o l b s =
[%trace]
~call:(fun {pf} ->
pf "@ %a = %a" Trm.pp (Trm.extract ~seq:a ~off:o ~len:l) Trm.pp b )
~retn:(fun {pf} -> pf "%a" (Option.pp "%a" pp))
@@ fun () ->
let* c, s = fresh "c" s in
let+ n, s =
match Trm.seq_size a with Some n -> Some (n, s) | None -> fresh "n" s
in
let n_c = Trm.sized ~siz:n ~seq:c in
let o_l = Trm.add o l in
let n_o_l = Trm.sub n o_l in
let c0 = Trm.extract ~seq:n_c ~off:Trm.zero ~len:o in
let c1 = Trm.extract ~seq:n_c ~off:o_l ~len:n_o_l in
let b, s =
match Trm.seq_size b with
| None -> (Trm.sized ~siz:l ~seq:b, s)
| Some m -> (b, add_pending l m s)
in
add_pending a (Trm.concat [|c0; b; c1|]) s
(* α₀^…^αᵢ^αⱼ^…^αᵥ = β ==> |α₀^…^αᵥ| = |β| ∧ … ∧ αⱼ = β[n₀+…+nᵢ,nⱼ) ∧ …
where n |α| and m = |β| *)
let solve_concat a0V b m s =
[%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp (Trm.concat a0V) Trm.pp b)
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
let s, n0V =
Iter.fold (Array.to_iter a0V) (s, Trm.zero) ~f:(fun aJ (s, oI) ->
let nJ = Trm.seq_size_exn aJ in
let oJ = Trm.add oI nJ in
let s = add_pending aJ (Trm.extract ~seq:b ~off:oI ~len:nJ) s in
(s, oJ) )
in
add_pending n0V m s
let solve d e s =
[%trace]
~call:(fun {pf} -> pf "@ %a = %a" Trm.pp d Trm.pp e)
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
match orient d e with
(* e' = f' ==> true when e' ≡ f' *)
| None -> s
(* i = j ==> false when i ≠ j *)
| Some (Z _, Z _) | Some (Q _, Q _) -> {s with solved= None}
(*
* Concat
*)
(* ⟨0,a⟩ = β ==> a = β = ⟨⟩ *)
| Some (Sized {siz= n; seq= a}, b) when n == Trm.zero ->
s
|> add_pending a (Trm.concat [||])
|> add_pending b (Trm.concat [||])
| Some (b, Sized {siz= n; seq= a}) when n == Trm.zero ->
s
|> add_pending a (Trm.concat [||])
|> add_pending b (Trm.concat [||])
(* ⟨n,0⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,0⟩[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Sized {siz= n; seq} as b), Concat a0V) when seq == Trm.zero ->
solve_concat a0V b n s
(* ⟨n,e^⟩ = α₀^…^αᵥ ==> … ∧ αⱼ = ⟨n,e^⟩[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Sized {siz= n; seq= Splat _} as b), Concat a0V) ->
solve_concat a0V b n s
| Some ((Var _ as v), (Concat a0V as c)) ->
[sledge] Rework term and arithmetic definitions to avoid recursive modules Summary: Terms include Arithmetic terms, which are polynomials over terms themselves. Monomials are represented as maps from terms (multiplicative factors) to integers (their powers). Polynomials are represented as maps from monomials (indeterminates) to rationals (coefficients). In particular, terms are represented using maps whose keys are terms themselves. This is currently implemented using recursive modules. This diff uses the Comparer-based interface of Maps to express this cycle as recursive *types* rather than recursive *modules*, see the very beginning of trm.ml. The rest of the changes are driven by the need to expose the Arithmetic.t type at toplevel, outside the functor that defines the arithmetic operations, and changes to stage the definition of term and polynomial operations to remove unnecessary recursion. One might hope that these changes are just moving code around, but due to how recursive modules are implemented, this refactoring is motivated by performance profiling. In every cycle between recursive modules, at least one of the modules must be "safe". A "safe" module is one where all exposed values have function type. This allows the compiler to initialize that module with functions that immediately raise an exception, define the other modules using it, and then tie the recursive knot by backpatching the safe module with the actual functions at the end. This implementation works, but has the consequence that the compiler must treat calls to functions of safe recursive modules as indirect calls to unknown functions. This means that they are not inlined or even called by symbol, and instead calling them involves spilling registers if needed, loading their address from memory, calling them by address, and restoring any spilled registers. For operations like Trm.compare that are a handful of instructions on the hot path, this is a significant difference. Since terms are the keys of maps and sets in the core of the first-order equality solver, those map operations are very very hot. Reviewed By: jvillard Differential Revision: D26250533 fbshipit-source-id: f79334c68
4 years ago
if not (Trm.Set.mem v (Trm.fv c :> Trm.Set.t)) then
(* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *)
add_solved ~var:v ~rep:c s
else
(* v = α₀^…^αᵥ ==> ⟨|α₀^…^αᵥ|,v⟩ = α₀^…^αᵥ when v ∈ fv(α₀^…^αᵥ) *)
let m = Trm.seq_size_exn c in
solve_concat a0V (Trm.sized ~siz:m ~seq:v) m s
(* α₀^…^αᵥ = β₀^…^βᵤ ==> … ∧ αⱼ = (β₀^…^βᵤ)[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some (Concat a0V, (Concat _ as c)) ->
solve_concat a0V c (Trm.seq_size_exn c) s
(* α[o,l) = α₀^…^αᵥ ==> … ∧ αⱼ = α[o,l)[n₀+…+nᵢ,nⱼ) ∧ … *)
| Some ((Extract {len= l} as e), Concat a0V) -> solve_concat a0V e l s
(*
* Extract
*)
| Some ((Var _ as v), (Extract {len= l} as e)) ->
[sledge] Rework term and arithmetic definitions to avoid recursive modules Summary: Terms include Arithmetic terms, which are polynomials over terms themselves. Monomials are represented as maps from terms (multiplicative factors) to integers (their powers). Polynomials are represented as maps from monomials (indeterminates) to rationals (coefficients). In particular, terms are represented using maps whose keys are terms themselves. This is currently implemented using recursive modules. This diff uses the Comparer-based interface of Maps to express this cycle as recursive *types* rather than recursive *modules*, see the very beginning of trm.ml. The rest of the changes are driven by the need to expose the Arithmetic.t type at toplevel, outside the functor that defines the arithmetic operations, and changes to stage the definition of term and polynomial operations to remove unnecessary recursion. One might hope that these changes are just moving code around, but due to how recursive modules are implemented, this refactoring is motivated by performance profiling. In every cycle between recursive modules, at least one of the modules must be "safe". A "safe" module is one where all exposed values have function type. This allows the compiler to initialize that module with functions that immediately raise an exception, define the other modules using it, and then tie the recursive knot by backpatching the safe module with the actual functions at the end. This implementation works, but has the consequence that the compiler must treat calls to functions of safe recursive modules as indirect calls to unknown functions. This means that they are not inlined or even called by symbol, and instead calling them involves spilling registers if needed, loading their address from memory, calling them by address, and restoring any spilled registers. For operations like Trm.compare that are a handful of instructions on the hot path, this is a significant difference. Since terms are the keys of maps and sets in the core of the first-order equality solver, those map operations are very very hot. Reviewed By: jvillard Differential Revision: D26250533 fbshipit-source-id: f79334c68
4 years ago
if not (Trm.Set.mem v (Trm.fv e :> Trm.Set.t)) then
(* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *)
add_solved ~var:v ~rep:e s
else
(* v = α[o,l) ==> α[o,l) ↦ ⟨l,v⟩ when v ∈ fv(α[o,l)) *)
add_solved ~var:e ~rep:(Trm.sized ~siz:l ~seq:v) s
(* α[o,l) = β ==> … ∧ α = _^β^_ *)
| Some (Extract {seq= a; off= o; len= l}, e) ->
Option.value (solve_extract a o l e s) ~default:s
(*
* Sized
*)
(* v = ⟨n,a⟩ ==> v = a *)
| Some ((Var _ as v), Sized {seq= a}) -> s |> add_pending v a
(* ⟨n,a⟩ = ⟨m,b⟩ ==> n = m ∧ a = β *)
| Some (Sized {siz= n; seq= a}, Sized {siz= m; seq= b}) ->
s |> add_pending n m |> add_pending a b
(* ⟨n,a⟩ = β ==> n = |β| ∧ a = β *)
| Some (Sized {siz= n; seq= a}, b) ->
s
|> Option.fold ~f:(add_pending n) (Trm.seq_size b)
|> add_pending a b
(*
* Splat
*)
(* a^ = b^ ==> a = b *)
| Some (Splat a, Splat b) -> s |> add_pending a b
(*
* Arithmetic
*)
(* p = q ==> p-q = 0 *)
| Some (((Arith _ | Z _ | Q _) as p), q | q, ((Arith _ | Z _ | Q _) as p))
->
solve_poly p q s
(*
* Uninterpreted
*)
(* r = v ==> v ↦ r *)
| Some (rep, var) ->
assert (is_noninterpreted var) ;
assert (is_noninterpreted rep) ;
add_solved ~var ~rep s