[sledge] Rework term and arithmetic definitions to avoid recursive modules

Summary:
Terms include Arithmetic terms, which are polynomials over terms
themselves. Monomials are represented as maps from
terms (multiplicative factors) to integers (their powers). Polynomials
are represented as maps from monomials (indeterminates) to
rationals (coefficients). In particular, terms are represented using
maps whose keys are terms themselves. This is currently implemented
using recursive modules.

This diff uses the Comparer-based interface of Maps to express this
cycle as recursive *types* rather than recursive *modules*, see the
very beginning of trm.ml. The rest of the changes are driven by the
need to expose the Arithmetic.t type at toplevel, outside the functor
that defines the arithmetic operations, and changes to stage the
definition of term and polynomial operations to remove unnecessary
recursion.

One might hope that these changes are just moving code around, but due
to how recursive modules are implemented, this refactoring is
motivated by performance profiling. In every cycle between recursive
modules, at least one of the modules must be "safe". A "safe" module
is one where all exposed values have function type. This allows the
compiler to initialize that module with functions that immediately
raise an exception, define the other modules using it, and then tie
the recursive knot by backpatching the safe module with the actual
functions at the end. This implementation works, but has the
consequence that the compiler must treat calls to functions of safe
recursive modules as indirect calls to unknown functions. This means
that they are not inlined or even called by symbol, and instead
calling them involves spilling registers if needed, loading their
address from memory, calling them by address, and restoring any
spilled registers. For operations like Trm.compare that are a handful
of instructions on the hot path, this is a significant
difference. Since terms are the keys of maps and sets in the core of
the first-order equality solver, those map operations are very very
hot.

Reviewed By: jvillard

Differential Revision: D26250533

fbshipit-source-id: f79334c68
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent 24ca0666d3
commit 5704b160bf

