[sledge] Rework term and arithmetic definitions to avoid recursive modules

Summary:
Terms include Arithmetic terms, which are polynomials over terms
themselves. Monomials are represented as maps from
terms (multiplicative factors) to integers (their powers). Polynomials
are represented as maps from monomials (indeterminates) to
rationals (coefficients). In particular, terms are represented using
maps whose keys are terms themselves. This is currently implemented
using recursive modules.

This diff uses the Comparer-based interface of Maps to express this
cycle as recursive *types* rather than recursive *modules*, see the
very beginning of trm.ml. The rest of the changes are driven by the
need to expose the Arithmetic.t type at toplevel, outside the functor
that defines the arithmetic operations, and changes to stage the
definition of term and polynomial operations to remove unnecessary
recursion.

One might hope that these changes are just moving code around, but due
to how recursive modules are implemented, this refactoring is
motivated by performance profiling. In every cycle between recursive
modules, at least one of the modules must be "safe". A "safe" module
is one where all exposed values have function type. This allows the
compiler to initialize that module with functions that immediately
raise an exception, define the other modules using it, and then tie
the recursive knot by backpatching the safe module with the actual
functions at the end. This implementation works, but has the
consequence that the compiler must treat calls to functions of safe
recursive modules as indirect calls to unknown functions. This means
that they are not inlined or even called by symbol, and instead
calling them involves spilling registers if needed, loading their
address from memory, calling them by address, and restoring any
spilled registers. For operations like Trm.compare that are a handful
of instructions on the hot path, this is a significant
difference. Since terms are the keys of maps and sets in the core of
the first-order equality solver, those map operations are very very
hot.

Reviewed By: jvillard

Differential Revision: D26250533

fbshipit-source-id: f79334c68
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent 24ca0666d3
commit 5704b160bf

