[sledge] Rework term and arithmetic definitions to avoid recursive modules

Summary:
Terms include Arithmetic terms, which are polynomials over terms
themselves. Monomials are represented as maps from
terms (multiplicative factors) to integers (their powers). Polynomials
are represented as maps from monomials (indeterminates) to
rationals (coefficients). In particular, terms are represented using
maps whose keys are terms themselves. This is currently implemented
using recursive modules.

This diff uses the Comparer-based interface of Maps to express this
cycle as recursive *types* rather than recursive *modules*, see the
very beginning of trm.ml. The rest of the changes are driven by the
need to expose the Arithmetic.t type at toplevel, outside the functor
that defines the arithmetic operations, and changes to stage the
definition of term and polynomial operations to remove unnecessary
recursion.

One might hope that these changes are just moving code around, but due
to how recursive modules are implemented, this refactoring is
motivated by performance profiling. In every cycle between recursive
modules, at least one of the modules must be "safe". A "safe" module
is one where all exposed values have function type. This allows the
compiler to initialize that module with functions that immediately
raise an exception, define the other modules using it, and then tie
the recursive knot by backpatching the safe module with the actual
functions at the end. This implementation works, but has the
consequence that the compiler must treat calls to functions of safe
recursive modules as indirect calls to unknown functions. This means
that they are not inlined or even called by symbol, and instead
calling them involves spilling registers if needed, loading their
address from memory, calling them by address, and restoring any
spilled registers. For operations like Trm.compare that are a handful
of instructions on the hot path, this is a significant
difference. Since terms are the keys of maps and sets in the core of
the first-order equality solver, those map operations are very very
hot.

Reviewed By: jvillard

Differential Revision: D26250533

fbshipit-source-id: f79334c68
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent 24ca0666d3
commit 5704b160bf