@ -19,17 +19,27 @@ module Q = struct
include Comparer.Make (Q)
end
module Representation
(Var : Var_intf.S)
(Trm : INDETERMINATE with type var := Var.t) =
type ('trm, 'compare_trm) mono = ('trm, int, 'compare_trm) Multiset.t
[@@deriving compare, equal, sexp]
type 'compare_trm compare_mono =
('compare_trm, Int.compare) Multiset.compare
[@@deriving compare, equal, sexp]
type ('trm, 'compare_trm) t =
(('trm, 'compare_trm) mono, Q.t, 'compare_trm compare_mono) Multiset.t
[@@deriving compare, equal, sexp]
module Make (Trm0 : sig
type t [@@deriving equal, sexp]
include Comparer.S with type t := t
end) =
struct
module Prod = struct
include Multiset.Make (Trm) (Int)
include Provide_of_sexp (Trm)
end
module Prod = Multiset.Make (Trm0) (Int)
module Mono = struct
type t = Prod.t [@@deriving compare, equal, sexp]
type t = Prod.t [@@deriving compare, equal, sexp_of]
let num_den m = Prod.partition m ~f:(fun _ i -> i >= 0)
@ -79,11 +89,7 @@ struct
Iter.from_iter (fun f -> Prod.iter mono ~f:(fun trm _ -> f trm))
end
module Sum = struct
include Multiset.Make (Prod) (Q)
include Provide_of_sexp (Prod)
end
module Sum = Multiset.Make (Prod) (Q)
module Poly = Sum
include Poly
@ -104,6 +110,11 @@ struct
(Sum.pp "@ + " pp_coeff_mono)
poly
let trms poly =
Iter.from_iter (fun f ->
Sum.iter poly ~f:(fun mono _ ->
Prod.iter mono ~f:(fun trm _ -> f trm) ) )
(* core invariant *)
let mono_invariant mono =
@ -119,10 +130,14 @@ struct
assert (not (Q.equal Q.zero coeff)) ;
mono_invariant mono )
(* embed a term into a polynomial *)
let trm trm = Sum.of_ (Mono.of_ trm 1) Q.one
(* constants *)
let const q = Sum.of_ Mono.one q |> check invariant
let zero = const Q.zero |> check (fun p -> assert (Sum.is_empty p))
let one = const Q.one
(* core constructors *)
@ -149,7 +164,7 @@ struct
(* projections and embeddings *)
type kind = Trm of Trm.t | Const of Q.t | Interpreted | Uninterpreted
type kind = Trm of Trm0.t | Const of Q.t | Interpreted | Uninterpreted
let classify poly =
match Sum.classify poly with
@ -189,7 +204,12 @@ struct
| None -> None
end
module Make (Embed : EMBEDDING with type trm := Trm.t and type t := t) =
include S0
module Embed
(Var : Var_intf.S)
(Trm : TRM with type t = Trm0.t with type var := Var.t)
(Embed : EMBEDDING with type trm := Trm0.t with type t := t) =
struct
module Mono = struct
include Mono
@ -202,16 +222,24 @@ struct
include Poly
include S0
(** hide S0.trm and S0.trms that ignore the embedding, shadowed below *)
let[@warning "-32"] trm, trms = ((), ())
let pp = ppx Trm.pp
(** Embed a polynomial into a term *)
let trm_of_poly = Embed.to_trm
(** Embed a monomial into a term, flattening if possible *)
let trm_of_mono mono =
match Mono.get_trm mono with
| Some trm -> trm
| None -> Embed.to_trm (Sum.of_ mono Q.one)
| None -> trm_of_poly (Sum.of_ mono Q.one)
(* traverse *)
let vars poly = Iter.flat_map ~f:Trm.vars (S0.trms poly)
let monos poly =
Iter.from_iter (fun f ->
Sum.iter poly ~f:(fun mono _ ->
@ -222,10 +250,6 @@ struct
| Some mono -> Mono.trms mono
| None -> Iter.map ~f:trm_of_mono (monos poly)
let vars p = Iter.flat_map ~f:Trm.vars (trms p)
(* invariant *)
let mono_invariant mono =
mono_invariant mono ;
let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in
@ -309,7 +333,7 @@ struct
let trm trm =
( match Embed.get_arith trm with
| Some poly -> poly
| None -> Sum.of_ (Mono.of_ trm 1) Q.one )
| None -> S0.trm trm )
|> check (fun poly ->
assert (equal poly (CM.to_poly (CM.of_trm trm))) )
@ -322,7 +346,7 @@ struct
let pow base power = CM.to_poly (CM.of_trm base ~power)
(* map over [trms] *)
(** map over [trms] *)
let map poly ~f =
[%trace]
~call:(fun {pf} -> pf "@ %a" pp poly)
@ -398,7 +422,7 @@ struct
| Some _ as soln -> soln
| None -> solve_poly (Sum.add mono coeff rejected) poly
(* solve [0 = e] *)
(** solve [0 = e] *)
let solve_zero_eq ?for_ e =
[%trace]
~call:(fun {pf} ->

@ -9,7 +9,25 @@
include module type of Arithmetic_intf
module Representation
(Var : Var_intf.S)
(Indeterminate : INDETERMINATE with type var := Var.t) :
REPRESENTATION with type var := Var.t with type trm := Indeterminate.t
(** Arithmetic terms, e.g. polynomials, polymorphic in the type of
indeterminates. *)
type ('trm, 'cmp) t [@@deriving compare, equal, sexp]
(** Functor that, given a totally ordered type of indeterminate terms,
builds an implementation of the embedding-independent arithmetic
operations, and a functor that, given an embedding of arithmetic terms
into indeterminate terms, builds an implementation of the arithmetic
operations. *)
module Make (Ord : sig
type t [@@deriving equal, sexp]
include Comparer.S with type t := t
end) : sig
include S0 with type t = (Ord.t, Ord.compare) t with type trm := Ord.t
module Embed
(Var : Var_intf.S)
(Trm : TRM with type t = Ord.t with type var := Var.t)
(_ : EMBEDDING with type trm := Trm.t and type t := t) :
S with type trm := Trm.t with type t = t
end

@ -14,31 +14,29 @@ module type EMBEDDING = sig
val to_trm : t -> trm
(** Embedding from [t] to [trm]: [to_trm a] is arithmetic term [a]
embedded in an indeterminate term. *)
embedded into an indeterminate term. *)
val get_arith : trm -> t option
(** Partial projection from [trm] to [t]: [get_arith x] is [Some a] if
[x = to_trm a]. This is used to flatten indeterminates that are
actually arithmetic for the client, thereby enabling arithmetic
operations to be interpreted more often. *)
(** Partial projection from [trm] to [t]: [get_arith x] is [Some a] iff
[x = to_trm a]. *)
end
(** Indeterminate terms, treated as atomic / variables except when they can
be flattened using [EMBEDDING.get_arith]. *)
module type INDETERMINATE = sig
type t [@@deriving compare, equal, sexp]
be flattened using {!EMBEDDING.get_arith}. *)
module type TRM = sig
include Comparer.S
include Comparer.S with type t := t
val pp : t pp
type var
val pp : t pp
val vars : t -> var iter
end
module type S = sig
(** Arithmetic terms, e.g. polynomials [t] over indeterminate terms [trm] *)
module type S0 = sig
type trm
type t [@@deriving compare, equal, sexp]
type t [@@deriving compare, equal]
val ppx : trm pp -> t pp
@ -48,8 +46,10 @@ module type S = sig
(** [trm x] represents the indeterminate term [x] *)
val const : Q.t -> t
(** [const q] represents the constant [q] *)
(** [const q] represents the rational constant [q] *)
val zero : t
val one : t
val neg : t -> t
val add : t -> t -> t
val sub : t -> t -> t
@ -84,6 +84,22 @@ module type S = sig
val is_uninterpreted : t -> bool
(** [is_uninterpreted a] iff [classify a = Uninterpreted] *)
(** Traverse *)
val trms : t -> trm iter
(** [trms a] enumerates the indeterminate terms appearing in [a].
Considering an arithmetic term as a polynomial,
[trms (c × (Σ c × Π
X^p))] is the sequence of terms [X] for each [i] and
[j]. *)
end
(** Arithmetic terms, where an embedding {!EMBEDDING.get_arith} into
indeterminate terms is used to implicitly flatten arithmetic terms that
are embedded into general terms to the underlying arithmetic term. *)
module type S = sig
include S0
(** Construct nonlinear arithmetic terms using the embedding, enabling
interpreting associativity, commutatitivity, and unit laws, but not
the full nonlinear arithmetic theory. *)
@ -95,8 +111,8 @@ module type S = sig
(** Traverse *)
val trms : t -> trm iter
(** [trms a] is the maximal foreign or noninterpreted proper subterms of
[a]. Considering an arithmetic term as a polynomial,
(** [trms a] enumerates the maximal foreign or noninterpreted proper
subterms of [a]. Considering an arithmetic term as a polynomial,
[trms (c × (Σ c × Π
X^p))] is the sequence of monomials
[Π X^p] for each [i]. If the arithmetic
@ -115,16 +131,3 @@ module type S = sig
expressed as [e = f] for some monomial subterm [e] of [d]. If [for_]
is passed, then the subterm [e] must be [for_]. *)
end
(** A type [t] representing arithmetic terms over indeterminates [trm]
together with a functor [Make] that takes an [EMBEDDING] of arithmetic
terms [t] into indeterminates [trm] and produces an implementation of
the primary interface [S]. *)
module type REPRESENTATION = sig
type t [@@deriving compare, equal, sexp]
type var
type trm
module Make (_ : EMBEDDING with type trm := trm and type t := t) :
S with type trm := trm with type t := t
end

@ -123,9 +123,7 @@ end = struct
[not (is_valid_eq xs e f)] implies [not (is_valid_eq ys e f)] for
[ys xs]. *)
let is_valid_eq xs e f =
let is_var_in xs e =
Option.exists ~f:(fun x -> Var.Set.mem x xs) (Var.of_trm e)
in
let is_var_in xs e = Trm.Set.mem e (xs : Var.Set.t :> Trm.Set.t) in
let noninterp_with_solvable_var_in xs e =
is_var_in xs e
|| Theory.is_noninterpreted e
@ -922,10 +920,11 @@ let trim ks x =
Cls.add rep (Option.value cls0 ~default:Cls.empty) ) )
in
(* enumerate expanded classes and update solution subst *)
let kills = Trm.Set.of_vars ks in
Trm.Map.fold clss x ~f:(fun ~key:a' ~data:ecls x ->
(* remove mappings for non-rep class elements to kill *)
let keep, drop = Trm.Set.diff_inter (Cls.to_set ecls) kills in
let keep, drop =
Trm.Set.diff_inter (Cls.to_set ecls) (ks : Var.Set.t :> Trm.Set.t)
in
if Trm.Set.is_empty drop then x
else
let rep = Trm.Set.fold ~f:Subst.remove drop x.rep in

@ -196,7 +196,7 @@ let solve d e s =
| Some ((Sized {siz= n; seq= Splat _} as b), Concat a0V) ->
solve_concat a0V b n s
| Some ((Var _ as v), (Concat a0V as c)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv c)) then
if not (Trm.Set.mem v (Trm.fv c :> Trm.Set.t)) then
(* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *)
add_solved ~var:v ~rep:c s
else
@ -212,7 +212,7 @@ let solve d e s =
* Extract
*)
| Some ((Var _ as v), (Extract {len= l} as e)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv e)) then
if not (Trm.Set.mem v (Trm.fv e :> Trm.Set.t)) then
(* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *)
add_solved ~var:v ~rep:e s
else