@ -19,17 +19,27 @@ module Q = struct
include Comparer.Make (Q) include Comparer.Make (Q)
end end
module Representation type ('trm, 'compare_trm) mono = ('trm, int, 'compare_trm) Multiset.t
(Var : Var_intf.S) [@@deriving compare, equal, sexp]
(Trm : INDETERMINATE with type var := Var.t) =
type 'compare_trm compare_mono =
('compare_trm, Int.compare) Multiset.compare
[@@deriving compare, equal, sexp]
type ('trm, 'compare_trm) t =
(('trm, 'compare_trm) mono, Q.t, 'compare_trm compare_mono) Multiset.t
[@@deriving compare, equal, sexp]
module Make (Trm0 : sig
type t [@@deriving equal, sexp]
include Comparer.S with type t := t
end) =
struct struct
module Prod = struct module Prod = Multiset.Make (Trm0) (Int)
include Multiset.Make (Trm) (Int)
include Provide_of_sexp (Trm)
end
module Mono = struct module Mono = struct
type t = Prod.t [@@deriving compare, equal, sexp] type t = Prod.t [@@deriving compare, equal, sexp_of]
let num_den m = Prod.partition m ~f:(fun _ i -> i >= 0) let num_den m = Prod.partition m ~f:(fun _ i -> i >= 0)
@ -79,11 +89,7 @@ struct
Iter.from_iter (fun f -> Prod.iter mono ~f:(fun trm _ -> f trm)) Iter.from_iter (fun f -> Prod.iter mono ~f:(fun trm _ -> f trm))
end end
module Sum = struct module Sum = Multiset.Make (Prod) (Q)
include Multiset.Make (Prod) (Q)
include Provide_of_sexp (Prod)
end
module Poly = Sum module Poly = Sum
include Poly include Poly
@ -104,6 +110,11 @@ struct
(Sum.pp "@ + " pp_coeff_mono) (Sum.pp "@ + " pp_coeff_mono)
poly poly
let trms poly =
Iter.from_iter (fun f ->
Sum.iter poly ~f:(fun mono _ ->
Prod.iter mono ~f:(fun trm _ -> f trm) ) )
(* core invariant *) (* core invariant *)
let mono_invariant mono = let mono_invariant mono =
@ -119,10 +130,14 @@ struct
assert (not (Q.equal Q.zero coeff)) ; assert (not (Q.equal Q.zero coeff)) ;
mono_invariant mono ) mono_invariant mono )
(* embed a term into a polynomial *)
let trm trm = Sum.of_ (Mono.of_ trm 1) Q.one
(* constants *) (* constants *)
let const q = Sum.of_ Mono.one q |> check invariant let const q = Sum.of_ Mono.one q |> check invariant
let zero = const Q.zero |> check (fun p -> assert (Sum.is_empty p)) let zero = const Q.zero |> check (fun p -> assert (Sum.is_empty p))
let one = const Q.one
(* core constructors *) (* core constructors *)
@ -149,7 +164,7 @@ struct
(* projections and embeddings *) (* projections and embeddings *)
type kind = Trm of Trm.t | Const of Q.t | Interpreted | Uninterpreted type kind = Trm of Trm0.t | Const of Q.t | Interpreted | Uninterpreted
let classify poly = let classify poly =
match Sum.classify poly with match Sum.classify poly with
@ -189,7 +204,12 @@ struct
| None -> None | None -> None
end end
module Make (Embed : EMBEDDING with type trm := Trm.t and type t := t) = include S0
module Embed
(Var : Var_intf.S)
(Trm : TRM with type t = Trm0.t with type var := Var.t)
(Embed : EMBEDDING with type trm := Trm0.t with type t := t) =
struct struct
module Mono = struct module Mono = struct
include Mono include Mono
@ -202,16 +222,24 @@ struct
include Poly include Poly
include S0 include S0
(** hide S0.trm and S0.trms that ignore the embedding, shadowed below *)
let[@warning "-32"] trm, trms = ((), ())
let pp = ppx Trm.pp let pp = ppx Trm.pp
(** Embed a polynomial into a term *)
let trm_of_poly = Embed.to_trm
(** Embed a monomial into a term, flattening if possible *) (** Embed a monomial into a term, flattening if possible *)
let trm_of_mono mono = let trm_of_mono mono =
match Mono.get_trm mono with match Mono.get_trm mono with
| Some trm -> trm | Some trm -> trm
| None -> Embed.to_trm (Sum.of_ mono Q.one) | None -> trm_of_poly (Sum.of_ mono Q.one)
(* traverse *) (* traverse *)
let vars poly = Iter.flat_map ~f:Trm.vars (S0.trms poly)
let monos poly = let monos poly =
Iter.from_iter (fun f -> Iter.from_iter (fun f ->
Sum.iter poly ~f:(fun mono _ -> Sum.iter poly ~f:(fun mono _ ->
@ -222,10 +250,6 @@ struct
| Some mono -> Mono.trms mono | Some mono -> Mono.trms mono
| None -> Iter.map ~f:trm_of_mono (monos poly) | None -> Iter.map ~f:trm_of_mono (monos poly)
let vars p = Iter.flat_map ~f:Trm.vars (trms p)
(* invariant *)
let mono_invariant mono = let mono_invariant mono =
mono_invariant mono ; mono_invariant mono ;
let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in
@ -309,7 +333,7 @@ struct
let trm trm = let trm trm =
( match Embed.get_arith trm with ( match Embed.get_arith trm with
| Some poly -> poly | Some poly -> poly
| None -> Sum.of_ (Mono.of_ trm 1) Q.one ) | None -> S0.trm trm )
|> check (fun poly -> |> check (fun poly ->
assert (equal poly (CM.to_poly (CM.of_trm trm))) ) assert (equal poly (CM.to_poly (CM.of_trm trm))) )
@ -322,7 +346,7 @@ struct
let pow base power = CM.to_poly (CM.of_trm base ~power) let pow base power = CM.to_poly (CM.of_trm base ~power)
(* map over [trms] *) (** map over [trms] *)
let map poly ~f = let map poly ~f =
[%trace] [%trace]
~call:(fun {pf} -> pf "@ %a" pp poly) ~call:(fun {pf} -> pf "@ %a" pp poly)
@ -398,7 +422,7 @@ struct
| Some _ as soln -> soln | Some _ as soln -> soln
| None -> solve_poly (Sum.add mono coeff rejected) poly | None -> solve_poly (Sum.add mono coeff rejected) poly
(* solve [0 = e] *) (** solve [0 = e] *)
let solve_zero_eq ?for_ e = let solve_zero_eq ?for_ e =
[%trace] [%trace]
~call:(fun {pf} -> ~call:(fun {pf} ->

@ -9,7 +9,25 @@
include module type of Arithmetic_intf include module type of Arithmetic_intf
module Representation (** Arithmetic terms, e.g. polynomials, polymorphic in the type of
indeterminates. *)
type ('trm, 'cmp) t [@@deriving compare, equal, sexp]
(** Functor that, given a totally ordered type of indeterminate terms,
builds an implementation of the embedding-independent arithmetic
operations, and a functor that, given an embedding of arithmetic terms
into indeterminate terms, builds an implementation of the arithmetic
operations. *)
module Make (Ord : sig
type t [@@deriving equal, sexp]
include Comparer.S with type t := t
end) : sig
include S0 with type t = (Ord.t, Ord.compare) t with type trm := Ord.t
module Embed
(Var : Var_intf.S) (Var : Var_intf.S)
(Indeterminate : INDETERMINATE with type var := Var.t) : (Trm : TRM with type t = Ord.t with type var := Var.t)
REPRESENTATION with type var := Var.t with type trm := Indeterminate.t (_ : EMBEDDING with type trm := Trm.t and type t := t) :
S with type trm := Trm.t with type t = t
end

@ -14,31 +14,29 @@ module type EMBEDDING = sig
val to_trm : t -> trm val to_trm : t -> trm
(** Embedding from [t] to [trm]: [to_trm a] is arithmetic term [a] (** Embedding from [t] to [trm]: [to_trm a] is arithmetic term [a]
embedded in an indeterminate term. *) embedded into an indeterminate term. *)
val get_arith : trm -> t option val get_arith : trm -> t option
(** Partial projection from [trm] to [t]: [get_arith x] is [Some a] if (** Partial projection from [trm] to [t]: [get_arith x] is [Some a] iff
[x = to_trm a]. This is used to flatten indeterminates that are [x = to_trm a]. *)
actually arithmetic for the client, thereby enabling arithmetic
operations to be interpreted more often. *)
end end
(** Indeterminate terms, treated as atomic / variables except when they can (** Indeterminate terms, treated as atomic / variables except when they can
be flattened using [EMBEDDING.get_arith]. *) be flattened using {!EMBEDDING.get_arith}. *)
module type INDETERMINATE = sig module type TRM = sig
type t [@@deriving compare, equal, sexp] include Comparer.S
include Comparer.S with type t := t val pp : t pp
type var type var
val pp : t pp
val vars : t -> var iter val vars : t -> var iter
end end
module type S = sig (** Arithmetic terms, e.g. polynomials [t] over indeterminate terms [trm] *)
module type S0 = sig
type trm type trm
type t [@@deriving compare, equal, sexp] type t [@@deriving compare, equal]
val ppx : trm pp -> t pp val ppx : trm pp -> t pp
@ -48,8 +46,10 @@ module type S = sig
(** [trm x] represents the indeterminate term [x] *) (** [trm x] represents the indeterminate term [x] *)
val const : Q.t -> t val const : Q.t -> t
(** [const q] represents the constant [q] *) (** [const q] represents the rational constant [q] *)
val zero : t
val one : t
val neg : t -> t val neg : t -> t
val add : t -> t -> t val add : t -> t -> t
val sub : t -> t -> t val sub : t -> t -> t
@ -84,6 +84,22 @@ module type S = sig
val is_uninterpreted : t -> bool val is_uninterpreted : t -> bool
(** [is_uninterpreted a] iff [classify a = Uninterpreted] *) (** [is_uninterpreted a] iff [classify a = Uninterpreted] *)
(** Traverse *)
val trms : t -> trm iter
(** [trms a] enumerates the indeterminate terms appearing in [a].
Considering an arithmetic term as a polynomial,
[trms (c × (Σ c × Π
X^p))] is the sequence of terms [X] for each [i] and
[j]. *)
end
(** Arithmetic terms, where an embedding {!EMBEDDING.get_arith} into
indeterminate terms is used to implicitly flatten arithmetic terms that
are embedded into general terms to the underlying arithmetic term. *)
module type S = sig
include S0
(** Construct nonlinear arithmetic terms using the embedding, enabling (** Construct nonlinear arithmetic terms using the embedding, enabling
interpreting associativity, commutatitivity, and unit laws, but not interpreting associativity, commutatitivity, and unit laws, but not
the full nonlinear arithmetic theory. *) the full nonlinear arithmetic theory. *)
@ -95,8 +111,8 @@ module type S = sig
(** Traverse *) (** Traverse *)
val trms : t -> trm iter val trms : t -> trm iter
(** [trms a] is the maximal foreign or noninterpreted proper subterms of (** [trms a] enumerates the maximal foreign or noninterpreted proper
[a]. Considering an arithmetic term as a polynomial, subterms of [a]. Considering an arithmetic term as a polynomial,
[trms (c × (Σ c × Π [trms (c × (Σ c × Π
X^p))] is the sequence of monomials X^p))] is the sequence of monomials
[Π X^p] for each [i]. If the arithmetic [Π X^p] for each [i]. If the arithmetic
@ -115,16 +131,3 @@ module type S = sig
expressed as [e = f] for some monomial subterm [e] of [d]. If [for_] expressed as [e = f] for some monomial subterm [e] of [d]. If [for_]
is passed, then the subterm [e] must be [for_]. *) is passed, then the subterm [e] must be [for_]. *)
end end
(** A type [t] representing arithmetic terms over indeterminates [trm]
together with a functor [Make] that takes an [EMBEDDING] of arithmetic
terms [t] into indeterminates [trm] and produces an implementation of
the primary interface [S]. *)
module type REPRESENTATION = sig
type t [@@deriving compare, equal, sexp]
type var
type trm
module Make (_ : EMBEDDING with type trm := trm and type t := t) :
S with type trm := trm with type t := t
end

@ -123,9 +123,7 @@ end = struct
[not (is_valid_eq xs e f)] implies [not (is_valid_eq ys e f)] for [not (is_valid_eq xs e f)] implies [not (is_valid_eq ys e f)] for
[ys xs]. *) [ys xs]. *)
let is_valid_eq xs e f = let is_valid_eq xs e f =
let is_var_in xs e = let is_var_in xs e = Trm.Set.mem e (xs : Var.Set.t :> Trm.Set.t) in
Option.exists ~f:(fun x -> Var.Set.mem x xs) (Var.of_trm e)
in
let noninterp_with_solvable_var_in xs e = let noninterp_with_solvable_var_in xs e =
is_var_in xs e is_var_in xs e
|| Theory.is_noninterpreted e || Theory.is_noninterpreted e
@ -922,10 +920,11 @@ let trim ks x =
Cls.add rep (Option.value cls0 ~default:Cls.empty) ) ) Cls.add rep (Option.value cls0 ~default:Cls.empty) ) )
in in
(* enumerate expanded classes and update solution subst *) (* enumerate expanded classes and update solution subst *)
let kills = Trm.Set.of_vars ks in
Trm.Map.fold clss x ~f:(fun ~key:a' ~data:ecls x -> Trm.Map.fold clss x ~f:(fun ~key:a' ~data:ecls x ->
(* remove mappings for non-rep class elements to kill *) (* remove mappings for non-rep class elements to kill *)
let keep, drop = Trm.Set.diff_inter (Cls.to_set ecls) kills in let keep, drop =
Trm.Set.diff_inter (Cls.to_set ecls) (ks : Var.Set.t :> Trm.Set.t)
in
if Trm.Set.is_empty drop then x if Trm.Set.is_empty drop then x
else else
let rep = Trm.Set.fold ~f:Subst.remove drop x.rep in let rep = Trm.Set.fold ~f:Subst.remove drop x.rep in

@ -196,7 +196,7 @@ let solve d e s =
| Some ((Sized {siz= n; seq= Splat _} as b), Concat a0V) -> | Some ((Sized {siz= n; seq= Splat _} as b), Concat a0V) ->
solve_concat a0V b n s solve_concat a0V b n s
| Some ((Var _ as v), (Concat a0V as c)) -> | Some ((Var _ as v), (Concat a0V as c)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv c)) then if not (Trm.Set.mem v (Trm.fv c :> Trm.Set.t)) then
(* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *) (* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *)
add_solved ~var:v ~rep:c s add_solved ~var:v ~rep:c s
else else
@ -212,7 +212,7 @@ let solve d e s =
* Extract * Extract
*) *)
| Some ((Var _ as v), (Extract {len= l} as e)) -> | Some ((Var _ as v), (Extract {len= l} as e)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv e)) then if not (Trm.Set.mem v (Trm.fv e :> Trm.Set.t)) then
(* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *) (* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *)
add_solved ~var:v ~rep:e s add_solved ~var:v ~rep:e s
else else

