[sledge] Rework term and arithmetic definitions to avoid recursive modules

Summary:
Terms include Arithmetic terms, which are polynomials over terms
themselves. Monomials are represented as maps from
terms (multiplicative factors) to integers (their powers). Polynomials
are represented as maps from monomials (indeterminates) to
rationals (coefficients). In particular, terms are represented using
maps whose keys are terms themselves. This is currently implemented
using recursive modules.

This diff uses the Comparer-based interface of Maps to express this
cycle as recursive *types* rather than recursive *modules*, see the
very beginning of trm.ml. The rest of the changes are driven by the
need to expose the Arithmetic.t type at toplevel, outside the functor
that defines the arithmetic operations, and changes to stage the
definition of term and polynomial operations to remove unnecessary
recursion.

One might hope that these changes are just moving code around, but due
to how recursive modules are implemented, this refactoring is
motivated by performance profiling. In every cycle between recursive
modules, at least one of the modules must be "safe". A "safe" module
is one where all exposed values have function type. This allows the
compiler to initialize that module with functions that immediately
raise an exception, define the other modules using it, and then tie
the recursive knot by backpatching the safe module with the actual
functions at the end. This implementation works, but has the
consequence that the compiler must treat calls to functions of safe
recursive modules as indirect calls to unknown functions. This means
that they are not inlined or even called by symbol, and instead
calling them involves spilling registers if needed, loading their
address from memory, calling them by address, and restoring any
spilled registers. For operations like Trm.compare that are a handful
of instructions on the hot path, this is a significant
difference. Since terms are the keys of maps and sets in the core of
the first-order equality solver, those map operations are very very
hot.

Reviewed By: jvillard

Differential Revision: D26250533

fbshipit-source-id: f79334c68
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent 24ca0666d3
commit 5704b160bf