@ -19,17 +19,27 @@ module Q = struct
include Comparer.Make (Q)
end
module Representation
(Var : Var_intf.S)
(Trm : INDETERMINATE with type var := Var.t) =
type ('trm, 'compare_trm) mono = ('trm, int, 'compare_trm) Multiset.t
[@@deriving compare, equal, sexp]
type 'compare_trm compare_mono =
('compare_trm, Int.compare) Multiset.compare
[@@deriving compare, equal, sexp]
type ('trm, 'compare_trm) t =
(('trm, 'compare_trm) mono, Q.t, 'compare_trm compare_mono) Multiset.t
[@@deriving compare, equal, sexp]
module Make (Trm0 : sig
type t [@@deriving equal, sexp]
include Comparer.S with type t := t
end) =
struct
module Prod = struct
include Multiset.Make (Trm) (Int)
include Provide_of_sexp (Trm)
end
module Prod = Multiset.Make (Trm0) (Int)
module Mono = struct
type t = Prod.t [@@deriving compare, equal, sexp]
type t = Prod.t [@@deriving compare, equal, sexp_of]
let num_den m = Prod.partition m ~f:(fun _ i -> i >= 0)
@ -79,11 +89,7 @@ struct
Iter.from_iter (fun f -> Prod.iter mono ~f:(fun trm _ -> f trm))
end
module Sum = struct
include Multiset.Make (Prod) (Q)
include Provide_of_sexp (Prod)
end
module Sum = Multiset.Make (Prod) (Q)
module Poly = Sum
include Poly
@ -104,6 +110,11 @@ struct
(Sum.pp "@ + " pp_coeff_mono)
poly
let trms poly =
Iter.from_iter (fun f ->
Sum.iter poly ~f:(fun mono _ ->
Prod.iter mono ~f:(fun trm _ -> f trm) ) )
(* core invariant *)
let mono_invariant mono =
@ -119,10 +130,14 @@ struct
assert (not (Q.equal Q.zero coeff)) ;
mono_invariant mono )
(* embed a term into a polynomial *)
let trm trm = Sum.of_ (Mono.of_ trm 1) Q.one
(* constants *)
let const q = Sum.of_ Mono.one q |> check invariant
let zero = const Q.zero |> check (fun p -> assert (Sum.is_empty p))
let one = const Q.one
(* core constructors *)
@ -149,7 +164,7 @@ struct
(* projections and embeddings *)
type kind = Trm of Trm.t | Const of Q.t | Interpreted | Uninterpreted
type kind = Trm of Trm0.t | Const of Q.t | Interpreted | Uninterpreted
let classify poly =
match Sum.classify poly with
@ -189,7 +204,12 @@ struct
| None -> None
end
module Make (Embed : EMBEDDING with type trm := Trm.t and type t := t) =
include S0
module Embed
(Var : Var_intf.S)
(Trm : TRM with type t = Trm0.t with type var := Var.t)
(Embed : EMBEDDING with type trm := Trm0.t with type t := t) =
struct
module Mono = struct
include Mono
@ -202,16 +222,24 @@ struct
include Poly
include S0
(** hide S0.trm and S0.trms that ignore the embedding, shadowed below *)
let[@warning "-32"] trm, trms = ((), ())
let pp = ppx Trm.pp
(** Embed a polynomial into a term *)
let trm_of_poly = Embed.to_trm
(** Embed a monomial into a term, flattening if possible *)
let trm_of_mono mono =
match Mono.get_trm mono with
| Some trm -> trm
| None -> Embed.to_trm (Sum.of_ mono Q.one)
| None -> trm_of_poly (Sum.of_ mono Q.one)
(* traverse *)
let vars poly = Iter.flat_map ~f:Trm.vars (S0.trms poly)
let monos poly =
Iter.from_iter (fun f ->
Sum.iter poly ~f:(fun mono _ ->
@ -222,10 +250,6 @@ struct
| Some mono -> Mono.trms mono
| None -> Iter.map ~f:trm_of_mono (monos poly)
let vars p = Iter.flat_map ~f:Trm.vars (trms p)
(* invariant *)
let mono_invariant mono =
mono_invariant mono ;
let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in
@ -309,7 +333,7 @@ struct
let trm trm =
( match Embed.get_arith trm with
| Some poly -> poly
| None -> Sum.of_ (Mono.of_ trm 1) Q.one )
| None -> S0.trm trm )
|> check (fun poly ->
assert (equal poly (CM.to_poly (CM.of_trm trm))) )
@ -322,7 +346,7 @@ struct
let pow base power = CM.to_poly (CM.of_trm base ~power)
(* map over [trms] *)
(** map over [trms] *)
let map poly ~f =
[%trace]
~call:(fun {pf} -> pf "@ %a" pp poly)
@ -398,7 +422,7 @@ struct
| Some _ as soln -> soln
| None -> solve_poly (Sum.add mono coeff rejected) poly
(* solve [0 = e] *)
(** solve [0 = e] *)
let solve_zero_eq ?for_ e =
[%trace]
~call:(fun {pf} ->

@ -9,7 +9,25 @@
include module type of Arithmetic_intf
module Representation
(** Arithmetic terms, e.g. polynomials, polymorphic in the type of
indeterminates. *)
type ('trm, 'cmp) t [@@deriving compare, equal, sexp]
(** Functor that, given a totally ordered type of indeterminate terms,
builds an implementation of the embedding-independent arithmetic
operations, and a functor that, given an embedding of arithmetic terms
into indeterminate terms, builds an implementation of the arithmetic
operations. *)
module Make (Ord : sig
type t [@@deriving equal, sexp]
include Comparer.S with type t := t
end) : sig
include S0 with type t = (Ord.t, Ord.compare) t with type trm := Ord.t
module Embed
(Var : Var_intf.S)
(Indeterminate : INDETERMINATE with type var := Var.t) :
REPRESENTATION with type var := Var.t with type trm := Indeterminate.t
(Trm : TRM with type t = Ord.t with type var := Var.t)
(_ : EMBEDDING with type trm := Trm.t and type t := t) :
S with type trm := Trm.t with type t = t
end

@ -14,31 +14,29 @@ module type EMBEDDING = sig
val to_trm : t -> trm
(** Embedding from [t] to [trm]: [to_trm a] is arithmetic term [a]
embedded in an indeterminate term. *)
embedded into an indeterminate term. *)
val get_arith : trm -> t option
(** Partial projection from [trm] to [t]: [get_arith x] is [Some a] if
[x = to_trm a]. This is used to flatten indeterminates that are
actually arithmetic for the client, thereby enabling arithmetic
operations to be interpreted more often. *)
(** Partial projection from [trm] to [t]: [get_arith x] is [Some a] iff
[x = to_trm a]. *)
end
(** Indeterminate terms, treated as atomic / variables except when they can
be flattened using [EMBEDDING.get_arith]. *)
module type INDETERMINATE = sig
type t [@@deriving compare, equal, sexp]
be flattened using {!EMBEDDING.get_arith}. *)
module type TRM = sig
include Comparer.S
include Comparer.S with type t := t
val pp : t pp
type var
val pp : t pp
val vars : t -> var iter
end
module type S = sig
(** Arithmetic terms, e.g. polynomials [t] over indeterminate terms [trm] *)
module type S0 = sig
type trm
type t [@@deriving compare, equal, sexp]
type t [@@deriving compare, equal]
val ppx : trm pp -> t pp
@ -48,8 +46,10 @@ module type S = sig
(** [trm x] represents the indeterminate term [x] *)
val const : Q.t -> t
(** [const q] represents the constant [q] *)
(** [const q] represents the rational constant [q] *)
val zero : t
val one : t
val neg : t -> t
val add : t -> t -> t
val sub : t -> t -> t
@ -84,6 +84,22 @@ module type S = sig
val is_uninterpreted : t -> bool
(** [is_uninterpreted a] iff [classify a = Uninterpreted] *)
(** Traverse *)
val trms : t -> trm iter
(** [trms a] enumerates the indeterminate terms appearing in [a].
Considering an arithmetic term as a polynomial,
[trms (c × (Σ c × Π
X^p))] is the sequence of terms [X] for each [i] and
[j]. *)
end
(** Arithmetic terms, where an embedding {!EMBEDDING.get_arith} into
indeterminate terms is used to implicitly flatten arithmetic terms that
are embedded into general terms to the underlying arithmetic term. *)
module type S = sig
include S0
(** Construct nonlinear arithmetic terms using the embedding, enabling
interpreting associativity, commutatitivity, and unit laws, but not
the full nonlinear arithmetic theory. *)
@ -95,8 +111,8 @@ module type S = sig
(** Traverse *)
val trms : t -> trm iter
(** [trms a] is the maximal foreign or noninterpreted proper subterms of
[a]. Considering an arithmetic term as a polynomial,
(** [trms a] enumerates the maximal foreign or noninterpreted proper
subterms of [a]. Considering an arithmetic term as a polynomial,
[trms (c × (Σ c × Π
X^p))] is the sequence of monomials
[Π X^p] for each [i]. If the arithmetic
@ -115,16 +131,3 @@ module type S = sig
expressed as [e = f] for some monomial subterm [e] of [d]. If [for_]
is passed, then the subterm [e] must be [for_]. *)
end
(** A type [t] representing arithmetic terms over indeterminates [trm]
together with a functor [Make] that takes an [EMBEDDING] of arithmetic
terms [t] into indeterminates [trm] and produces an implementation of
the primary interface [S]. *)
module type REPRESENTATION = sig
type t [@@deriving compare, equal, sexp]
type var
type trm
module Make (_ : EMBEDDING with type trm := trm and type t := t) :
S with type trm := trm with type t := t
end

@ -123,9 +123,7 @@ end = struct
[not (is_valid_eq xs e f)] implies [not (is_valid_eq ys e f)] for
[ys xs]. *)
let is_valid_eq xs e f =
let is_var_in xs e =
Option.exists ~f:(fun x -> Var.Set.mem x xs) (Var.of_trm e)
in
let is_var_in xs e = Trm.Set.mem e (xs : Var.Set.t :> Trm.Set.t) in
let noninterp_with_solvable_var_in xs e =
is_var_in xs e
|| Theory.is_noninterpreted e
@ -922,10 +920,11 @@ let trim ks x =
Cls.add rep (Option.value cls0 ~default:Cls.empty) ) )
in
(* enumerate expanded classes and update solution subst *)
let kills = Trm.Set.of_vars ks in
Trm.Map.fold clss x ~f:(fun ~key:a' ~data:ecls x ->
(* remove mappings for non-rep class elements to kill *)
let keep, drop = Trm.Set.diff_inter (Cls.to_set ecls) kills in
let keep, drop =
Trm.Set.diff_inter (Cls.to_set ecls) (ks : Var.Set.t :> Trm.Set.t)
in
if Trm.Set.is_empty drop then x
else
let rep = Trm.Set.fold ~f:Subst.remove drop x.rep in

@ -196,7 +196,7 @@ let solve d e s =
| Some ((Sized {siz= n; seq= Splat _} as b), Concat a0V) ->
solve_concat a0V b n s
| Some ((Var _ as v), (Concat a0V as c)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv c)) then
if not (Trm.Set.mem v (Trm.fv c :> Trm.Set.t)) then
(* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *)
add_solved ~var:v ~rep:c s
else
@ -212,7 +212,7 @@ let solve d e s =
* Extract
*)
| Some ((Var _ as v), (Extract {len= l} as e)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv e)) then
if not (Trm.Set.mem v (Trm.fv e :> Trm.Set.t)) then
(* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *)
add_solved ~var:v ~rep:e s
else