@ -7,45 +7,20 @@
(** Terms *) (** Terms *)
(** Representation of Arithmetic terms *) (* Define term type using polymorphic arithmetic type, with derived compare,
module rec Arith0 : equal, and sexp_of functions *)
(Arithmetic.REPRESENTATION module Trm1 = struct
with type var := Trm.Var1.t type compare [@@deriving compare, equal, sexp]
with type trm := Trm.t) =
Arithmetic.Representation
(Trm.Var1)
(struct
include Trm
include Comparer.Make (Trm)
end)
(** Arithmetic terms *)
and Arith : (Arithmetic.S with type trm := Trm.t with type t = Arith0.t) =
struct
include Arith0
include Make (struct type arith = (t, compare) Arithmetic.t
let to_trm = Trm._Arith
let get_arith (e : Trm.t) =
match e with
| Z z -> Some (Arith.const (Q.of_z z))
| Q q -> Some (Arith.const q)
| Arith a -> Some a
| _ -> None
end)
end
(** Terms, built from variables and applications of function symbols from and t =
various theories. Denote functions from structures to values. *)
and Trm : sig
type t = private
(* variables *) (* variables *)
| Var of {id: int; name: string} | Var of {id: int; name: string [@ignore]}
(* arithmetic *) (* arithmetic *)
| Z of Z.t | Z of Z.t
| Q of Q.t | Q of Q.t
| Arith of Arith.t | Arith of arith
(* sequences (of flexible size) *) (* sequences (of flexible size) *)
| Splat of t | Splat of t
| Sized of {seq: t; siz: t} | Sized of {seq: t; siz: t}
@ -54,50 +29,20 @@ and Trm : sig
(* uninterpreted *) (* uninterpreted *)
| Apply of Funsym.t * t array | Apply of Funsym.t * t array
[@@deriving compare, equal, sexp] [@@deriving compare, equal, sexp]
end
(** Variable terms, represented as a subtype of general terms *) (* Add comparer, needed to instantiate arithmetic and containers *)
module Var1 : sig module Trm2 = struct
type trm := t include Comparer.Counterfeit (Trm1)
include Trm1
end
include Var_intf.S with type t = private trm (* Specialize arithmetic type and define operations using comparer *)
module Arith0 = Arithmetic.Make (Trm2)
val of_ : trm -> t (* Add ppx, defined recursively with Arith0.ppx *)
val of_trm : trm -> t option module Trm3 = struct
end include Trm2
val ppx : Var1.strength -> t pp
val pp : t pp
include Invariant.S with type t := t
val _Var : int -> string -> t
val _Z : Z.t -> t
val _Q : Q.t -> t
val _Arith : Arith.t -> t
val _Splat : t -> t
val _Sized : seq:t -> siz:t -> t
val _Extract : seq:t -> off:t -> len:t -> t
val _Concat : t array -> t
val _Apply : Funsym.t -> t array -> t
val add : t -> t -> t
val sub : t -> t -> t
val seq_size_exn : t -> t
val seq_size : t -> t option
val get_z : t -> Z.t option
val get_q : t -> Q.t option
val vars : t -> Var1.t iter
end = struct
type t =
| Var of {id: int; name: string [@ignore]}
| Z of Z.t
| Q of Q.t
| Arith of Arith.t
| Splat of t
| Sized of {seq: t; siz: t}
| Extract of {seq: t; off: t; len: t}
| Concat of t array
| Apply of Funsym.t * t array
[@@deriving compare, equal, sexp]
(* nul-terminated string value represented by a concatenation *) (* nul-terminated string value represented by a concatenation *)
let string_of_concat xs = let string_of_concat xs =
@ -136,7 +81,7 @@ end = struct
| Some `Anonymous -> Trace.pp_styled `Cyan "_" fs ) | Some `Anonymous -> Trace.pp_styled `Cyan "_" fs )
| Z z -> Trace.pp_styled `Magenta "%a" fs Z.pp z | Z z -> Trace.pp_styled `Magenta "%a" fs Z.pp z
| Q q -> Trace.pp_styled `Magenta "%a" fs Q.pp q | Q q -> Trace.pp_styled `Magenta "%a" fs Q.pp q
| Arith a -> Arith.ppx (ppx strength) fs a | Arith a -> Arith0.ppx (ppx strength) fs a
| Splat x -> pf "%a^" pp x | Splat x -> pf "%a^" pp x
| Sized {seq; siz} -> pf "@<1>⟨%a,%a@<1>⟩" pp siz pp seq | Sized {seq; siz} -> pf "@<1>⟨%a,%a@<1>⟩" pp siz pp seq
| Extract {seq; off; len} -> pf "%a[%a,%a)" pp seq pp off pp len | Extract {seq; off; len} -> pf "%a[%a,%a)" pp seq pp off pp len
@ -157,19 +102,31 @@ end = struct
pp fs trm pp fs trm
let pp = ppx (fun _ -> None) let pp = ppx (fun _ -> None)
let pp_diff fs (x, y) = Format.fprintf fs "-- %a ++ %a" pp x pp y
end
(* Define containers over terms *)
module Set = struct
include Set.Make (Trm3)
include Provide_of_sexp (Trm3)
include Provide_pp (Trm3)
end
module Map = struct
include Map.Make (Trm3)
include Provide_of_sexp (Trm3)
end
(* Define variables as a subtype of terms *)
module Var = struct
open Trm3
(* Define variables as a subtype of terms *)
module Var1 = struct
module V = struct module V = struct
module T = struct
type nonrec t = t [@@deriving compare, equal, sexp] type nonrec t = t [@@deriving compare, equal, sexp]
type strength = t -> [`Universal | `Existential | `Anonymous] option type strength = t -> [`Universal | `Existential | `Anonymous] option
let pp = pp let pp = pp
let ppx = ppx let ppx = ppx
end
include T
let invariant x = let invariant x =
let@ () = Invariant.invariant [%here] x [%sexp_of: t] in let@ () = Invariant.invariant [%here] x [%sexp_of: t] in
@ -182,22 +139,16 @@ end = struct
let name = function Var v -> v.name | x -> violates invariant x let name = function Var v -> v.name | x -> violates invariant x
module Set = struct module Set = struct
module S = NS.Set.Make (T) include Set
include S
include Provide_of_sexp (T)
include Provide_pp (T)
let ppx strength vs = S.pp_full (ppx strength) vs let ppx strength vs = pp_full (ppx strength) vs
let pp_xs fs xs = let pp_xs fs xs =
if not (is_empty xs) then if not (is_empty xs) then
Format.fprintf fs "@<2>∃ @[%a@] .@;<1 2>" pp xs Format.fprintf fs "@<2>∃ @[%a@] .@;<1 2>" pp xs
end end
module Map = struct module Map = Map
include NS.Map.Make (T)
include Provide_of_sexp (T)
end
let fresh name ~wrt = let fresh name ~wrt =
let max = let max =
@ -219,26 +170,44 @@ end = struct
include V include V
module Subst = Subst.Make (V) module Subst = Subst.Make (V)
end end
(* Add definitions needed for arithmetic embedding into terms *)
module Trm = struct
include Trm3
(** Invariant *)
let invariant e = let invariant e =
let@ () = Invariant.invariant [%here] e [%sexp_of: t] in let@ () = Invariant.invariant [%here] e [%sexp_of: t] in
match e with match e with
| Q q -> assert (not (Z.equal Z.one (Q.den q))) | Q q -> assert (not (Z.equal Z.one (Q.den q)))
| Arith a -> ( | Arith a -> (
match Arith.classify a with match Arith0.classify a with
| Trm _ | Const _ -> assert false | Trm _ | Const _ -> assert false
| _ -> () ) | _ -> () )
| _ -> () | _ -> ()
(** Destruct *) (** Traverse *)
let get_z = function Z z -> Some z | _ -> None let rec iter_vars e ~f =
let get_q = function Q q -> Some q | Z z -> Some (Q.of_z z) | _ -> None match e with
| Var _ as v -> f (Var.of_ v)
| Z _ | Q _ -> ()
| Splat x -> iter_vars ~f x
| Sized {seq= x; siz= y} ->
iter_vars ~f x ;
iter_vars ~f y
| Extract {seq= x; off= y; len= z} ->
iter_vars ~f x ;
iter_vars ~f y ;
iter_vars ~f z
| Concat xs | Apply (_, xs) -> Array.iter ~f:(iter_vars ~f) xs
| Arith a -> Iter.iter ~f:(iter_vars ~f) (Arith0.trms a)
(** Construct *) let vars e = Iter.from_labelled_iter (iter_vars e)
let _Var id name = Var {id; name} |> check invariant (** Construct *)
(* statically allocated since they are tested with == *) (* statically allocated since they are tested with == *)
let zero = Z Z.zero |> check invariant let zero = Z Z.zero |> check invariant
@ -253,20 +222,63 @@ end = struct
|> check invariant |> check invariant
let _Arith a = let _Arith a =
( match Arith.classify a with ( match Arith0.classify a with
| Trm e -> e | Trm e -> e
| Const q -> _Q q | Const q -> _Q q
| _ -> Arith a ) | _ -> Arith a )
|> check invariant |> check invariant
end
include Trm
(* Instantiate arithmetic with embedding into terms, yielding full
Arithmetic interface *)
module Arith =
Arith0.Embed (Var) (Trm)
(struct
let to_trm = _Arith
let get_arith e =
match e with
| Z z -> Some (Arith0.const (Q.of_z z))
| Q q -> Some (Arith0.const q)
| Arith a -> Some a
| _ -> None
end)
(* Full Trm definition, using full arithmetic interface *)
(** Destruct *)
let get_z = function Z z -> Some z | _ -> None
let get_q = function Q q -> Some q | Z z -> Some (Q.of_z z) | _ -> None
(** Construct *)
let add x y = _Arith Arith.(add (trm x) (trm y)) (* variables *)
let sub x y = _Arith Arith.(sub (trm x) (trm y))
let var v = (v : Var.t :> t)
(* arithmetic *)
let integer z = _Z z
let rational q = _Q q
let neg x = _Arith Arith.(neg (trm x))
let add x y = _Arith Arith.(add (trm x) (trm y))
let sub x y = _Arith Arith.(sub (trm x) (trm y))
let mulq q x = _Arith Arith.(mulc q (trm x))
let mul x y = _Arith (Arith.mul x y)
let div x y = _Arith (Arith.div x y)
let pow x i = _Arith (Arith.pow x i)
let arith = _Arith
let _Splat x = (* sequences *)
let splat x =
(* 0^ ==> 0 *) (* 0^ ==> 0 *)
(if x == zero then x else Splat x) |> check invariant (if x == zero then x else Splat x) |> check invariant
let seq_size_exn = let seq_size_exn =
let invalid = Invalid_argument "seq_size_exn" in let invalid = Invalid_argument "seq_size_exn" in
let rec seq_size_exn = function let rec seq_size_exn = function
| Sized {siz= n} | Extract {len= n} -> n | Sized {siz= n} | Extract {len= n} -> n
@ -276,28 +288,27 @@ end = struct
in in
seq_size_exn seq_size_exn
let seq_size e = let seq_size e = try Some (seq_size_exn e) with Invalid_argument _ -> None
try Some (seq_size_exn e) with Invalid_argument _ -> None
let _Sized ~seq ~siz = let sized ~seq ~siz =
( match seq_size seq with ( match seq_size seq with
(* ⟨n,α⟩ ==> α when n ≡ |α| *) (* ⟨n,α⟩ ==> α when n ≡ |α| *)
| Some n when equal siz n -> seq | Some n when equal siz n -> seq
| _ -> Sized {seq; siz} ) | _ -> Sized {seq; siz} )
|> check invariant |> check invariant
let partial_compare x y = let partial_compare x y =
match sub x y with match sub x y with
| Z z -> Some (Int.sign (Z.sign z)) | Z z -> Some (Int.sign (Z.sign z))
| Q q -> Some (Int.sign (Q.sign q)) | Q q -> Some (Int.sign (Q.sign q))
| _ -> None | _ -> None
let partial_ge x y = let partial_ge x y =
match partial_compare x y with Some (Pos | Zero) -> true | _ -> false match partial_compare x y with Some (Pos | Zero) -> true | _ -> false
let empty_seq = Concat [||] let empty_seq = Concat [||]
let rec _Extract ~seq ~off ~len = let rec extract ~seq ~off ~len =
[%trace] [%trace]
~call:(fun {pf} -> pf "@ %a" pp (Extract {seq; off; len})) ~call:(fun {pf} -> pf "@ %a" pp (Extract {seq; off; len}))
~retn:(fun {pf} -> pf "%a" pp) ~retn:(fun {pf} -> pf "%a" pp)
@ -309,13 +320,13 @@ end = struct
match seq with match seq with
(* α[m,k)[o,l) ==> α[m+o,l) when k ≥ o+l *) (* α[m,k)[o,l) ==> α[m+o,l) when k ≥ o+l *)
| Extract {seq= a; off= m; len= k} when partial_ge k o_l -> | Extract {seq= a; off= m; len= k} when partial_ge k o_l ->
_Extract ~seq:a ~off:(add m off) ~len extract ~seq:a ~off:(add m off) ~len
(* ⟨n,0⟩[o,l) ==> ⟨l,0⟩ when n ≥ o+l *) (* ⟨n,0⟩[o,l) ==> ⟨l,0⟩ when n ≥ o+l *)
| Sized {siz= n; seq} when seq == zero && partial_ge n o_l -> | Sized {siz= n; seq} when seq == zero && partial_ge n o_l ->
_Sized ~seq ~siz:len sized ~seq ~siz:len
(* ⟨n,E^⟩[o,l) ==> ⟨l,E^⟩ when n ≥ o+l *) (* ⟨n,E^⟩[o,l) ==> ⟨l,E^⟩ when n ≥ o+l *)
| Sized {siz= n; seq= Splat _ as e} when partial_ge n o_l -> | Sized {siz= n; seq= Splat _ as e} when partial_ge n o_l ->
_Sized ~seq:e ~siz:len sized ~seq:e ~siz:len
(* ⟨n,a⟩[0,n) ==> ⟨n,a⟩ *) (* ⟨n,a⟩[0,n) ==> ⟨n,a⟩ *)
| Sized {siz= n} when equal off zero && equal n len -> seq | Sized {siz= n} when equal off zero && equal n len -> seq
(* For (α₀^α₁)[o,l) there are 3 cases: (* For (α₀^α₁)[o,l) there are 3 cases:
@ -344,7 +355,7 @@ end = struct
Array.fold_map_until na1N (l, off) Array.fold_map_until na1N (l, off)
~f:(fun naI (l, oI) -> ~f:(fun naI (l, oI) ->
if Z.equal Z.zero l then if Z.equal Z.zero l then
`Continue (_Extract ~seq:naI ~off:oI ~len:zero, (l, oI)) `Continue (extract ~seq:naI ~off:oI ~len:zero, (l, oI))
else else
let nI = seq_size_exn naI in let nI = seq_size_exn naI in
let oI_nI = sub oI nI in let oI_nI = sub oI nI in
@ -354,15 +365,15 @@ end = struct
let lI = Z.(max zero (min l (neg z))) in let lI = Z.(max zero (min l (neg z))) in
let l = Z.(l - lI) in let l = Z.(l - lI) in
`Continue `Continue
(_Extract ~seq:naI ~off:oI ~len:(_Z lI), (l, oJ)) (extract ~seq:naI ~off:oI ~len:(_Z lI), (l, oJ))
| _ -> `Stop (Extract {seq; off; len}) ) | _ -> `Stop (Extract {seq; off; len}) )
~finish:(fun (e1N, _) -> _Concat e1N) ~finish:(fun (e1N, _) -> concat e1N)
| _ -> Extract {seq; off; len} ) | _ -> Extract {seq; off; len} )
(* α[o,l) *) (* α[o,l) *)
| _ -> Extract {seq; off; len} ) | _ -> Extract {seq; off; len} )
|> check invariant |> check invariant
and _Concat xs = and concat xs =
[%trace] [%trace]
~call:(fun {pf} -> pf "@ %a" pp (Concat xs)) ~call:(fun {pf} -> pf "@ %a" pp (Concat xs))
~retn:(fun {pf} -> pf "%a" pp) ~retn:(fun {pf} -> pf "%a" pp)
@ -380,102 +391,29 @@ end = struct
, Extract {seq= na'; off= o_k; len= l} ) , Extract {seq= na'; off= o_k; len= l} )
when equal na na' && equal o_k (add o k) && partial_ge n (add o_k l) when equal na na' && equal o_k (add o k) && partial_ge n (add o_k l)
-> ->
Some (_Extract ~seq:na ~off:o ~len:(add k l)) Some (extract ~seq:na ~off:o ~len:(add k l))
(* ⟨m,0⟩^⟨n,0⟩ ==> ⟨m+n,0⟩ *) (* ⟨m,0⟩^⟨n,0⟩ ==> ⟨m+n,0⟩ *)
| Sized {siz= m; seq= a}, Sized {siz= n; seq= a'} | Sized {siz= m; seq= a}, Sized {siz= n; seq= a'}
when a == zero && a' == zero -> when a == zero && a' == zero ->
Some (_Sized ~seq:a ~siz:(add m n)) Some (sized ~seq:a ~siz:(add m n))
(* ⟨m,E^⟩^⟨n,E^⟩ ==> ⟨m+n,E^⟩ *) (* ⟨m,E^⟩^⟨n,E^⟩ ==> ⟨m+n,E^⟩ *)
| Sized {siz= m; seq= Splat _ as a}, Sized {siz= n; seq= a'} | Sized {siz= m; seq= Splat _ as a}, Sized {siz= n; seq= a'}
when equal a a' -> when equal a a' ->
Some (_Sized ~seq:a ~siz:(add m n)) Some (sized ~seq:a ~siz:(add m n))
| _ -> None | _ -> None
in in
let xs = flatten xs in let xs = flatten xs in
let xs = Array.reduce_adjacent ~f:simp_adjacent xs in let xs = Array.reduce_adjacent ~f:simp_adjacent xs in
(if Array.length xs = 1 then xs.(0) else Concat xs) |> check invariant (if Array.length xs = 1 then xs.(0) else Concat xs) |> check invariant
let _Apply f es = (* uninterpreted *)
let apply f es =
( match Funsym.eval ~equal ~get_z ~ret_z:_Z ~get_q ~ret_q:_Q f es with ( match Funsym.eval ~equal ~get_z ~ret_z:_Z ~get_q ~ret_q:_Q f es with
| Some c -> c | Some c -> c
| None -> Apply (f, es) ) | None -> Apply (f, es) )
|> check invariant |> check invariant
(** Traverse *)
let rec iter_vars e ~f =
match e with
| Var _ as v -> f (Var1.of_ v)
| Z _ | Q _ -> ()
| Splat x -> iter_vars ~f x
| Sized {seq= x; siz= y} ->
iter_vars ~f x ;
iter_vars ~f y
| Extract {seq= x; off= y; len= z} ->
iter_vars ~f x ;
iter_vars ~f y ;
iter_vars ~f z
| Concat xs | Apply (_, xs) -> Array.iter ~f:(iter_vars ~f) xs
| Arith a -> Iter.iter ~f:(iter_vars ~f) (Arith.trms a)
let vars e = Iter.from_labelled_iter (iter_vars e)
end
include Trm
module Var = Var1
module Set = struct
include Set.Make (Trm)
include Provide_of_sexp (Trm)
include Provide_pp (Trm)
let of_vars : Var.Set.t -> t =
fun vs ->
of_iter
(Iter.map ~f:(fun v -> (v : Var.t :> Trm.t)) (Var.Set.to_iter vs))
end
module Map = struct
include Map.Make (Trm)
include Provide_of_sexp (Trm)
end
type arith = Arith.t
let pp_diff fs (x, y) = Format.fprintf fs "-- %a ++ %a" pp x pp y
(** Construct *)
(* variables *)
let var v = (v : Var.t :> t)
(* arithmetic *)
let zero = _Z Z.zero
let one = _Z Z.one
let integer z = _Z z
let rational q = _Q q
let neg x = _Arith Arith.(neg (trm x))
let add = Trm.add
let sub = Trm.sub
let mulq q x = _Arith Arith.(mulc q (trm x))
let mul x y = _Arith (Arith.mul x y)
let div x y = _Arith (Arith.div x y)
let pow x i = _Arith (Arith.pow x i)
let arith = _Arith
(* sequences *)
let splat = _Splat
let sized = _Sized
let extract = _Extract
let concat elts = _Concat elts
(* uninterpreted *)
let apply sym args = _Apply sym args
(** Traverse *) (** Traverse *)
let trms = function let trms = function
@ -505,25 +443,25 @@ let rec map_vars e ~f =
| Var _ as v -> (f (Var.of_ v) : Var.t :> t) | Var _ as v -> (f (Var.of_ v) : Var.t :> t)
| Z _ | Q _ -> e | Z _ | Q _ -> e
| Arith a -> map1 (Arith.map ~f:(map_vars ~f)) e _Arith a | Arith a -> map1 (Arith.map ~f:(map_vars ~f)) e _Arith a
| Splat x -> map1 (map_vars ~f) e _Splat x | Splat x -> map1 (map_vars ~f) e splat x
| Sized {seq; siz} -> | Sized {seq; siz} ->
map2 (map_vars ~f) e (fun seq siz -> _Sized ~seq ~siz) seq siz map2 (map_vars ~f) e (fun seq siz -> sized ~seq ~siz) seq siz
| Extract {seq; off; len} -> | Extract {seq; off; len} ->
map3 (map_vars ~f) e map3 (map_vars ~f) e
(fun seq off len -> _Extract ~seq ~off ~len) (fun seq off len -> extract ~seq ~off ~len)
seq off len seq off len
| Concat xs -> mapN (map_vars ~f) e _Concat xs | Concat xs -> mapN (map_vars ~f) e concat xs
| Apply (g, xs) -> mapN (map_vars ~f) e (_Apply g) xs | Apply (g, xs) -> mapN (map_vars ~f) e (apply g) xs
let map e ~f = let map e ~f =
match e with match e with
| Var _ | Z _ | Q _ -> e | Var _ | Z _ | Q _ -> e
| Arith a -> map1 (Arith.map ~f) e _Arith a | Arith a -> map1 (Arith.map ~f) e _Arith a
| Splat x -> map1 f e _Splat x | Splat x -> map1 f e splat x
| Sized {seq; siz} -> map2 f e (fun seq siz -> _Sized ~seq ~siz) seq siz | Sized {seq; siz} -> map2 f e (fun seq siz -> sized ~seq ~siz) seq siz
| Extract {seq; off; len} -> | Extract {seq; off; len} ->
map3 f e (fun seq off len -> _Extract ~seq ~off ~len) seq off len map3 f e (fun seq off len -> extract ~seq ~off ~len) seq off len
| Concat xs -> mapN f e _Concat xs | Concat xs -> mapN f e concat xs
| Apply (g, xs) -> mapN f e (_Apply g) xs | Apply (g, xs) -> mapN f e (apply g) xs
let fold_map e = fold_map_from_map map e let fold_map e = fold_map_from_map map e