@ -19,17 +19,27 @@ module Q = struct
include Comparer.Make (Q) include Comparer.Make (Q)
end end
module Representation type ('trm, 'compare_trm) mono = ('trm, int, 'compare_trm) Multiset.t
(Var : Var_intf.S) [@@deriving compare, equal, sexp]
(Trm : INDETERMINATE with type var := Var.t) =
type 'compare_trm compare_mono =
('compare_trm, Int.compare) Multiset.compare
[@@deriving compare, equal, sexp]
type ('trm, 'compare_trm) t =
(('trm, 'compare_trm) mono, Q.t, 'compare_trm compare_mono) Multiset.t
[@@deriving compare, equal, sexp]
module Make (Trm0 : sig
type t [@@deriving equal, sexp]
include Comparer.S with type t := t
end) =
struct struct
module Prod = struct module Prod = Multiset.Make (Trm0) (Int)
include Multiset.Make (Trm) (Int)
include Provide_of_sexp (Trm)
end
module Mono = struct module Mono = struct
type t = Prod.t [@@deriving compare, equal, sexp] type t = Prod.t [@@deriving compare, equal, sexp_of]
let num_den m = Prod.partition m ~f:(fun _ i -> i >= 0) let num_den m = Prod.partition m ~f:(fun _ i -> i >= 0)
@ -79,11 +89,7 @@ struct
Iter.from_iter (fun f -> Prod.iter mono ~f:(fun trm _ -> f trm)) Iter.from_iter (fun f -> Prod.iter mono ~f:(fun trm _ -> f trm))
end end
module Sum = struct module Sum = Multiset.Make (Prod) (Q)
include Multiset.Make (Prod) (Q)
include Provide_of_sexp (Prod)
end
module Poly = Sum module Poly = Sum
include Poly include Poly
@ -104,6 +110,11 @@ struct
(Sum.pp "@ + " pp_coeff_mono) (Sum.pp "@ + " pp_coeff_mono)
poly poly
let trms poly =
Iter.from_iter (fun f ->
Sum.iter poly ~f:(fun mono _ ->
Prod.iter mono ~f:(fun trm _ -> f trm) ) )
(* core invariant *) (* core invariant *)
let mono_invariant mono = let mono_invariant mono =
@ -119,10 +130,14 @@ struct
assert (not (Q.equal Q.zero coeff)) ; assert (not (Q.equal Q.zero coeff)) ;
mono_invariant mono ) mono_invariant mono )
(* embed a term into a polynomial *)
let trm trm = Sum.of_ (Mono.of_ trm 1) Q.one
(* constants *) (* constants *)
let const q = Sum.of_ Mono.one q |> check invariant let const q = Sum.of_ Mono.one q |> check invariant
let zero = const Q.zero |> check (fun p -> assert (Sum.is_empty p)) let zero = const Q.zero |> check (fun p -> assert (Sum.is_empty p))
let one = const Q.one
(* core constructors *) (* core constructors *)
@ -149,7 +164,7 @@ struct
(* projections and embeddings *) (* projections and embeddings *)
type kind = Trm of Trm.t | Const of Q.t | Interpreted | Uninterpreted type kind = Trm of Trm0.t | Const of Q.t | Interpreted | Uninterpreted
let classify poly = let classify poly =
match Sum.classify poly with match Sum.classify poly with
@ -189,7 +204,12 @@ struct
| None -> None | None -> None
end end
module Make (Embed : EMBEDDING with type trm := Trm.t and type t := t) = include S0
module Embed
(Var : Var_intf.S)
(Trm : TRM with type t = Trm0.t with type var := Var.t)
(Embed : EMBEDDING with type trm := Trm0.t with type t := t) =
struct struct
module Mono = struct module Mono = struct
include Mono include Mono
@ -202,16 +222,24 @@ struct
include Poly include Poly
include S0 include S0
(** hide S0.trm and S0.trms that ignore the embedding, shadowed below *)
let[@warning "-32"] trm, trms = ((), ())
let pp = ppx Trm.pp let pp = ppx Trm.pp
(** Embed a polynomial into a term *)
let trm_of_poly = Embed.to_trm
(** Embed a monomial into a term, flattening if possible *) (** Embed a monomial into a term, flattening if possible *)
let trm_of_mono mono = let trm_of_mono mono =
match Mono.get_trm mono with match Mono.get_trm mono with
| Some trm -> trm | Some trm -> trm
| None -> Embed.to_trm (Sum.of_ mono Q.one) | None -> trm_of_poly (Sum.of_ mono Q.one)
(* traverse *) (* traverse *)
let vars poly = Iter.flat_map ~f:Trm.vars (S0.trms poly)
let monos poly = let monos poly =
Iter.from_iter (fun f -> Iter.from_iter (fun f ->
Sum.iter poly ~f:(fun mono _ -> Sum.iter poly ~f:(fun mono _ ->
@ -222,10 +250,6 @@ struct
| Some mono -> Mono.trms mono | Some mono -> Mono.trms mono
| None -> Iter.map ~f:trm_of_mono (monos poly) | None -> Iter.map ~f:trm_of_mono (monos poly)
let vars p = Iter.flat_map ~f:Trm.vars (trms p)
(* invariant *)
let mono_invariant mono = let mono_invariant mono =
mono_invariant mono ; mono_invariant mono ;
let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in
@ -309,7 +333,7 @@ struct
let trm trm = let trm trm =
( match Embed.get_arith trm with ( match Embed.get_arith trm with
| Some poly -> poly | Some poly -> poly
| None -> Sum.of_ (Mono.of_ trm 1) Q.one ) | None -> S0.trm trm )
|> check (fun poly -> |> check (fun poly ->
assert (equal poly (CM.to_poly (CM.of_trm trm))) ) assert (equal poly (CM.to_poly (CM.of_trm trm))) )
@ -322,7 +346,7 @@ struct
let pow base power = CM.to_poly (CM.of_trm base ~power) let pow base power = CM.to_poly (CM.of_trm base ~power)
(* map over [trms] *) (** map over [trms] *)
let map poly ~f = let map poly ~f =
[%trace] [%trace]
~call:(fun {pf} -> pf "@ %a" pp poly) ~call:(fun {pf} -> pf "@ %a" pp poly)
@ -398,7 +422,7 @@ struct
| Some _ as soln -> soln | Some _ as soln -> soln
| None -> solve_poly (Sum.add mono coeff rejected) poly | None -> solve_poly (Sum.add mono coeff rejected) poly
(* solve [0 = e] *) (** solve [0 = e] *)
let solve_zero_eq ?for_ e = let solve_zero_eq ?for_ e =
[%trace] [%trace]
~call:(fun {pf} -> ~call:(fun {pf} ->

@ -9,7 +9,25 @@
include module type of Arithmetic_intf include module type of Arithmetic_intf
module Representation (** Arithmetic terms, e.g. polynomials, polymorphic in the type of
(Var : Var_intf.S) indeterminates. *)
(Indeterminate : INDETERMINATE with type var := Var.t) : type ('trm, 'cmp) t [@@deriving compare, equal, sexp]
REPRESENTATION with type var := Var.t with type trm := Indeterminate.t
(** Functor that, given a totally ordered type of indeterminate terms,
builds an implementation of the embedding-independent arithmetic
operations, and a functor that, given an embedding of arithmetic terms
into indeterminate terms, builds an implementation of the arithmetic
operations. *)
module Make (Ord : sig
type t [@@deriving equal, sexp]
include Comparer.S with type t := t
end) : sig
include S0 with type t = (Ord.t, Ord.compare) t with type trm := Ord.t
module Embed
(Var : Var_intf.S)
(Trm : TRM with type t = Ord.t with type var := Var.t)
(_ : EMBEDDING with type trm := Trm.t and type t := t) :
S with type trm := Trm.t with type t = t
end

@ -14,31 +14,29 @@ module type EMBEDDING = sig
val to_trm : t -> trm val to_trm : t -> trm
(** Embedding from [t] to [trm]: [to_trm a] is arithmetic term [a] (** Embedding from [t] to [trm]: [to_trm a] is arithmetic term [a]
embedded in an indeterminate term. *) embedded into an indeterminate term. *)
val get_arith : trm -> t option val get_arith : trm -> t option
(** Partial projection from [trm] to [t]: [get_arith x] is [Some a] if (** Partial projection from [trm] to [t]: [get_arith x] is [Some a] iff
[x = to_trm a]. This is used to flatten indeterminates that are [x = to_trm a]. *)
actually arithmetic for the client, thereby enabling arithmetic
operations to be interpreted more often. *)
end end
(** Indeterminate terms, treated as atomic / variables except when they can (** Indeterminate terms, treated as atomic / variables except when they can
be flattened using [EMBEDDING.get_arith]. *) be flattened using {!EMBEDDING.get_arith}. *)
module type INDETERMINATE = sig module type TRM = sig
type t [@@deriving compare, equal, sexp] include Comparer.S
include Comparer.S with type t := t val pp : t pp
type var type var
val pp : t pp
val vars : t -> var iter val vars : t -> var iter
end end
module type S = sig (** Arithmetic terms, e.g. polynomials [t] over indeterminate terms [trm] *)
module type S0 = sig
type trm type trm
type t [@@deriving compare, equal, sexp] type t [@@deriving compare, equal]
val ppx : trm pp -> t pp val ppx : trm pp -> t pp
@ -48,8 +46,10 @@ module type S = sig
(** [trm x] represents the indeterminate term [x] *) (** [trm x] represents the indeterminate term [x] *)
val const : Q.t -> t val const : Q.t -> t
(** [const q] represents the constant [q] *) (** [const q] represents the rational constant [q] *)
val zero : t
val one : t
val neg : t -> t val neg : t -> t
val add : t -> t -> t val add : t -> t -> t
val sub : t -> t -> t val sub : t -> t -> t
@ -84,6 +84,22 @@ module type S = sig
val is_uninterpreted : t -> bool val is_uninterpreted : t -> bool
(** [is_uninterpreted a] iff [classify a = Uninterpreted] *) (** [is_uninterpreted a] iff [classify a = Uninterpreted] *)
(** Traverse *)
val trms : t -> trm iter
(** [trms a] enumerates the indeterminate terms appearing in [a].
Considering an arithmetic term as a polynomial,
[trms (c × (Σ c × Π
X^p))] is the sequence of terms [X] for each [i] and
[j]. *)
end
(** Arithmetic terms, where an embedding {!EMBEDDING.get_arith} into
indeterminate terms is used to implicitly flatten arithmetic terms that
are embedded into general terms to the underlying arithmetic term. *)
module type S = sig
include S0
(** Construct nonlinear arithmetic terms using the embedding, enabling (** Construct nonlinear arithmetic terms using the embedding, enabling
interpreting associativity, commutatitivity, and unit laws, but not interpreting associativity, commutatitivity, and unit laws, but not
the full nonlinear arithmetic theory. *) the full nonlinear arithmetic theory. *)
@ -95,8 +111,8 @@ module type S = sig
(** Traverse *) (** Traverse *)
val trms : t -> trm iter val trms : t -> trm iter
(** [trms a] is the maximal foreign or noninterpreted proper subterms of (** [trms a] enumerates the maximal foreign or noninterpreted proper
[a]. Considering an arithmetic term as a polynomial, subterms of [a]. Considering an arithmetic term as a polynomial,
[trms (c × (Σ c × Π [trms (c × (Σ c × Π
X^p))] is the sequence of monomials X^p))] is the sequence of monomials
[Π X^p] for each [i]. If the arithmetic [Π X^p] for each [i]. If the arithmetic
@ -115,16 +131,3 @@ module type S = sig
expressed as [e = f] for some monomial subterm [e] of [d]. If [for_] expressed as [e = f] for some monomial subterm [e] of [d]. If [for_]
is passed, then the subterm [e] must be [for_]. *) is passed, then the subterm [e] must be [for_]. *)
end end
(** A type [t] representing arithmetic terms over indeterminates [trm]
together with a functor [Make] that takes an [EMBEDDING] of arithmetic
terms [t] into indeterminates [trm] and produces an implementation of
the primary interface [S]. *)
module type REPRESENTATION = sig
type t [@@deriving compare, equal, sexp]
type var
type trm
module Make (_ : EMBEDDING with type trm := trm and type t := t) :
S with type trm := trm with type t := t
end

@ -123,9 +123,7 @@ end = struct
[not (is_valid_eq xs e f)] implies [not (is_valid_eq ys e f)] for [not (is_valid_eq xs e f)] implies [not (is_valid_eq ys e f)] for
[ys xs]. *) [ys xs]. *)
let is_valid_eq xs e f = let is_valid_eq xs e f =
let is_var_in xs e = let is_var_in xs e = Trm.Set.mem e (xs : Var.Set.t :> Trm.Set.t) in
Option.exists ~f:(fun x -> Var.Set.mem x xs) (Var.of_trm e)
in
let noninterp_with_solvable_var_in xs e = let noninterp_with_solvable_var_in xs e =
is_var_in xs e is_var_in xs e
|| Theory.is_noninterpreted e || Theory.is_noninterpreted e
@ -922,10 +920,11 @@ let trim ks x =
Cls.add rep (Option.value cls0 ~default:Cls.empty) ) ) Cls.add rep (Option.value cls0 ~default:Cls.empty) ) )
in in
(* enumerate expanded classes and update solution subst *) (* enumerate expanded classes and update solution subst *)
let kills = Trm.Set.of_vars ks in
Trm.Map.fold clss x ~f:(fun ~key:a' ~data:ecls x -> Trm.Map.fold clss x ~f:(fun ~key:a' ~data:ecls x ->
(* remove mappings for non-rep class elements to kill *) (* remove mappings for non-rep class elements to kill *)
let keep, drop = Trm.Set.diff_inter (Cls.to_set ecls) kills in let keep, drop =
Trm.Set.diff_inter (Cls.to_set ecls) (ks : Var.Set.t :> Trm.Set.t)
in
if Trm.Set.is_empty drop then x if Trm.Set.is_empty drop then x
else else
let rep = Trm.Set.fold ~f:Subst.remove drop x.rep in let rep = Trm.Set.fold ~f:Subst.remove drop x.rep in

@ -196,7 +196,7 @@ let solve d e s =
| Some ((Sized {siz= n; seq= Splat _} as b), Concat a0V) -> | Some ((Sized {siz= n; seq= Splat _} as b), Concat a0V) ->
solve_concat a0V b n s solve_concat a0V b n s
| Some ((Var _ as v), (Concat a0V as c)) -> | Some ((Var _ as v), (Concat a0V as c)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv c)) then if not (Trm.Set.mem v (Trm.fv c :> Trm.Set.t)) then
(* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *) (* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *)
add_solved ~var:v ~rep:c s add_solved ~var:v ~rep:c s
else else
@ -212,7 +212,7 @@ let solve d e s =
* Extract * Extract
*) *)
| Some ((Var _ as v), (Extract {len= l} as e)) -> | Some ((Var _ as v), (Extract {len= l} as e)) ->
if not (Var.Set.mem (Var.of_ v) (Trm.fv e)) then if not (Trm.Set.mem v (Trm.fv e :> Trm.Set.t)) then
(* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *) (* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *)
add_solved ~var:v ~rep:e s add_solved ~var:v ~rep:e s
else else