@ -7,45 +7,20 @@
(** Terms *)
(** Representation of Arithmetic terms *)
module rec Arith0 :
(Arithmetic.REPRESENTATION
with type var := Trm.Var1.t
with type trm := Trm.t) =
Arithmetic.Representation
(Trm.Var1)
(struct
include Trm
include Comparer.Make (Trm)
end)
(** Arithmetic terms *)
and Arith : (Arithmetic.S with type trm := Trm.t with type t = Arith0.t) =
struct
include Arith0
include Make (struct
let to_trm = Trm._Arith
(* Define term type using polymorphic arithmetic type, with derived compare,
equal, and sexp_of functions *)
module Trm1 = struct
type compare [@@deriving compare, equal, sexp]
let get_arith (e : Trm.t) =
match e with
| Z z -> Some (Arith.const (Q.of_z z))
| Q q -> Some (Arith.const q)
| Arith a -> Some a
| _ -> None
end)
end
type arith = (t, compare) Arithmetic.t
(** Terms, built from variables and applications of function symbols from
various theories. Denote functions from structures to values. *)
and Trm : sig
type t = private
and t =
(* variables *)
| Var of {id: int; name: string}
| Var of {id: int; name: string [@ignore]}
(* arithmetic *)
| Z of Z.t
| Q of Q.t
| Arith of Arith.t
| Arith of arith
(* sequences (of flexible size) *)
| Splat of t
| Sized of {seq: t; siz: t}
@ -54,50 +29,20 @@ and Trm : sig
(* uninterpreted *)
| Apply of Funsym.t * t array
[@@deriving compare, equal, sexp]
end
(** Variable terms, represented as a subtype of general terms *)
module Var1 : sig
type trm := t
include Var_intf.S with type t = private trm
val of_ : trm -> t
val of_trm : trm -> t option
(* Add comparer, needed to instantiate arithmetic and containers *)
module Trm2 = struct
include Comparer.Counterfeit (Trm1)
include Trm1
end
val ppx : Var1.strength -> t pp
val pp : t pp
include Invariant.S with type t := t
val _Var : int -> string -> t
val _Z : Z.t -> t
val _Q : Q.t -> t
val _Arith : Arith.t -> t
val _Splat : t -> t
val _Sized : seq:t -> siz:t -> t
val _Extract : seq:t -> off:t -> len:t -> t
val _Concat : t array -> t
val _Apply : Funsym.t -> t array -> t
val add : t -> t -> t
val sub : t -> t -> t
val seq_size_exn : t -> t
val seq_size : t -> t option
val get_z : t -> Z.t option
val get_q : t -> Q.t option
val vars : t -> Var1.t iter
end = struct
type t =
| Var of {id: int; name: string [@ignore]}
| Z of Z.t
| Q of Q.t
| Arith of Arith.t
| Splat of t
| Sized of {seq: t; siz: t}
| Extract of {seq: t; off: t; len: t}
| Concat of t array
| Apply of Funsym.t * t array
[@@deriving compare, equal, sexp]
(* Specialize arithmetic type and define operations using comparer *)
module Arith0 = Arithmetic.Make (Trm2)
(* Add ppx, defined recursively with Arith0.ppx *)
module Trm3 = struct
include Trm2
(* nul-terminated string value represented by a concatenation *)
let string_of_concat xs =
@ -136,7 +81,7 @@ end = struct
| Some `Anonymous -> Trace.pp_styled `Cyan "_" fs )
| Z z -> Trace.pp_styled `Magenta "%a" fs Z.pp z
| Q q -> Trace.pp_styled `Magenta "%a" fs Q.pp q
| Arith a -> Arith.ppx (ppx strength) fs a
| Arith a -> Arith0.ppx (ppx strength) fs a
| Splat x -> pf "%a^" pp x
| Sized {seq; siz} -> pf "@<1>⟨%a,%a@<1>⟩" pp siz pp seq
| Extract {seq; off; len} -> pf "%a[%a,%a)" pp seq pp off pp len
@ -157,19 +102,31 @@ end = struct
pp fs trm
let pp = ppx (fun _ -> None)
let pp_diff fs (x, y) = Format.fprintf fs "-- %a ++ %a" pp x pp y
end
(* Define containers over terms *)
module Set = struct
include Set.Make (Trm3)
include Provide_of_sexp (Trm3)
include Provide_pp (Trm3)
end
module Map = struct
include Map.Make (Trm3)
include Provide_of_sexp (Trm3)
end
(* Define variables as a subtype of terms *)
module Var1 = struct
module Var = struct
open Trm3
module V = struct
module T = struct
type nonrec t = t [@@deriving compare, equal, sexp]
type strength = t -> [`Universal | `Existential | `Anonymous] option
let pp = pp
let ppx = ppx
end
include T
let invariant x =
let@ () = Invariant.invariant [%here] x [%sexp_of: t] in
@ -182,22 +139,16 @@ end = struct
let name = function Var v -> v.name | x -> violates invariant x
module Set = struct
module S = NS.Set.Make (T)
include S
include Provide_of_sexp (T)
include Provide_pp (T)
include Set
let ppx strength vs = S.pp_full (ppx strength) vs
let ppx strength vs = pp_full (ppx strength) vs
let pp_xs fs xs =
if not (is_empty xs) then
Format.fprintf fs "@<2>∃ @[%a@] .@;<1 2>" pp xs
end
module Map = struct
include NS.Map.Make (T)
include Provide_of_sexp (T)
end
module Map = Map
let fresh name ~wrt =
let max =
@ -221,24 +172,42 @@ end = struct
module Subst = Subst.Make (V)
end
(* Add definitions needed for arithmetic embedding into terms *)
module Trm = struct
include Trm3
(** Invariant *)
let invariant e =
let@ () = Invariant.invariant [%here] e [%sexp_of: t] in
match e with
| Q q -> assert (not (Z.equal Z.one (Q.den q)))
| Arith a -> (
match Arith.classify a with
match Arith0.classify a with
| Trm _ | Const _ -> assert false
| _ -> () )
| _ -> ()
(** Destruct *)
(** Traverse *)
let get_z = function Z z -> Some z | _ -> None
let get_q = function Q q -> Some q | Z z -> Some (Q.of_z z) | _ -> None
let rec iter_vars e ~f =
match e with
| Var _ as v -> f (Var.of_ v)
| Z _ | Q _ -> ()
| Splat x -> iter_vars ~f x
| Sized {seq= x; siz= y} ->
iter_vars ~f x ;
iter_vars ~f y
| Extract {seq= x; off= y; len= z} ->
iter_vars ~f x ;
iter_vars ~f y ;
iter_vars ~f z
| Concat xs | Apply (_, xs) -> Array.iter ~f:(iter_vars ~f) xs
| Arith a -> Iter.iter ~f:(iter_vars ~f) (Arith0.trms a)
(** Construct *)
let vars e = Iter.from_labelled_iter (iter_vars e)
let _Var id name = Var {id; name} |> check invariant
(** Construct *)
(* statically allocated since they are tested with == *)
let zero = Z Z.zero |> check invariant
@ -253,16 +222,59 @@ end = struct
|> check invariant
let _Arith a =
( match Arith.classify a with
( match Arith0.classify a with
| Trm e -> e
| Const q -> _Q q
| _ -> Arith a )
|> check invariant
end
include Trm
(* Instantiate arithmetic with embedding into terms, yielding full
Arithmetic interface *)
module Arith =
Arith0.Embed (Var) (Trm)
(struct
let to_trm = _Arith
let get_arith e =
match e with
| Z z -> Some (Arith0.const (Q.of_z z))
| Q q -> Some (Arith0.const q)
| Arith a -> Some a
| _ -> None
end)
(* Full Trm definition, using full arithmetic interface *)
(** Destruct *)
let get_z = function Z z -> Some z | _ -> None
let get_q = function Q q -> Some q | Z z -> Some (Q.of_z z) | _ -> None
(** Construct *)
(* variables *)
let var v = (v : Var.t :> t)
(* arithmetic *)
let integer z = _Z z
let rational q = _Q q
let neg x = _Arith Arith.(neg (trm x))
let add x y = _Arith Arith.(add (trm x) (trm y))
let sub x y = _Arith Arith.(sub (trm x) (trm y))
let mulq q x = _Arith Arith.(mulc q (trm x))
let mul x y = _Arith (Arith.mul x y)
let div x y = _Arith (Arith.div x y)
let pow x i = _Arith (Arith.pow x i)
let arith = _Arith
(* sequences *)
let _Splat x =
let splat x =
(* 0^ ==> 0 *)
(if x == zero then x else Splat x) |> check invariant
@ -276,10 +288,9 @@ end = struct
in
seq_size_exn
let seq_size e =
try Some (seq_size_exn e) with Invalid_argument _ -> None
let seq_size e = try Some (seq_size_exn e) with Invalid_argument _ -> None
let _Sized ~seq ~siz =
let sized ~seq ~siz =
( match seq_size seq with
(* ⟨n,α⟩ ==> α when n ≡ |α| *)
| Some n when equal siz n -> seq
@ -297,7 +308,7 @@ end = struct
let empty_seq = Concat [||]
let rec _Extract ~seq ~off ~len =
let rec extract ~seq ~off ~len =
[%trace]
~call:(fun {pf} -> pf "@ %a" pp (Extract {seq; off; len}))
~retn:(fun {pf} -> pf "%a" pp)
@ -309,13 +320,13 @@ end = struct
match seq with
(* α[m,k)[o,l) ==> α[m+o,l) when k ≥ o+l *)
| Extract {seq= a; off= m; len= k} when partial_ge k o_l ->
_Extract ~seq:a ~off:(add m off) ~len
extract ~seq:a ~off:(add m off) ~len
(* ⟨n,0⟩[o,l) ==> ⟨l,0⟩ when n ≥ o+l *)
| Sized {siz= n; seq} when seq == zero && partial_ge n o_l ->
_Sized ~seq ~siz:len
sized ~seq ~siz:len
(* ⟨n,E^⟩[o,l) ==> ⟨l,E^⟩ when n ≥ o+l *)
| Sized {siz= n; seq= Splat _ as e} when partial_ge n o_l ->
_Sized ~seq:e ~siz:len
sized ~seq:e ~siz:len
(* ⟨n,a⟩[0,n) ==> ⟨n,a⟩ *)
| Sized {siz= n} when equal off zero && equal n len -> seq
(* For (α₀^α₁)[o,l) there are 3 cases:
@ -344,7 +355,7 @@ end = struct
Array.fold_map_until na1N (l, off)
~f:(fun naI (l, oI) ->
if Z.equal Z.zero l then
`Continue (_Extract ~seq:naI ~off:oI ~len:zero, (l, oI))
`Continue (extract ~seq:naI ~off:oI ~len:zero, (l, oI))
else
let nI = seq_size_exn naI in
let oI_nI = sub oI nI in
@ -354,15 +365,15 @@ end = struct
let lI = Z.(max zero (min l (neg z))) in
let l = Z.(l - lI) in
`Continue
(_Extract ~seq:naI ~off:oI ~len:(_Z lI), (l, oJ))
(extract ~seq:naI ~off:oI ~len:(_Z lI), (l, oJ))
| _ -> `Stop (Extract {seq; off; len}) )
~finish:(fun (e1N, _) -> _Concat e1N)
~finish:(fun (e1N, _) -> concat e1N)
| _ -> Extract {seq; off; len} )
(* α[o,l) *)
| _ -> Extract {seq; off; len} )
|> check invariant
and _Concat xs =
and concat xs =
[%trace]
~call:(fun {pf} -> pf "@ %a" pp (Concat xs))
~retn:(fun {pf} -> pf "%a" pp)
@ -380,22 +391,24 @@ end = struct
, Extract {seq= na'; off= o_k; len= l} )
when equal na na' && equal o_k (add o k) && partial_ge n (add o_k l)
->
Some (_Extract ~seq:na ~off:o ~len:(add k l))
Some (extract ~seq:na ~off:o ~len:(add k l))
(* ⟨m,0⟩^⟨n,0⟩ ==> ⟨m+n,0⟩ *)
| Sized {siz= m; seq= a}, Sized {siz= n; seq= a'}
when a == zero && a' == zero ->
Some (_Sized ~seq:a ~siz:(add m n))
Some (sized ~seq:a ~siz:(add m n))
(* ⟨m,E^⟩^⟨n,E^⟩ ==> ⟨m+n,E^⟩ *)
| Sized {siz= m; seq= Splat _ as a}, Sized {siz= n; seq= a'}
when equal a a' ->
Some (_Sized ~seq:a ~siz:(add m n))
Some (sized ~seq:a ~siz:(add m n))
| _ -> None
in
let xs = flatten xs in
let xs = Array.reduce_adjacent ~f:simp_adjacent xs in
(if Array.length xs = 1 then xs.(0) else Concat xs) |> check invariant
let _Apply f es =
(* uninterpreted *)
let apply f es =
( match Funsym.eval ~equal ~get_z ~ret_z:_Z ~get_q ~ret_q:_Q f es with
| Some c -> c
| None -> Apply (f, es) )
@ -403,81 +416,6 @@ end = struct
(** Traverse *)
let rec iter_vars e ~f =
match e with
| Var _ as v -> f (Var1.of_ v)
| Z _ | Q _ -> ()
| Splat x -> iter_vars ~f x
| Sized {seq= x; siz= y} ->
iter_vars ~f x ;
iter_vars ~f y
| Extract {seq= x; off= y; len= z} ->
iter_vars ~f x ;
iter_vars ~f y ;
iter_vars ~f z
| Concat xs | Apply (_, xs) -> Array.iter ~f:(iter_vars ~f) xs
| Arith a -> Iter.iter ~f:(iter_vars ~f) (Arith.trms a)
let vars e = Iter.from_labelled_iter (iter_vars e)
end
include Trm
module Var = Var1
module Set = struct
include Set.Make (Trm)
include Provide_of_sexp (Trm)
include Provide_pp (Trm)
let of_vars : Var.Set.t -> t =
fun vs ->
of_iter
(Iter.map ~f:(fun v -> (v : Var.t :> Trm.t)) (Var.Set.to_iter vs))
end
module Map = struct
include Map.Make (Trm)
include Provide_of_sexp (Trm)
end
type arith = Arith.t
let pp_diff fs (x, y) = Format.fprintf fs "-- %a ++ %a" pp x pp y
(** Construct *)
(* variables *)
let var v = (v : Var.t :> t)
(* arithmetic *)
let zero = _Z Z.zero
let one = _Z Z.one
let integer z = _Z z
let rational q = _Q q
let neg x = _Arith Arith.(neg (trm x))
let add = Trm.add
let sub = Trm.sub
let mulq q x = _Arith Arith.(mulc q (trm x))
let mul x y = _Arith (Arith.mul x y)
let div x y = _Arith (Arith.div x y)
let pow x i = _Arith (Arith.pow x i)
let arith = _Arith
(* sequences *)
let splat = _Splat
let sized = _Sized
let extract = _Extract
let concat elts = _Concat elts
(* uninterpreted *)
let apply sym args = _Apply sym args
(** Traverse *)
let trms = function
| Var _ | Z _ | Q _ -> Iter.empty
| Arith a -> Arith.trms a
@ -505,25 +443,25 @@ let rec map_vars e ~f =
| Var _ as v -> (f (Var.of_ v) : Var.t :> t)
| Z _ | Q _ -> e
| Arith a -> map1 (Arith.map ~f:(map_vars ~f)) e _Arith a
| Splat x -> map1 (map_vars ~f) e _Splat x
| Splat x -> map1 (map_vars ~f) e splat x
| Sized {seq; siz} ->
map2 (map_vars ~f) e (fun seq siz -> _Sized ~seq ~siz) seq siz
map2 (map_vars ~f) e (fun seq siz -> sized ~seq ~siz) seq siz
| Extract {seq; off; len} ->
map3 (map_vars ~f) e
(fun seq off len -> _Extract ~seq ~off ~len)
(fun seq off len -> extract ~seq ~off ~len)
seq off len
| Concat xs -> mapN (map_vars ~f) e _Concat xs
| Apply (g, xs) -> mapN (map_vars ~f) e (_Apply g) xs
| Concat xs -> mapN (map_vars ~f) e concat xs
| Apply (g, xs) -> mapN (map_vars ~f) e (apply g) xs
let map e ~f =
match e with
| Var _ | Z _ | Q _ -> e
| Arith a -> map1 (Arith.map ~f) e _Arith a
| Splat x -> map1 f e _Splat x
| Sized {seq; siz} -> map2 f e (fun seq siz -> _Sized ~seq ~siz) seq siz
| Splat x -> map1 f e splat x
| Sized {seq; siz} -> map2 f e (fun seq siz -> sized ~seq ~siz) seq siz
| Extract {seq; off; len} ->
map3 f e (fun seq off len -> _Extract ~seq ~off ~len) seq off len
| Concat xs -> mapN f e _Concat xs
| Apply (g, xs) -> mapN f e (_Apply g) xs
map3 f e (fun seq off len -> extract ~seq ~off ~len) seq off len
| Concat xs -> mapN f e concat xs
| Apply (g, xs) -> mapN f e (apply g) xs
let fold_map e = fold_map_from_map map e

@ -9,9 +9,11 @@
type arith
(** Terms, built from variables and applications of function symbols from
various theories. Denote functions from structures to values. *)
type t = private
(* variables *)
| Var of {id: int; name: string}
| Var of {id: int; name: string [@ignore]}
(* arithmetic *)
| Z of Z.t
| Q of Q.t
@ -25,24 +27,15 @@ type t = private
| Apply of Funsym.t * t array
[@@deriving compare, equal, sexp]
(** Arithmetic terms *)
module Arith : Arithmetic.S with type trm := t with type t = arith
module Var : sig
type trm := t
include Var_intf.S with type t = private trm
val of_ : trm -> t
val of_trm : trm -> t option
end
module Set : sig
include Set.S with type elt := t
val t_of_sexp : Sexp.t -> t
val pp : t pp
val pp_diff : (t * t) pp
val of_vars : Var.Set.t -> t
end
module Map : sig
@ -51,6 +44,16 @@ module Map : sig
val t_of_sexp : (Sexp.t -> 'a) -> Sexp.t -> 'a t
end
(** Variable terms, represented as a subtype of general terms *)
module Var : sig
type trm := t
include
Var_intf.S with type t = private trm with type Set.t = private Set.t
val of_trm : trm -> t option
end
val ppx : Var.strength -> t pp
val pp : t pp
val pp_diff : (t * t) pp

Loading…
Cancel
Save