@ -7,45 +7,20 @@
(** Terms *)
(** Representation of Arithmetic terms *)
module rec Arith0 :
(Arithmetic.REPRESENTATION
with type var := Trm.Var1.t
with type trm := Trm.t) =
Arithmetic.Representation
(Trm.Var1)
(struct
include Trm
include Comparer.Make (Trm)
end)
(* Define term type using polymorphic arithmetic type, with derived compare,
equal, and sexp_of functions *)
module Trm1 = struct
type compare [@@deriving compare, equal, sexp]
(** Arithmetic terms *)
and Arith : (Arithmetic.S with type trm := Trm.t with type t = Arith0.t) =
struct
include Arith0
include Make (struct
let to_trm = Trm._Arith
let get_arith (e : Trm.t) =
match e with
| Z z -> Some (Arith.const (Q.of_z z))
| Q q -> Some (Arith.const q)
| Arith a -> Some a
| _ -> None
end)
end
type arith = (t, compare) Arithmetic.t
(** Terms, built from variables and applications of function symbols from
various theories. Denote functions from structures to values. *)
and Trm : sig
type t = private
and t =
(* variables *)
| Var of {id: int; name: string}
| Var of {id: int; name: string [@ignore]}
(* arithmetic *)
| Z of Z.t
| Q of Q.t
| Arith of Arith.t
| Arith of arith
(* sequences (of flexible size) *)
| Splat of t
| Sized of {seq: t; siz: t}
@ -54,50 +29,20 @@ and Trm : sig
(* uninterpreted *)
| Apply of Funsym.t * t array
[@@deriving compare, equal, sexp]
end
(** Variable terms, represented as a subtype of general terms *)
module Var1 : sig
type trm := t
include Var_intf.S with type t = private trm
(* Add comparer, needed to instantiate arithmetic and containers *)
module Trm2 = struct
include Comparer.Counterfeit (Trm1)
include Trm1
end
val of_ : trm -> t
val of_trm : trm -> t option
end
(* Specialize arithmetic type and define operations using comparer *)
module Arith0 = Arithmetic.Make (Trm2)
val ppx : Var1.strength -> t pp
val pp : t pp
include Invariant.S with type t := t
val _Var : int -> string -> t
val _Z : Z.t -> t
val _Q : Q.t -> t
val _Arith : Arith.t -> t
val _Splat : t -> t
val _Sized : seq:t -> siz:t -> t
val _Extract : seq:t -> off:t -> len:t -> t
val _Concat : t array -> t
val _Apply : Funsym.t -> t array -> t
val add : t -> t -> t
val sub : t -> t -> t
val seq_size_exn : t -> t
val seq_size : t -> t option
val get_z : t -> Z.t option
val get_q : t -> Q.t option
val vars : t -> Var1.t iter
end = struct
type t =
| Var of {id: int; name: string [@ignore]}
| Z of Z.t
| Q of Q.t
| Arith of Arith.t
| Splat of t
| Sized of {seq: t; siz: t}
| Extract of {seq: t; off: t; len: t}
| Concat of t array
| Apply of Funsym.t * t array
[@@deriving compare, equal, sexp]
(* Add ppx, defined recursively with Arith0.ppx *)
module Trm3 = struct
include Trm2
(* nul-terminated string value represented by a concatenation *)
let string_of_concat xs =
@ -136,7 +81,7 @@ end = struct
| Some `Anonymous -> Trace.pp_styled `Cyan "_" fs )
| Z z -> Trace.pp_styled `Magenta "%a" fs Z.pp z
| Q q -> Trace.pp_styled `Magenta "%a" fs Q.pp q
| Arith a -> Arith.ppx (ppx strength) fs a
| Arith a -> Arith0.ppx (ppx strength) fs a
| Splat x -> pf "%a^" pp x
| Sized {seq; siz} -> pf "@<1>⟨%a,%a@<1>⟩" pp siz pp seq
| Extract {seq; off; len} -> pf "%a[%a,%a)" pp seq pp off pp len
@ -157,255 +102,97 @@ end = struct
pp fs trm
let pp = ppx (fun _ -> None)
let pp_diff fs (x, y) = Format.fprintf fs "-- %a ++ %a" pp x pp y
end
(* Define variables as a subtype of terms *)
module Var1 = struct
module V = struct
module T = struct
type nonrec t = t [@@deriving compare, equal, sexp]
type strength = t -> [`Universal | `Existential | `Anonymous] option
let pp = pp
let ppx = ppx
end
include T
let invariant x =
let@ () = Invariant.invariant [%here] x [%sexp_of: t] in
match x with
| Var _ -> ()
| _ -> fail "non-var: %a" Sexp.pp_hum (sexp_of_t x) ()
let make ~id ~name = Var {id; name} |> check invariant
let id = function Var v -> v.id | x -> violates invariant x
let name = function Var v -> v.name | x -> violates invariant x
module Set = struct
module S = NS.Set.Make (T)
include S
include Provide_of_sexp (T)
include Provide_pp (T)
let ppx strength vs = S.pp_full (ppx strength) vs
let pp_xs fs xs =
if not (is_empty xs) then
Format.fprintf fs "@<2>∃ @[%a@] .@;<1 2>" pp xs
end
module Map = struct
include NS.Map.Make (T)
include Provide_of_sexp (T)
end
let fresh name ~wrt =
let max =
match Set.max_elt wrt with None -> 0 | Some m -> max 0 (id m)
in
let x' = make ~id:(max + 1) ~name in
(x', Set.add x' wrt)
let freshen v ~wrt = fresh (name v) ~wrt
let program ?(name = "") ~id =
assert (id > 0) ;
make ~id:(-id) ~name
let identified ~name ~id = make ~id ~name
let of_ v = v |> check invariant
let of_trm = function Var _ as v -> Some v | _ -> None
end
(* Define containers over terms *)
module Set = struct
include Set.Make (Trm3)
include Provide_of_sexp (Trm3)
include Provide_pp (Trm3)
end
include V
module Subst = Subst.Make (V)
end
module Map = struct
include Map.Make (Trm3)
include Provide_of_sexp (Trm3)
end
let invariant e =
let@ () = Invariant.invariant [%here] e [%sexp_of: t] in
match e with
| Q q -> assert (not (Z.equal Z.one (Q.den q)))
| Arith a -> (
match Arith.classify a with
| Trm _ | Const _ -> assert false
| _ -> () )
| _ -> ()
(* Define variables as a subtype of terms *)
module Var = struct
open Trm3
(** Destruct *)
module V = struct
type nonrec t = t [@@deriving compare, equal, sexp]
type strength = t -> [`Universal | `Existential | `Anonymous] option
let get_z = function Z z -> Some z | _ -> None
let get_q = function Q q -> Some q | Z z -> Some (Q.of_z z) | _ -> None
let pp = pp
let ppx = ppx
(** Construct *)
let invariant x =
let@ () = Invariant.invariant [%here] x [%sexp_of: t] in
match x with
| Var _ -> ()
| _ -> fail "non-var: %a" Sexp.pp_hum (sexp_of_t x) ()
let _Var id name = Var {id; name} |> check invariant
let make ~id ~name = Var {id; name} |> check invariant
let id = function Var v -> v.id | x -> violates invariant x
let name = function Var v -> v.name | x -> violates invariant x
(* statically allocated since they are tested with == *)
let zero = Z Z.zero |> check invariant
let one = Z Z.one |> check invariant
module Set = struct
include Set
let _Z z =
(if Z.equal Z.zero z then zero else if Z.equal Z.one z then one else Z z)
|> check invariant
let ppx strength vs = pp_full (ppx strength) vs
let _Q q =
(if Z.equal Z.one (Q.den q) then _Z (Q.num q) else Q q)
|> check invariant
let pp_xs fs xs =
if not (is_empty xs) then
Format.fprintf fs "@<2>∃ @[%a@] .@;<1 2>" pp xs
end
let _Arith a =
( match Arith.classify a with
| Trm e -> e
| Const q -> _Q q
| _ -> Arith a )
|> check invariant
module Map = Map
let add x y = _Arith Arith.(add (trm x) (trm y))
let sub x y = _Arith Arith.(sub (trm x) (trm y))
let fresh name ~wrt =
let max =
match Set.max_elt wrt with None -> 0 | Some m -> max 0 (id m)
in
let x' = make ~id:(max + 1) ~name in
(x', Set.add x' wrt)
let _Splat x =
(* 0^ ==> 0 *)
(if x == zero then x else Splat x) |> check invariant
let freshen v ~wrt = fresh (name v) ~wrt
let seq_size_exn =
let invalid = Invalid_argument "seq_size_exn" in
let rec seq_size_exn = function
| Sized {siz= n} | Extract {len= n} -> n
| Concat a0U ->
Array.fold ~f:(fun aJ a0I -> add a0I (seq_size_exn aJ)) a0U zero
| _ -> raise invalid
in
seq_size_exn
let program ?(name = "") ~id =
assert (id > 0) ;
make ~id:(-id) ~name
let seq_size e =
try Some (seq_size_exn e) with Invalid_argument _ -> None
let identified ~name ~id = make ~id ~name
let of_ v = v |> check invariant
let of_trm = function Var _ as v -> Some v | _ -> None
end
let _Sized ~seq ~siz =
( match seq_size seq with
(* ⟨n,α⟩ ==> α when n ≡ |α| *)
| Some n when equal siz n -> seq
| _ -> Sized {seq; siz} )
|> check invariant
include V
module Subst = Subst.Make (V)
end
let partial_compare x y =
match sub x y with
| Z z -> Some (Int.sign (Z.sign z))
| Q q -> Some (Int.sign (Q.sign q))
| _ -> None
(* Add definitions needed for arithmetic embedding into terms *)
module Trm = struct
include Trm3
let partial_ge x y =
match partial_compare x y with Some (Pos | Zero) -> true | _ -> false
let empty_seq = Concat [||]
let rec _Extract ~seq ~off ~len =
[%trace]
~call:(fun {pf} -> pf "@ %a" pp (Extract {seq; off; len}))
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
(* _[_,0) ==> ⟨⟩ *)
( if equal len zero then empty_seq
else
let o_l = add off len in
match seq with
(* α[m,k)[o,l) ==> α[m+o,l) when k ≥ o+l *)
| Extract {seq= a; off= m; len= k} when partial_ge k o_l ->
_Extract ~seq:a ~off:(add m off) ~len
(* ⟨n,0⟩[o,l) ==> ⟨l,0⟩ when n ≥ o+l *)
| Sized {siz= n; seq} when seq == zero && partial_ge n o_l ->
_Sized ~seq ~siz:len
(* ⟨n,E^⟩[o,l) ==> ⟨l,E^⟩ when n ≥ o+l *)
| Sized {siz= n; seq= Splat _ as e} when partial_ge n o_l ->
_Sized ~seq:e ~siz:len
(* ⟨n,a⟩[0,n) ==> ⟨n,a⟩ *)
| Sized {siz= n} when equal off zero && equal n len -> seq
(* For (α₀^α₁)[o,l) there are 3 cases:
*
* ...^...
* [,)
* o < o+l |α| : (α^α)[o,l) ==> α[o,l) ^ α[0,0)
*
* ...^...
* [ , )
* o |α| < o+l : (α^α)[o,l) ==> α[o,|α|-o) ^ α[0,l-(|α|-o))
*
* ...^...
* [,)
* |α| o : (α^α)[o,l) ==> α[o,0) ^ α[o-|α|,l)
*
* So in general:
*
* (α^α)[o,l) ==> α[o,l) ^ α[o,l-l)
* where l = max 0 (min l |α|-o)
* o = max 0 o-|α|
*)
| Concat na1N -> (
match len with
| Z l ->
Array.fold_map_until na1N (l, off)
~f:(fun naI (l, oI) ->
if Z.equal Z.zero l then
`Continue (_Extract ~seq:naI ~off:oI ~len:zero, (l, oI))
else
let nI = seq_size_exn naI in
let oI_nI = sub oI nI in
match oI_nI with
| Z z ->
let oJ = if Z.sign z <= 0 then zero else oI_nI in
let lI = Z.(max zero (min l (neg z))) in
let l = Z.(l - lI) in
`Continue
(_Extract ~seq:naI ~off:oI ~len:(_Z lI), (l, oJ))
| _ -> `Stop (Extract {seq; off; len}) )
~finish:(fun (e1N, _) -> _Concat e1N)
| _ -> Extract {seq; off; len} )
(* α[o,l) *)
| _ -> Extract {seq; off; len} )
|> check invariant
(** Invariant *)
and _Concat xs =
[%trace]
~call:(fun {pf} -> pf "@ %a" pp (Concat xs))
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
(* (α^(β^γ)) ==> (α^β^γ) *)
let flatten xs =
if Array.exists ~f:(function Concat _ -> true | _ -> false) xs then
Array.flat_map ~f:(function Concat s -> s | e -> [|e|]) xs
else xs
in
let simp_adjacent e f =
match (e, f) with
(* ⟨n,a⟩[o,k)^⟨n,a⟩[o+k,l) ==> ⟨n,a⟩[o,k+l) when n ≥ o+k+l *)
| ( Extract {seq= Sized {siz= n} as na; off= o; len= k}
, Extract {seq= na'; off= o_k; len= l} )
when equal na na' && equal o_k (add o k) && partial_ge n (add o_k l)
->
Some (_Extract ~seq:na ~off:o ~len:(add k l))
(* ⟨m,0⟩^⟨n,0⟩ ==> ⟨m+n,0⟩ *)
| Sized {siz= m; seq= a}, Sized {siz= n; seq= a'}
when a == zero && a' == zero ->
Some (_Sized ~seq:a ~siz:(add m n))
(* ⟨m,E^⟩^⟨n,E^⟩ ==> ⟨m+n,E^⟩ *)
| Sized {siz= m; seq= Splat _ as a}, Sized {siz= n; seq= a'}
when equal a a' ->
Some (_Sized ~seq:a ~siz:(add m n))
| _ -> None
in
let xs = flatten xs in
let xs = Array.reduce_adjacent ~f:simp_adjacent xs in
(if Array.length xs = 1 then xs.(0) else Concat xs) |> check invariant
let _Apply f es =
( match Funsym.eval ~equal ~get_z ~ret_z:_Z ~get_q ~ret_q:_Q f es with
| Some c -> c
| None -> Apply (f, es) )
|> check invariant
let invariant e =
let@ () = Invariant.invariant [%here] e [%sexp_of: t] in
match e with
| Q q -> assert (not (Z.equal Z.one (Q.den q)))
| Arith a -> (
match Arith0.classify a with
| Trm _ | Const _ -> assert false
| _ -> () )
| _ -> ()
(** Traverse *)
let rec iter_vars e ~f =
match e with
| Var _ as v -> f (Var1.of_ v)
| Var _ as v -> f (Var.of_ v)
| Z _ | Q _ -> ()
| Splat x -> iter_vars ~f x
| Sized {seq= x; siz= y} ->
@ -416,33 +203,55 @@ end = struct
iter_vars ~f y ;
iter_vars ~f z
| Concat xs | Apply (_, xs) -> Array.iter ~f:(iter_vars ~f) xs
| Arith a -> Iter.iter ~f:(iter_vars ~f) (Arith.trms a)
| Arith a -> Iter.iter ~f:(iter_vars ~f) (Arith0.trms a)
let vars e = Iter.from_labelled_iter (iter_vars e)
(** Construct *)
(* statically allocated since they are tested with == *)
let zero = Z Z.zero |> check invariant
let one = Z Z.one |> check invariant
let _Z z =
(if Z.equal Z.zero z then zero else if Z.equal Z.one z then one else Z z)
|> check invariant
let _Q q =
(if Z.equal Z.one (Q.den q) then _Z (Q.num q) else Q q)
|> check invariant
let _Arith a =
( match Arith0.classify a with
| Trm e -> e
| Const q -> _Q q
| _ -> Arith a )
|> check invariant
end
include Trm
module Var = Var1
module Set = struct
include Set.Make (Trm)
include Provide_of_sexp (Trm)
include Provide_pp (Trm)
let of_vars : Var.Set.t -> t =
fun vs ->
of_iter
(Iter.map ~f:(fun v -> (v : Var.t :> Trm.t)) (Var.Set.to_iter vs))
end
(* Instantiate arithmetic with embedding into terms, yielding full
Arithmetic interface *)
module Arith =
Arith0.Embed (Var) (Trm)
(struct
let to_trm = _Arith
let get_arith e =
match e with
| Z z -> Some (Arith0.const (Q.of_z z))
| Q q -> Some (Arith0.const q)
| Arith a -> Some a
| _ -> None
end)
module Map = struct
include Map.Make (Trm)
include Provide_of_sexp (Trm)
end
(* Full Trm definition, using full arithmetic interface *)
type arith = Arith.t
(** Destruct *)
let pp_diff fs (x, y) = Format.fprintf fs "-- %a ++ %a" pp x pp y
let get_z = function Z z -> Some z | _ -> None
let get_q = function Q q -> Some q | Z z -> Some (Q.of_z z) | _ -> None
(** Construct *)
@ -452,13 +261,11 @@ let var v = (v : Var.t :> t)
(* arithmetic *)
let zero = _Z Z.zero
let one = _Z Z.one
let integer z = _Z z
let rational q = _Q q
let neg x = _Arith Arith.(neg (trm x))
let add = Trm.add
let sub = Trm.sub
let add x y = _Arith Arith.(add (trm x) (trm y))
let sub x y = _Arith Arith.(sub (trm x) (trm y))
let mulq q x = _Arith Arith.(mulc q (trm x))
let mul x y = _Arith (Arith.mul x y)
let div x y = _Arith (Arith.div x y)
@ -467,14 +274,145 @@ let arith = _Arith
(* sequences *)
let splat = _Splat
let sized = _Sized
let extract = _Extract
let concat elts = _Concat elts
let splat x =
(* 0^ ==> 0 *)
(if x == zero then x else Splat x) |> check invariant
let seq_size_exn =
let invalid = Invalid_argument "seq_size_exn" in
let rec seq_size_exn = function
| Sized {siz= n} | Extract {len= n} -> n
| Concat a0U ->
Array.fold ~f:(fun aJ a0I -> add a0I (seq_size_exn aJ)) a0U zero
| _ -> raise invalid
in
seq_size_exn
let seq_size e = try Some (seq_size_exn e) with Invalid_argument _ -> None
let sized ~seq ~siz =
( match seq_size seq with
(* ⟨n,α⟩ ==> α when n ≡ |α| *)
| Some n when equal siz n -> seq
| _ -> Sized {seq; siz} )
|> check invariant
let partial_compare x y =
match sub x y with
| Z z -> Some (Int.sign (Z.sign z))
| Q q -> Some (Int.sign (Q.sign q))
| _ -> None
let partial_ge x y =
match partial_compare x y with Some (Pos | Zero) -> true | _ -> false
let empty_seq = Concat [||]
let rec extract ~seq ~off ~len =
[%trace]
~call:(fun {pf} -> pf "@ %a" pp (Extract {seq; off; len}))
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
(* _[_,0) ==> ⟨⟩ *)
( if equal len zero then empty_seq
else
let o_l = add off len in
match seq with
(* α[m,k)[o,l) ==> α[m+o,l) when k ≥ o+l *)
| Extract {seq= a; off= m; len= k} when partial_ge k o_l ->
extract ~seq:a ~off:(add m off) ~len
(* ⟨n,0⟩[o,l) ==> ⟨l,0⟩ when n ≥ o+l *)
| Sized {siz= n; seq} when seq == zero && partial_ge n o_l ->
sized ~seq ~siz:len
(* ⟨n,E^⟩[o,l) ==> ⟨l,E^⟩ when n ≥ o+l *)
| Sized {siz= n; seq= Splat _ as e} when partial_ge n o_l ->
sized ~seq:e ~siz:len
(* ⟨n,a⟩[0,n) ==> ⟨n,a⟩ *)
| Sized {siz= n} when equal off zero && equal n len -> seq
(* For (α₀^α₁)[o,l) there are 3 cases:
*
* ...^...
* [,)
* o < o+l |α| : (α^α)[o,l) ==> α[o,l) ^ α[0,0)
*
* ...^...
* [ , )
* o |α| < o+l : (α^α)[o,l) ==> α[o,|α|-o) ^ α[0,l-(|α|-o))
*
* ...^...
* [,)
* |α| o : (α^α)[o,l) ==> α[o,0) ^ α[o-|α|,l)
*
* So in general:
*
* (α^α)[o,l) ==> α[o,l) ^ α[o,l-l)
* where l = max 0 (min l |α|-o)
* o = max 0 o-|α|
*)
| Concat na1N -> (
match len with
| Z l ->
Array.fold_map_until na1N (l, off)
~f:(fun naI (l, oI) ->
if Z.equal Z.zero l then
`Continue (extract ~seq:naI ~off:oI ~len:zero, (l, oI))
else
let nI = seq_size_exn naI in
let oI_nI = sub oI nI in
match oI_nI with
| Z z ->
let oJ = if Z.sign z <= 0 then zero else oI_nI in
let lI = Z.(max zero (min l (neg z))) in
let l = Z.(l - lI) in
`Continue
(extract ~seq:naI ~off:oI ~len:(_Z lI), (l, oJ))
| _ -> `Stop (Extract {seq; off; len}) )
~finish:(fun (e1N, _) -> concat e1N)
| _ -> Extract {seq; off; len} )
(* α[o,l) *)
| _ -> Extract {seq; off; len} )
|> check invariant
and concat xs =
[%trace]
~call:(fun {pf} -> pf "@ %a" pp (Concat xs))
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
(* (α^(β^γ)) ==> (α^β^γ) *)
let flatten xs =
if Array.exists ~f:(function Concat _ -> true | _ -> false) xs then
Array.flat_map ~f:(function Concat s -> s | e -> [|e|]) xs
else xs
in
let simp_adjacent e f =
match (e, f) with
(* ⟨n,a⟩[o,k)^⟨n,a⟩[o+k,l) ==> ⟨n,a⟩[o,k+l) when n ≥ o+k+l *)
| ( Extract {seq= Sized {siz= n} as na; off= o; len= k}
, Extract {seq= na'; off= o_k; len= l} )
when equal na na' && equal o_k (add o k) && partial_ge n (add o_k l)
->
Some (extract ~seq:na ~off:o ~len:(add k l))
(* ⟨m,0⟩^⟨n,0⟩ ==> ⟨m+n,0⟩ *)
| Sized {siz= m; seq= a}, Sized {siz= n; seq= a'}
when a == zero && a' == zero ->
Some (sized ~seq:a ~siz:(add m n))
(* ⟨m,E^⟩^⟨n,E^⟩ ==> ⟨m+n,E^⟩ *)
| Sized {siz= m; seq= Splat _ as a}, Sized {siz= n; seq= a'}
when equal a a' ->
Some (sized ~seq:a ~siz:(add m n))
| _ -> None
in
let xs = flatten xs in
let xs = Array.reduce_adjacent ~f:simp_adjacent xs in
(if Array.length xs = 1 then xs.(0) else Concat xs) |> check invariant
(* uninterpreted *)
let apply sym args = _Apply sym args
let apply f es =
( match Funsym.eval ~equal ~get_z ~ret_z:_Z ~get_q ~ret_q:_Q f es with
| Some c -> c
| None -> Apply (f, es) )
|> check invariant
(** Traverse *)
@ -505,25 +443,25 @@ let rec map_vars e ~f =
| Var _ as v -> (f (Var.of_ v) : Var.t :> t)
| Z _ | Q _ -> e
| Arith a -> map1 (Arith.map ~f:(map_vars ~f)) e _Arith a
| Splat x -> map1 (map_vars ~f) e _Splat x
| Splat x -> map1 (map_vars ~f) e splat x
| Sized {seq; siz} ->
map2 (map_vars ~f) e (fun seq siz -> _Sized ~seq ~siz) seq siz
map2 (map_vars ~f) e (fun seq siz -> sized ~seq ~siz) seq siz
| Extract {seq; off; len} ->
map3 (map_vars ~f) e
(fun seq off len -> _Extract ~seq ~off ~len)
(fun seq off len -> extract ~seq ~off ~len)
seq off len
| Concat xs -> mapN (map_vars ~f) e _Concat xs
| Apply (g, xs) -> mapN (map_vars ~f) e (_Apply g) xs
| Concat xs -> mapN (map_vars ~f) e concat xs
| Apply (g, xs) -> mapN (map_vars ~f) e (apply g) xs
let map e ~f =
match e with
| Var _ | Z _ | Q _ -> e
| Arith a -> map1 (Arith.map ~f) e _Arith a
| Splat x -> map1 f e _Splat x
| Sized {seq; siz} -> map2 f e (fun seq siz -> _Sized ~seq ~siz) seq siz
| Splat x -> map1 f e splat x
| Sized {seq; siz} -> map2 f e (fun seq siz -> sized ~seq ~siz) seq siz
| Extract {seq; off; len} ->
map3 f e (fun seq off len -> _Extract ~seq ~off ~len) seq off len
| Concat xs -> mapN f e _Concat xs
| Apply (g, xs) -> mapN f e (_Apply g) xs
map3 f e (fun seq off len -> extract ~seq ~off ~len) seq off len
| Concat xs -> mapN f e concat xs
| Apply (g, xs) -> mapN f e (apply g) xs
let fold_map e = fold_map_from_map map e

@ -9,9 +9,11 @@
type arith
(** Terms, built from variables and applications of function symbols from
various theories. Denote functions from structures to values. *)
type t = private
(* variables *)
| Var of {id: int; name: string}
| Var of {id: int; name: string [@ignore]}
(* arithmetic *)
| Z of Z.t
| Q of Q.t
@ -25,24 +27,15 @@ type t = private
| Apply of Funsym.t * t array
[@@deriving compare, equal, sexp]
(** Arithmetic terms *)
module Arith : Arithmetic.S with type trm := t with type t = arith
module Var : sig
type trm := t
include Var_intf.S with type t = private trm
val of_ : trm -> t
val of_trm : trm -> t option
end
module Set : sig
include Set.S with type elt := t
val t_of_sexp : Sexp.t -> t
val pp : t pp
val pp_diff : (t * t) pp
val of_vars : Var.Set.t -> t
end
module Map : sig
@ -51,6 +44,16 @@ module Map : sig
val t_of_sexp : (Sexp.t -> 'a) -> Sexp.t -> 'a t
end
(** Variable terms, represented as a subtype of general terms *)
module Var : sig
type trm := t
include
Var_intf.S with type t = private trm with type Set.t = private Set.t
val of_trm : trm -> t option
end
val ppx : Var.strength -> t pp
val pp : t pp
val pp_diff : (t * t) pp

Loading…
Cancel
Save