@ -9,9 +9,11 @@
type arith type arith
(** Terms, built from variables and applications of function symbols from
various theories. Denote functions from structures to values. *)
type t = private type t = private
(* variables *) (* variables *)
| Var of {id: int; name: string} | Var of {id: int; name: string [@ignore]}
(* arithmetic *) (* arithmetic *)
| Z of Z.t | Z of Z.t
| Q of Q.t | Q of Q.t
@ -25,24 +27,15 @@ type t = private
| Apply of Funsym.t * t array | Apply of Funsym.t * t array
[@@deriving compare, equal, sexp] [@@deriving compare, equal, sexp]
(** Arithmetic terms *)
module Arith : Arithmetic.S with type trm := t with type t = arith module Arith : Arithmetic.S with type trm := t with type t = arith
module Var : sig
type trm := t
include Var_intf.S with type t = private trm
val of_ : trm -> t
val of_trm : trm -> t option
end
module Set : sig module Set : sig
include Set.S with type elt := t include Set.S with type elt := t
val t_of_sexp : Sexp.t -> t val t_of_sexp : Sexp.t -> t
val pp : t pp val pp : t pp
val pp_diff : (t * t) pp val pp_diff : (t * t) pp
val of_vars : Var.Set.t -> t
end end
module Map : sig module Map : sig
@ -51,6 +44,16 @@ module Map : sig
val t_of_sexp : (Sexp.t -> 'a) -> Sexp.t -> 'a t val t_of_sexp : (Sexp.t -> 'a) -> Sexp.t -> 'a t
end end
(** Variable terms, represented as a subtype of general terms *)
module Var : sig
type trm := t
include
Var_intf.S with type t = private trm with type Set.t = private Set.t
val of_trm : trm -> t option
end
val ppx : Var.strength -> t pp val ppx : Var.strength -> t pp
val pp : t pp val pp : t pp
val pp_diff : (t * t) pp val pp_diff : (t * t) pp

Loading…
Cancel
Save