@ -7,45 +7,20 @@
(** Terms *) (** Terms *)
(** Representation of Arithmetic terms *) (* Define term type using polymorphic arithmetic type, with derived compare,
module rec Arith0 : equal, and sexp_of functions *)
(Arithmetic.REPRESENTATION module Trm1 = struct
with type var := Trm.Var1.t type compare [@@deriving compare, equal, sexp]
with type trm := Trm.t) =
Arithmetic.Representation
(Trm.Var1)
(struct
include Trm
include Comparer.Make (Trm)
end)
(** Arithmetic terms *) type arith = (t, compare) Arithmetic.t
and Arith : (Arithmetic.S with type trm := Trm.t with type t = Arith0.t) =
struct
include Arith0
include Make (struct
let to_trm = Trm._Arith
let get_arith (e : Trm.t) =
match e with
| Z z -> Some (Arith.const (Q.of_z z))
| Q q -> Some (Arith.const q)
| Arith a -> Some a
| _ -> None
end)
end
(** Terms, built from variables and applications of function symbols from and t =
various theories. Denote functions from structures to values. *)
and Trm : sig
type t = private
(* variables *) (* variables *)
| Var of {id: int; name: string} | Var of {id: int; name: string [@ignore]}
(* arithmetic *) (* arithmetic *)
| Z of Z.t | Z of Z.t
| Q of Q.t | Q of Q.t
| Arith of Arith.t | Arith of arith
(* sequences (of flexible size) *) (* sequences (of flexible size) *)
| Splat of t | Splat of t
| Sized of {seq: t; siz: t} | Sized of {seq: t; siz: t}
@ -54,50 +29,20 @@ and Trm : sig
(* uninterpreted *) (* uninterpreted *)
| Apply of Funsym.t * t array | Apply of Funsym.t * t array
[@@deriving compare, equal, sexp] [@@deriving compare, equal, sexp]
end
(** Variable terms, represented as a subtype of general terms *) (* Add comparer, needed to instantiate arithmetic and containers *)
module Var1 : sig module Trm2 = struct
type trm := t include Comparer.Counterfeit (Trm1)
include Trm1
include Var_intf.S with type t = private trm end
val of_ : trm -> t (* Specialize arithmetic type and define operations using comparer *)
val of_trm : trm -> t option module Arith0 = Arithmetic.Make (Trm2)
end
val ppx : Var1.strength -> t pp (* Add ppx, defined recursively with Arith0.ppx *)
val pp : t pp module Trm3 = struct
include Trm2
include Invariant.S with type t := t
val _Var : int -> string -> t
val _Z : Z.t -> t
val _Q : Q.t -> t
val _Arith : Arith.t -> t
val _Splat : t -> t
val _Sized : seq:t -> siz:t -> t
val _Extract : seq:t -> off:t -> len:t -> t
val _Concat : t array -> t
val _Apply : Funsym.t -> t array -> t
val add : t -> t -> t
val sub : t -> t -> t
val seq_size_exn : t -> t
val seq_size : t -> t option
val get_z : t -> Z.t option
val get_q : t -> Q.t option
val vars : t -> Var1.t iter
end = struct
type t =
| Var of {id: int; name: string [@ignore]}
| Z of Z.t
| Q of Q.t
| Arith of Arith.t
| Splat of t
| Sized of {seq: t; siz: t}
| Extract of {seq: t; off: t; len: t}
| Concat of t array
| Apply of Funsym.t * t array
[@@deriving compare, equal, sexp]
(* nul-terminated string value represented by a concatenation *) (* nul-terminated string value represented by a concatenation *)
let string_of_concat xs = let string_of_concat xs =
@ -136,7 +81,7 @@ end = struct
| Some `Anonymous -> Trace.pp_styled `Cyan "_" fs ) | Some `Anonymous -> Trace.pp_styled `Cyan "_" fs )
| Z z -> Trace.pp_styled `Magenta "%a" fs Z.pp z | Z z -> Trace.pp_styled `Magenta "%a" fs Z.pp z
| Q q -> Trace.pp_styled `Magenta "%a" fs Q.pp q | Q q -> Trace.pp_styled `Magenta "%a" fs Q.pp q
| Arith a -> Arith.ppx (ppx strength) fs a | Arith a -> Arith0.ppx (ppx strength) fs a
| Splat x -> pf "%a^" pp x | Splat x -> pf "%a^" pp x
| Sized {seq; siz} -> pf "@<1>⟨%a,%a@<1>⟩" pp siz pp seq | Sized {seq; siz} -> pf "@<1>⟨%a,%a@<1>⟩" pp siz pp seq
| Extract {seq; off; len} -> pf "%a[%a,%a)" pp seq pp off pp len | Extract {seq; off; len} -> pf "%a[%a,%a)" pp seq pp off pp len
@ -157,255 +102,97 @@ end = struct
pp fs trm pp fs trm
let pp = ppx (fun _ -> None) let pp = ppx (fun _ -> None)
let pp_diff fs (x, y) = Format.fprintf fs "-- %a ++ %a" pp x pp y
end
(* Define variables as a subtype of terms *) (* Define containers over terms *)
module Var1 = struct module Set = struct
module V = struct include Set.Make (Trm3)
module T = struct include Provide_of_sexp (Trm3)
type nonrec t = t [@@deriving compare, equal, sexp] include Provide_pp (Trm3)
type strength = t -> [`Universal | `Existential | `Anonymous] option end
let pp = pp
let ppx = ppx
end
include T
let invariant x =
let@ () = Invariant.invariant [%here] x [%sexp_of: t] in
match x with
| Var _ -> ()
| _ -> fail "non-var: %a" Sexp.pp_hum (sexp_of_t x) ()
let make ~id ~name = Var {id; name} |> check invariant
let id = function Var v -> v.id | x -> violates invariant x
let name = function Var v -> v.name | x -> violates invariant x
module Set = struct
module S = NS.Set.Make (T)
include S
include Provide_of_sexp (T)
include Provide_pp (T)
let ppx strength vs = S.pp_full (ppx strength) vs
let pp_xs fs xs =
if not (is_empty xs) then
Format.fprintf fs "@<2>∃ @[%a@] .@;<1 2>" pp xs
end
module Map = struct
include NS.Map.Make (T)
include Provide_of_sexp (T)
end
let fresh name ~wrt =
let max =
match Set.max_elt wrt with None -> 0 | Some m -> max 0 (id m)
in
let x' = make ~id:(max + 1) ~name in
(x', Set.add x' wrt)
let freshen v ~wrt = fresh (name v) ~wrt
let program ?(name = "") ~id =
assert (id > 0) ;
make ~id:(-id) ~name
let identified ~name ~id = make ~id ~name
let of_ v = v |> check invariant
let of_trm = function Var _ as v -> Some v | _ -> None
end
include V module Map = struct
module Subst = Subst.Make (V) include Map.Make (Trm3)
end include Provide_of_sexp (Trm3)
end
let invariant e = (* Define variables as a subtype of terms *)
let@ () = Invariant.invariant [%here] e [%sexp_of: t] in module Var = struct
match e with open Trm3
| Q q -> assert (not (Z.equal Z.one (Q.den q)))
| Arith a -> (
match Arith.classify a with
| Trm _ | Const _ -> assert false
| _ -> () )
| _ -> ()
(** Destruct *) module V = struct
type nonrec t = t [@@deriving compare, equal, sexp]
type strength = t -> [`Universal | `Existential | `Anonymous] option
let get_z = function Z z -> Some z | _ -> None let pp = pp
let get_q = function Q q -> Some q | Z z -> Some (Q.of_z z) | _ -> None let ppx = ppx
(** Construct *) let invariant x =
let@ () = Invariant.invariant [%here] x [%sexp_of: t] in
match x with
| Var _ -> ()
| _ -> fail "non-var: %a" Sexp.pp_hum (sexp_of_t x) ()
let _Var id name = Var {id; name} |> check invariant let make ~id ~name = Var {id; name} |> check invariant
let id = function Var v -> v.id | x -> violates invariant x
let name = function Var v -> v.name | x -> violates invariant x
(* statically allocated since they are tested with == *) module Set = struct
let zero = Z Z.zero |> check invariant include Set
let one = Z Z.one |> check invariant
let _Z z = let ppx strength vs = pp_full (ppx strength) vs
(if Z.equal Z.zero z then zero else if Z.equal Z.one z then one else Z z)
|> check invariant
let _Q q = let pp_xs fs xs =
(if Z.equal Z.one (Q.den q) then _Z (Q.num q) else Q q) if not (is_empty xs) then
|> check invariant Format.fprintf fs "@<2>∃ @[%a@] .@;<1 2>" pp xs
end
let _Arith a = module Map = Map
( match Arith.classify a with
| Trm e -> e
| Const q -> _Q q
| _ -> Arith a )
|> check invariant
let add x y = _Arith Arith.(add (trm x) (trm y)) let fresh name ~wrt =
let sub x y = _Arith Arith.(sub (trm x) (trm y)) let max =
match Set.max_elt wrt with None -> 0 | Some m -> max 0 (id m)
in
let x' = make ~id:(max + 1) ~name in
(x', Set.add x' wrt)
let _Splat x = let freshen v ~wrt = fresh (name v) ~wrt
(* 0^ ==> 0 *)
(if x == zero then x else Splat x) |> check invariant
let seq_size_exn = let program ?(name = "") ~id =
let invalid = Invalid_argument "seq_size_exn" in assert (id > 0) ;
let rec seq_size_exn = function make ~id:(-id) ~name
| Sized {siz= n} | Extract {len= n} -> n
| Concat a0U ->
Array.fold ~f:(fun aJ a0I -> add a0I (seq_size_exn aJ)) a0U zero
| _ -> raise invalid
in
seq_size_exn
let seq_size e = let identified ~name ~id = make ~id ~name
try Some (seq_size_exn e) with Invalid_argument _ -> None let of_ v = v |> check invariant
let of_trm = function Var _ as v -> Some v | _ -> None
end
let _Sized ~seq ~siz = include V
( match seq_size seq with module Subst = Subst.Make (V)
(* ⟨n,α⟩ ==> α when n ≡ |α| *) end
| Some n when equal siz n -> seq
| _ -> Sized {seq; siz} )
|> check invariant
let partial_compare x y = (* Add definitions needed for arithmetic embedding into terms *)
match sub x y with module Trm = struct
| Z z -> Some (Int.sign (Z.sign z)) include Trm3
| Q q -> Some (Int.sign (Q.sign q))
| _ -> None
let partial_ge x y = (** Invariant *)
match partial_compare x y with Some (Pos | Zero) -> true | _ -> false
let empty_seq = Concat [||]
let rec _Extract ~seq ~off ~len =
[%trace]
~call:(fun {pf} -> pf "@ %a" pp (Extract {seq; off; len}))
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
(* _[_,0) ==> ⟨⟩ *)
( if equal len zero then empty_seq
else
let o_l = add off len in
match seq with
(* α[m,k)[o,l) ==> α[m+o,l) when k ≥ o+l *)
| Extract {seq= a; off= m; len= k} when partial_ge k o_l ->
_Extract ~seq:a ~off:(add m off) ~len
(* ⟨n,0⟩[o,l) ==> ⟨l,0⟩ when n ≥ o+l *)
| Sized {siz= n; seq} when seq == zero && partial_ge n o_l ->
_Sized ~seq ~siz:len
(* ⟨n,E^⟩[o,l) ==> ⟨l,E^⟩ when n ≥ o+l *)
| Sized {siz= n; seq= Splat _ as e} when partial_ge n o_l ->
_Sized ~seq:e ~siz:len
(* ⟨n,a⟩[0,n) ==> ⟨n,a⟩ *)
| Sized {siz= n} when equal off zero && equal n len -> seq
(* For (α₀^α₁)[o,l) there are 3 cases:
*
* ...^...
* [,)
* o < o+l |α| : (α^α)[o,l) ==> α[o,l) ^ α[0,0)
*
* ...^...
* [ , )
* o |α| < o+l : (α^α)[o,l) ==> α[o,|α|-o) ^ α[0,l-(|α|-o))
*
* ...^...
* [,)
* |α| o : (α^α)[o,l) ==> α[o,0) ^ α[o-|α|,l)
*
* So in general:
*
* (α^α)[o,l) ==> α[o,l) ^ α[o,l-l)
* where l = max 0 (min l |α|-o)
* o = max 0 o-|α|
*)
| Concat na1N -> (
match len with
| Z l ->
Array.fold_map_until na1N (l, off)
~f:(fun naI (l, oI) ->
if Z.equal Z.zero l then
`Continue (_Extract ~seq:naI ~off:oI ~len:zero, (l, oI))
else
let nI = seq_size_exn naI in
let oI_nI = sub oI nI in
match oI_nI with
| Z z ->
let oJ = if Z.sign z <= 0 then zero else oI_nI in
let lI = Z.(max zero (min l (neg z))) in
let l = Z.(l - lI) in
`Continue
(_Extract ~seq:naI ~off:oI ~len:(_Z lI), (l, oJ))
| _ -> `Stop (Extract {seq; off; len}) )
~finish:(fun (e1N, _) -> _Concat e1N)
| _ -> Extract {seq; off; len} )
(* α[o,l) *)
| _ -> Extract {seq; off; len} )
|> check invariant
and _Concat xs = let invariant e =
[%trace] let@ () = Invariant.invariant [%here] e [%sexp_of: t] in
~call:(fun {pf} -> pf "@ %a" pp (Concat xs)) match e with
~retn:(fun {pf} -> pf "%a" pp) | Q q -> assert (not (Z.equal Z.one (Q.den q)))
@@ fun () -> | Arith a -> (
(* (α^(β^γ)) ==> (α^β^γ) *) match Arith0.classify a with
let flatten xs = | Trm _ | Const _ -> assert false
if Array.exists ~f:(function Concat _ -> true | _ -> false) xs then | _ -> () )
Array.flat_map ~f:(function Concat s -> s | e -> [|e|]) xs | _ -> ()
else xs
in
let simp_adjacent e f =
match (e, f) with
(* ⟨n,a⟩[o,k)^⟨n,a⟩[o+k,l) ==> ⟨n,a⟩[o,k+l) when n ≥ o+k+l *)
| ( Extract {seq= Sized {siz= n} as na; off= o; len= k}
, Extract {seq= na'; off= o_k; len= l} )
when equal na na' && equal o_k (add o k) && partial_ge n (add o_k l)
->
Some (_Extract ~seq:na ~off:o ~len:(add k l))
(* ⟨m,0⟩^⟨n,0⟩ ==> ⟨m+n,0⟩ *)
| Sized {siz= m; seq= a}, Sized {siz= n; seq= a'}
when a == zero && a' == zero ->
Some (_Sized ~seq:a ~siz:(add m n))
(* ⟨m,E^⟩^⟨n,E^⟩ ==> ⟨m+n,E^⟩ *)
| Sized {siz= m; seq= Splat _ as a}, Sized {siz= n; seq= a'}
when equal a a' ->
Some (_Sized ~seq:a ~siz:(add m n))
| _ -> None
in
let xs = flatten xs in
let xs = Array.reduce_adjacent ~f:simp_adjacent xs in
(if Array.length xs = 1 then xs.(0) else Concat xs) |> check invariant
let _Apply f es =
( match Funsym.eval ~equal ~get_z ~ret_z:_Z ~get_q ~ret_q:_Q f es with
| Some c -> c
| None -> Apply (f, es) )
|> check invariant
(** Traverse *) (** Traverse *)
let rec iter_vars e ~f = let rec iter_vars e ~f =
match e with match e with
| Var _ as v -> f (Var1.of_ v) | Var _ as v -> f (Var.of_ v)
| Z _ | Q _ -> () | Z _ | Q _ -> ()
| Splat x -> iter_vars ~f x | Splat x -> iter_vars ~f x
| Sized {seq= x; siz= y} -> | Sized {seq= x; siz= y} ->
@ -416,33 +203,55 @@ end = struct
iter_vars ~f y ; iter_vars ~f y ;
iter_vars ~f z iter_vars ~f z
| Concat xs | Apply (_, xs) -> Array.iter ~f:(iter_vars ~f) xs | Concat xs | Apply (_, xs) -> Array.iter ~f:(iter_vars ~f) xs
| Arith a -> Iter.iter ~f:(iter_vars ~f) (Arith.trms a) | Arith a -> Iter.iter ~f:(iter_vars ~f) (Arith0.trms a)
let vars e = Iter.from_labelled_iter (iter_vars e) let vars e = Iter.from_labelled_iter (iter_vars e)
(** Construct *)
(* statically allocated since they are tested with == *)
let zero = Z Z.zero |> check invariant
let one = Z Z.one |> check invariant
let _Z z =
(if Z.equal Z.zero z then zero else if Z.equal Z.one z then one else Z z)
|> check invariant
let _Q q =
(if Z.equal Z.one (Q.den q) then _Z (Q.num q) else Q q)
|> check invariant
let _Arith a =
( match Arith0.classify a with
| Trm e -> e
| Const q -> _Q q
| _ -> Arith a )
|> check invariant
end end
include Trm include Trm
module Var = Var1
module Set = struct (* Instantiate arithmetic with embedding into terms, yielding full
include Set.Make (Trm) Arithmetic interface *)
include Provide_of_sexp (Trm) module Arith =
include Provide_pp (Trm) Arith0.Embed (Var) (Trm)
(struct
let of_vars : Var.Set.t -> t = let to_trm = _Arith
fun vs ->
of_iter let get_arith e =
(Iter.map ~f:(fun v -> (v : Var.t :> Trm.t)) (Var.Set.to_iter vs)) match e with
end | Z z -> Some (Arith0.const (Q.of_z z))
| Q q -> Some (Arith0.const q)
| Arith a -> Some a
| _ -> None
end)
module Map = struct (* Full Trm definition, using full arithmetic interface *)
include Map.Make (Trm)
include Provide_of_sexp (Trm)
end
type arith = Arith.t (** Destruct *)
let pp_diff fs (x, y) = Format.fprintf fs "-- %a ++ %a" pp x pp y let get_z = function Z z -> Some z | _ -> None
let get_q = function Q q -> Some q | Z z -> Some (Q.of_z z) | _ -> None
(** Construct *) (** Construct *)
@ -452,13 +261,11 @@ let var v = (v : Var.t :> t)
(* arithmetic *) (* arithmetic *)
let zero = _Z Z.zero
let one = _Z Z.one
let integer z = _Z z let integer z = _Z z
let rational q = _Q q let rational q = _Q q
let neg x = _Arith Arith.(neg (trm x)) let neg x = _Arith Arith.(neg (trm x))
let add = Trm.add let add x y = _Arith Arith.(add (trm x) (trm y))
let sub = Trm.sub let sub x y = _Arith Arith.(sub (trm x) (trm y))
let mulq q x = _Arith Arith.(mulc q (trm x)) let mulq q x = _Arith Arith.(mulc q (trm x))
let mul x y = _Arith (Arith.mul x y) let mul x y = _Arith (Arith.mul x y)
let div x y = _Arith (Arith.div x y) let div x y = _Arith (Arith.div x y)
@ -467,14 +274,145 @@ let arith = _Arith
(* sequences *) (* sequences *)
let splat = _Splat let splat x =
let sized = _Sized (* 0^ ==> 0 *)
let extract = _Extract (if x == zero then x else Splat x) |> check invariant
let concat elts = _Concat elts
let seq_size_exn =
let invalid = Invalid_argument "seq_size_exn" in
let rec seq_size_exn = function
| Sized {siz= n} | Extract {len= n} -> n
| Concat a0U ->
Array.fold ~f:(fun aJ a0I -> add a0I (seq_size_exn aJ)) a0U zero
| _ -> raise invalid
in
seq_size_exn
let seq_size e = try Some (seq_size_exn e) with Invalid_argument _ -> None
let sized ~seq ~siz =
( match seq_size seq with
(* ⟨n,α⟩ ==> α when n ≡ |α| *)
| Some n when equal siz n -> seq
| _ -> Sized {seq; siz} )
|> check invariant
let partial_compare x y =
match sub x y with
| Z z -> Some (Int.sign (Z.sign z))
| Q q -> Some (Int.sign (Q.sign q))
| _ -> None
let partial_ge x y =
match partial_compare x y with Some (Pos | Zero) -> true | _ -> false
let empty_seq = Concat [||]
let rec extract ~seq ~off ~len =
[%trace]
~call:(fun {pf} -> pf "@ %a" pp (Extract {seq; off; len}))
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
(* _[_,0) ==> ⟨⟩ *)
( if equal len zero then empty_seq
else
let o_l = add off len in
match seq with
(* α[m,k)[o,l) ==> α[m+o,l) when k ≥ o+l *)
| Extract {seq= a; off= m; len= k} when partial_ge k o_l ->
extract ~seq:a ~off:(add m off) ~len
(* ⟨n,0⟩[o,l) ==> ⟨l,0⟩ when n ≥ o+l *)
| Sized {siz= n; seq} when seq == zero && partial_ge n o_l ->
sized ~seq ~siz:len
(* ⟨n,E^⟩[o,l) ==> ⟨l,E^⟩ when n ≥ o+l *)
| Sized {siz= n; seq= Splat _ as e} when partial_ge n o_l ->
sized ~seq:e ~siz:len
(* ⟨n,a⟩[0,n) ==> ⟨n,a⟩ *)
| Sized {siz= n} when equal off zero && equal n len -> seq
(* For (α₀^α₁)[o,l) there are 3 cases:
*
* ...^...
* [,)
* o < o+l |α| : (α^α)[o,l) ==> α[o,l) ^ α[0,0)
*
* ...^...
* [ , )
* o |α| < o+l : (α^α)[o,l) ==> α[o,|α|-o) ^ α[0,l-(|α|-o))
*
* ...^...
* [,)
* |α| o : (α^α)[o,l) ==> α[o,0) ^ α[o-|α|,l)
*
* So in general:
*
* (α^α)[o,l) ==> α[o,l) ^ α[o,l-l)
* where l = max 0 (min l |α|-o)
* o = max 0 o-|α|
*)
| Concat na1N -> (
match len with
| Z l ->
Array.fold_map_until na1N (l, off)
~f:(fun naI (l, oI) ->
if Z.equal Z.zero l then
`Continue (extract ~seq:naI ~off:oI ~len:zero, (l, oI))
else
let nI = seq_size_exn naI in
let oI_nI = sub oI nI in
match oI_nI with
| Z z ->
let oJ = if Z.sign z <= 0 then zero else oI_nI in
let lI = Z.(max zero (min l (neg z))) in
let l = Z.(l - lI) in
`Continue
(extract ~seq:naI ~off:oI ~len:(_Z lI), (l, oJ))
| _ -> `Stop (Extract {seq; off; len}) )
~finish:(fun (e1N, _) -> concat e1N)
| _ -> Extract {seq; off; len} )
(* α[o,l) *)
| _ -> Extract {seq; off; len} )
|> check invariant
and concat xs =
[%trace]
~call:(fun {pf} -> pf "@ %a" pp (Concat xs))
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
(* (α^(β^γ)) ==> (α^β^γ) *)
let flatten xs =
if Array.exists ~f:(function Concat _ -> true | _ -> false) xs then
Array.flat_map ~f:(function Concat s -> s | e -> [|e|]) xs
else xs
in
let simp_adjacent e f =
match (e, f) with
(* ⟨n,a⟩[o,k)^⟨n,a⟩[o+k,l) ==> ⟨n,a⟩[o,k+l) when n ≥ o+k+l *)
| ( Extract {seq= Sized {siz= n} as na; off= o; len= k}
, Extract {seq= na'; off= o_k; len= l} )
when equal na na' && equal o_k (add o k) && partial_ge n (add o_k l)
->
Some (extract ~seq:na ~off:o ~len:(add k l))
(* ⟨m,0⟩^⟨n,0⟩ ==> ⟨m+n,0⟩ *)
| Sized {siz= m; seq= a}, Sized {siz= n; seq= a'}
when a == zero && a' == zero ->
Some (sized ~seq:a ~siz:(add m n))
(* ⟨m,E^⟩^⟨n,E^⟩ ==> ⟨m+n,E^⟩ *)
| Sized {siz= m; seq= Splat _ as a}, Sized {siz= n; seq= a'}
when equal a a' ->
Some (sized ~seq:a ~siz:(add m n))
| _ -> None
in
let xs = flatten xs in
let xs = Array.reduce_adjacent ~f:simp_adjacent xs in
(if Array.length xs = 1 then xs.(0) else Concat xs) |> check invariant
(* uninterpreted *) (* uninterpreted *)
let apply sym args = _Apply sym args let apply f es =
( match Funsym.eval ~equal ~get_z ~ret_z:_Z ~get_q ~ret_q:_Q f es with
| Some c -> c
| None -> Apply (f, es) )
|> check invariant
(** Traverse *) (** Traverse *)
@ -505,25 +443,25 @@ let rec map_vars e ~f =
| Var _ as v -> (f (Var.of_ v) : Var.t :> t) | Var _ as v -> (f (Var.of_ v) : Var.t :> t)
| Z _ | Q _ -> e | Z _ | Q _ -> e
| Arith a -> map1 (Arith.map ~f:(map_vars ~f)) e _Arith a | Arith a -> map1 (Arith.map ~f:(map_vars ~f)) e _Arith a
| Splat x -> map1 (map_vars ~f) e _Splat x | Splat x -> map1 (map_vars ~f) e splat x
| Sized {seq; siz} -> | Sized {seq; siz} ->
map2 (map_vars ~f) e (fun seq siz -> _Sized ~seq ~siz) seq siz map2 (map_vars ~f) e (fun seq siz -> sized ~seq ~siz) seq siz
| Extract {seq; off; len} -> | Extract {seq; off; len} ->
map3 (map_vars ~f) e map3 (map_vars ~f) e
(fun seq off len -> _Extract ~seq ~off ~len) (fun seq off len -> extract ~seq ~off ~len)
seq off len seq off len
| Concat xs -> mapN (map_vars ~f) e _Concat xs | Concat xs -> mapN (map_vars ~f) e concat xs
| Apply (g, xs) -> mapN (map_vars ~f) e (_Apply g) xs | Apply (g, xs) -> mapN (map_vars ~f) e (apply g) xs
let map e ~f = let map e ~f =
match e with match e with
| Var _ | Z _ | Q _ -> e | Var _ | Z _ | Q _ -> e
| Arith a -> map1 (Arith.map ~f) e _Arith a | Arith a -> map1 (Arith.map ~f) e _Arith a
| Splat x -> map1 f e _Splat x | Splat x -> map1 f e splat x
| Sized {seq; siz} -> map2 f e (fun seq siz -> _Sized ~seq ~siz) seq siz | Sized {seq; siz} -> map2 f e (fun seq siz -> sized ~seq ~siz) seq siz
| Extract {seq; off; len} -> | Extract {seq; off; len} ->
map3 f e (fun seq off len -> _Extract ~seq ~off ~len) seq off len map3 f e (fun seq off len -> extract ~seq ~off ~len) seq off len
| Concat xs -> mapN f e _Concat xs | Concat xs -> mapN f e concat xs
| Apply (g, xs) -> mapN f e (_Apply g) xs | Apply (g, xs) -> mapN f e (apply g) xs
let fold_map e = fold_map_from_map map e let fold_map e = fold_map_from_map map e

@ -9,9 +9,11 @@
type arith type arith
(** Terms, built from variables and applications of function symbols from
various theories. Denote functions from structures to values. *)
type t = private type t = private
(* variables *) (* variables *)
| Var of {id: int; name: string} | Var of {id: int; name: string [@ignore]}
(* arithmetic *) (* arithmetic *)
| Z of Z.t | Z of Z.t
| Q of Q.t | Q of Q.t
@ -25,24 +27,15 @@ type t = private
| Apply of Funsym.t * t array | Apply of Funsym.t * t array
[@@deriving compare, equal, sexp] [@@deriving compare, equal, sexp]
(** Arithmetic terms *)
module Arith : Arithmetic.S with type trm := t with type t = arith module Arith : Arithmetic.S with type trm := t with type t = arith
module Var : sig
type trm := t
include Var_intf.S with type t = private trm
val of_ : trm -> t
val of_trm : trm -> t option
end
module Set : sig module Set : sig
include Set.S with type elt := t include Set.S with type elt := t
val t_of_sexp : Sexp.t -> t val t_of_sexp : Sexp.t -> t
val pp : t pp val pp : t pp
val pp_diff : (t * t) pp val pp_diff : (t * t) pp
val of_vars : Var.Set.t -> t
end end
module Map : sig module Map : sig
@ -51,6 +44,16 @@ module Map : sig
val t_of_sexp : (Sexp.t -> 'a) -> Sexp.t -> 'a t val t_of_sexp : (Sexp.t -> 'a) -> Sexp.t -> 'a t
end end
(** Variable terms, represented as a subtype of general terms *)
module Var : sig
type trm := t
include
Var_intf.S with type t = private trm with type Set.t = private Set.t
val of_trm : trm -> t option
end
val ppx : Var.strength -> t pp val ppx : Var.strength -> t pp
val pp : t pp val pp : t pp
val pp_diff : (t * t) pp val pp_diff : (t * t) pp

Loading…
Cancel
Save