diff --git a/sledge/src/fol/arithmetic.ml b/sledge/src/fol/arithmetic.ml index 1d39b872f..dd46d1bc5 100644 --- a/sledge/src/fol/arithmetic.ml +++ b/sledge/src/fol/arithmetic.ml @@ -19,17 +19,27 @@ module Q = struct include Comparer.Make (Q) end -module Representation - (Var : Var_intf.S) - (Trm : INDETERMINATE with type var := Var.t) = +type ('trm, 'compare_trm) mono = ('trm, int, 'compare_trm) Multiset.t +[@@deriving compare, equal, sexp] + +type 'compare_trm compare_mono = + ('compare_trm, Int.compare) Multiset.compare +[@@deriving compare, equal, sexp] + +type ('trm, 'compare_trm) t = + (('trm, 'compare_trm) mono, Q.t, 'compare_trm compare_mono) Multiset.t +[@@deriving compare, equal, sexp] + +module Make (Trm0 : sig + type t [@@deriving equal, sexp] + + include Comparer.S with type t := t +end) = struct - module Prod = struct - include Multiset.Make (Trm) (Int) - include Provide_of_sexp (Trm) - end + module Prod = Multiset.Make (Trm0) (Int) module Mono = struct - type t = Prod.t [@@deriving compare, equal, sexp] + type t = Prod.t [@@deriving compare, equal, sexp_of] let num_den m = Prod.partition m ~f:(fun _ i -> i >= 0) @@ -79,11 +89,7 @@ struct Iter.from_iter (fun f -> Prod.iter mono ~f:(fun trm _ -> f trm)) end - module Sum = struct - include Multiset.Make (Prod) (Q) - include Provide_of_sexp (Prod) - end - + module Sum = Multiset.Make (Prod) (Q) module Poly = Sum include Poly @@ -104,6 +110,11 @@ struct (Sum.pp "@ + " pp_coeff_mono) poly + let trms poly = + Iter.from_iter (fun f -> + Sum.iter poly ~f:(fun mono _ -> + Prod.iter mono ~f:(fun trm _ -> f trm) ) ) + (* core invariant *) let mono_invariant mono = @@ -119,10 +130,14 @@ struct assert (not (Q.equal Q.zero coeff)) ; mono_invariant mono ) + (* embed a term into a polynomial *) + let trm trm = Sum.of_ (Mono.of_ trm 1) Q.one + (* constants *) let const q = Sum.of_ Mono.one q |> check invariant let zero = const Q.zero |> check (fun p -> assert (Sum.is_empty p)) + let one = const Q.one (* core constructors *) @@ -149,7 +164,7 @@ struct (* projections and embeddings *) - type kind = Trm of Trm.t | Const of Q.t | Interpreted | Uninterpreted + type kind = Trm of Trm0.t | Const of Q.t | Interpreted | Uninterpreted let classify poly = match Sum.classify poly with @@ -189,7 +204,12 @@ struct | None -> None end - module Make (Embed : EMBEDDING with type trm := Trm.t and type t := t) = + include S0 + + module Embed + (Var : Var_intf.S) + (Trm : TRM with type t = Trm0.t with type var := Var.t) + (Embed : EMBEDDING with type trm := Trm0.t with type t := t) = struct module Mono = struct include Mono @@ -202,16 +222,24 @@ struct include Poly include S0 + (** hide S0.trm and S0.trms that ignore the embedding, shadowed below *) + let[@warning "-32"] trm, trms = ((), ()) + let pp = ppx Trm.pp + (** Embed a polynomial into a term *) + let trm_of_poly = Embed.to_trm + (** Embed a monomial into a term, flattening if possible *) let trm_of_mono mono = match Mono.get_trm mono with | Some trm -> trm - | None -> Embed.to_trm (Sum.of_ mono Q.one) + | None -> trm_of_poly (Sum.of_ mono Q.one) (* traverse *) + let vars poly = Iter.flat_map ~f:Trm.vars (S0.trms poly) + let monos poly = Iter.from_iter (fun f -> Sum.iter poly ~f:(fun mono _ -> @@ -222,10 +250,6 @@ struct | Some mono -> Mono.trms mono | None -> Iter.map ~f:trm_of_mono (monos poly) - let vars p = Iter.flat_map ~f:Trm.vars (trms p) - - (* invariant *) - let mono_invariant mono = mono_invariant mono ; let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in @@ -309,7 +333,7 @@ struct let trm trm = ( match Embed.get_arith trm with | Some poly -> poly - | None -> Sum.of_ (Mono.of_ trm 1) Q.one ) + | None -> S0.trm trm ) |> check (fun poly -> assert (equal poly (CM.to_poly (CM.of_trm trm))) ) @@ -322,7 +346,7 @@ struct let pow base power = CM.to_poly (CM.of_trm base ~power) - (* map over [trms] *) + (** map over [trms] *) let map poly ~f = [%trace] ~call:(fun {pf} -> pf "@ %a" pp poly) @@ -398,7 +422,7 @@ struct | Some _ as soln -> soln | None -> solve_poly (Sum.add mono coeff rejected) poly - (* solve [0 = e] *) + (** solve [0 = e] *) let solve_zero_eq ?for_ e = [%trace] ~call:(fun {pf} -> diff --git a/sledge/src/fol/arithmetic.mli b/sledge/src/fol/arithmetic.mli index edb79493b..6ed43a534 100644 --- a/sledge/src/fol/arithmetic.mli +++ b/sledge/src/fol/arithmetic.mli @@ -9,7 +9,25 @@ include module type of Arithmetic_intf -module Representation - (Var : Var_intf.S) - (Indeterminate : INDETERMINATE with type var := Var.t) : - REPRESENTATION with type var := Var.t with type trm := Indeterminate.t +(** Arithmetic terms, e.g. polynomials, polymorphic in the type of + indeterminates. *) +type ('trm, 'cmp) t [@@deriving compare, equal, sexp] + +(** Functor that, given a totally ordered type of indeterminate terms, + builds an implementation of the embedding-independent arithmetic + operations, and a functor that, given an embedding of arithmetic terms + into indeterminate terms, builds an implementation of the arithmetic + operations. *) +module Make (Ord : sig + type t [@@deriving equal, sexp] + + include Comparer.S with type t := t +end) : sig + include S0 with type t = (Ord.t, Ord.compare) t with type trm := Ord.t + + module Embed + (Var : Var_intf.S) + (Trm : TRM with type t = Ord.t with type var := Var.t) + (_ : EMBEDDING with type trm := Trm.t and type t := t) : + S with type trm := Trm.t with type t = t +end diff --git a/sledge/src/fol/arithmetic_intf.ml b/sledge/src/fol/arithmetic_intf.ml index b7375b702..ec813d6bc 100644 --- a/sledge/src/fol/arithmetic_intf.ml +++ b/sledge/src/fol/arithmetic_intf.ml @@ -14,31 +14,29 @@ module type EMBEDDING = sig val to_trm : t -> trm (** Embedding from [t] to [trm]: [to_trm a] is arithmetic term [a] - embedded in an indeterminate term. *) + embedded into an indeterminate term. *) val get_arith : trm -> t option - (** Partial projection from [trm] to [t]: [get_arith x] is [Some a] if - [x = to_trm a]. This is used to flatten indeterminates that are - actually arithmetic for the client, thereby enabling arithmetic - operations to be interpreted more often. *) + (** Partial projection from [trm] to [t]: [get_arith x] is [Some a] iff + [x = to_trm a]. *) end (** Indeterminate terms, treated as atomic / variables except when they can - be flattened using [EMBEDDING.get_arith]. *) -module type INDETERMINATE = sig - type t [@@deriving compare, equal, sexp] + be flattened using {!EMBEDDING.get_arith}. *) +module type TRM = sig + include Comparer.S - include Comparer.S with type t := t + val pp : t pp type var - val pp : t pp val vars : t -> var iter end -module type S = sig +(** Arithmetic terms, e.g. polynomials [t] over indeterminate terms [trm] *) +module type S0 = sig type trm - type t [@@deriving compare, equal, sexp] + type t [@@deriving compare, equal] val ppx : trm pp -> t pp @@ -48,8 +46,10 @@ module type S = sig (** [trm x] represents the indeterminate term [x] *) val const : Q.t -> t - (** [const q] represents the constant [q] *) + (** [const q] represents the rational constant [q] *) + val zero : t + val one : t val neg : t -> t val add : t -> t -> t val sub : t -> t -> t @@ -84,6 +84,22 @@ module type S = sig val is_uninterpreted : t -> bool (** [is_uninterpreted a] iff [classify a = Uninterpreted] *) + (** Traverse *) + + val trms : t -> trm iter + (** [trms a] enumerates the indeterminate terms appearing in [a]. + Considering an arithmetic term as a polynomial, + [trms (c × (Σᵢ₌₁ⁿ cᵢ × Πⱼ₌₁ᵐᵢ + Xᵢⱼ^pᵢⱼ))] is the sequence of terms [Xᵢⱼ] for each [i] and + [j]. *) +end + +(** Arithmetic terms, where an embedding {!EMBEDDING.get_arith} into + indeterminate terms is used to implicitly flatten arithmetic terms that + are embedded into general terms to the underlying arithmetic term. *) +module type S = sig + include S0 + (** Construct nonlinear arithmetic terms using the embedding, enabling interpreting associativity, commutatitivity, and unit laws, but not the full nonlinear arithmetic theory. *) @@ -95,8 +111,8 @@ module type S = sig (** Traverse *) val trms : t -> trm iter - (** [trms a] is the maximal foreign or noninterpreted proper subterms of - [a]. Considering an arithmetic term as a polynomial, + (** [trms a] enumerates the maximal foreign or noninterpreted proper + subterms of [a]. Considering an arithmetic term as a polynomial, [trms (c × (Σᵢ₌₁ⁿ cᵢ × Πⱼ₌₁ᵐᵢ Xᵢⱼ^pᵢⱼ))] is the sequence of monomials [Πⱼ₌₁ᵐᵢ Xᵢⱼ^pᵢⱼ] for each [i]. If the arithmetic @@ -115,16 +131,3 @@ module type S = sig expressed as [e = f] for some monomial subterm [e] of [d]. If [for_] is passed, then the subterm [e] must be [for_]. *) end - -(** A type [t] representing arithmetic terms over indeterminates [trm] - together with a functor [Make] that takes an [EMBEDDING] of arithmetic - terms [t] into indeterminates [trm] and produces an implementation of - the primary interface [S]. *) -module type REPRESENTATION = sig - type t [@@deriving compare, equal, sexp] - type var - type trm - - module Make (_ : EMBEDDING with type trm := trm and type t := t) : - S with type trm := trm with type t := t -end diff --git a/sledge/src/fol/context.ml b/sledge/src/fol/context.ml index 5cbd0bdb6..4e8983153 100644 --- a/sledge/src/fol/context.ml +++ b/sledge/src/fol/context.ml @@ -123,9 +123,7 @@ end = struct [not (is_valid_eq xs e f)] implies [not (is_valid_eq ys e f)] for [ys ⊆ xs]. *) let is_valid_eq xs e f = - let is_var_in xs e = - Option.exists ~f:(fun x -> Var.Set.mem x xs) (Var.of_trm e) - in + let is_var_in xs e = Trm.Set.mem e (xs : Var.Set.t :> Trm.Set.t) in let noninterp_with_solvable_var_in xs e = is_var_in xs e || Theory.is_noninterpreted e @@ -922,10 +920,11 @@ let trim ks x = Cls.add rep (Option.value cls0 ~default:Cls.empty) ) ) in (* enumerate expanded classes and update solution subst *) - let kills = Trm.Set.of_vars ks in Trm.Map.fold clss x ~f:(fun ~key:a' ~data:ecls x -> (* remove mappings for non-rep class elements to kill *) - let keep, drop = Trm.Set.diff_inter (Cls.to_set ecls) kills in + let keep, drop = + Trm.Set.diff_inter (Cls.to_set ecls) (ks : Var.Set.t :> Trm.Set.t) + in if Trm.Set.is_empty drop then x else let rep = Trm.Set.fold ~f:Subst.remove drop x.rep in diff --git a/sledge/src/fol/theory.ml b/sledge/src/fol/theory.ml index 04098ff37..454a837ba 100644 --- a/sledge/src/fol/theory.ml +++ b/sledge/src/fol/theory.ml @@ -196,7 +196,7 @@ let solve d e s = | Some ((Sized {siz= n; seq= Splat _} as b), Concat a0V) -> solve_concat a0V b n s | Some ((Var _ as v), (Concat a0V as c)) -> - if not (Var.Set.mem (Var.of_ v) (Trm.fv c)) then + if not (Trm.Set.mem v (Trm.fv c :> Trm.Set.t)) then (* v = α₀^…^αᵥ ==> v ↦ α₀^…^αᵥ when v ∉ fv(α₀^…^αᵥ) *) add_solved ~var:v ~rep:c s else @@ -212,7 +212,7 @@ let solve d e s = * Extract *) | Some ((Var _ as v), (Extract {len= l} as e)) -> - if not (Var.Set.mem (Var.of_ v) (Trm.fv e)) then + if not (Trm.Set.mem v (Trm.fv e :> Trm.Set.t)) then (* v = α[o,l) ==> v ↦ α[o,l) when v ∉ fv(α[o,l)) *) add_solved ~var:v ~rep:e s else diff --git a/sledge/src/fol/trm.ml b/sledge/src/fol/trm.ml index ea2ee270a..79ba468c6 100644 --- a/sledge/src/fol/trm.ml +++ b/sledge/src/fol/trm.ml @@ -7,45 +7,20 @@ (** Terms *) -(** Representation of Arithmetic terms *) -module rec Arith0 : - (Arithmetic.REPRESENTATION - with type var := Trm.Var1.t - with type trm := Trm.t) = - Arithmetic.Representation - (Trm.Var1) - (struct - include Trm - include Comparer.Make (Trm) - end) +(* Define term type using polymorphic arithmetic type, with derived compare, + equal, and sexp_of functions *) +module Trm1 = struct + type compare [@@deriving compare, equal, sexp] -(** Arithmetic terms *) -and Arith : (Arithmetic.S with type trm := Trm.t with type t = Arith0.t) = -struct - include Arith0 - - include Make (struct - let to_trm = Trm._Arith - - let get_arith (e : Trm.t) = - match e with - | Z z -> Some (Arith.const (Q.of_z z)) - | Q q -> Some (Arith.const q) - | Arith a -> Some a - | _ -> None - end) -end + type arith = (t, compare) Arithmetic.t -(** Terms, built from variables and applications of function symbols from - various theories. Denote functions from structures to values. *) -and Trm : sig - type t = private + and t = (* variables *) - | Var of {id: int; name: string} + | Var of {id: int; name: string [@ignore]} (* arithmetic *) | Z of Z.t | Q of Q.t - | Arith of Arith.t + | Arith of arith (* sequences (of flexible size) *) | Splat of t | Sized of {seq: t; siz: t} @@ -54,50 +29,20 @@ and Trm : sig (* uninterpreted *) | Apply of Funsym.t * t array [@@deriving compare, equal, sexp] +end - (** Variable terms, represented as a subtype of general terms *) - module Var1 : sig - type trm := t - - include Var_intf.S with type t = private trm +(* Add comparer, needed to instantiate arithmetic and containers *) +module Trm2 = struct + include Comparer.Counterfeit (Trm1) + include Trm1 +end - val of_ : trm -> t - val of_trm : trm -> t option - end +(* Specialize arithmetic type and define operations using comparer *) +module Arith0 = Arithmetic.Make (Trm2) - val ppx : Var1.strength -> t pp - val pp : t pp - - include Invariant.S with type t := t - - val _Var : int -> string -> t - val _Z : Z.t -> t - val _Q : Q.t -> t - val _Arith : Arith.t -> t - val _Splat : t -> t - val _Sized : seq:t -> siz:t -> t - val _Extract : seq:t -> off:t -> len:t -> t - val _Concat : t array -> t - val _Apply : Funsym.t -> t array -> t - val add : t -> t -> t - val sub : t -> t -> t - val seq_size_exn : t -> t - val seq_size : t -> t option - val get_z : t -> Z.t option - val get_q : t -> Q.t option - val vars : t -> Var1.t iter -end = struct - type t = - | Var of {id: int; name: string [@ignore]} - | Z of Z.t - | Q of Q.t - | Arith of Arith.t - | Splat of t - | Sized of {seq: t; siz: t} - | Extract of {seq: t; off: t; len: t} - | Concat of t array - | Apply of Funsym.t * t array - [@@deriving compare, equal, sexp] +(* Add ppx, defined recursively with Arith0.ppx *) +module Trm3 = struct + include Trm2 (* nul-terminated string value represented by a concatenation *) let string_of_concat xs = @@ -136,7 +81,7 @@ end = struct | Some `Anonymous -> Trace.pp_styled `Cyan "_" fs ) | Z z -> Trace.pp_styled `Magenta "%a" fs Z.pp z | Q q -> Trace.pp_styled `Magenta "%a" fs Q.pp q - | Arith a -> Arith.ppx (ppx strength) fs a + | Arith a -> Arith0.ppx (ppx strength) fs a | Splat x -> pf "%a^" pp x | Sized {seq; siz} -> pf "@<1>⟨%a,%a@<1>⟩" pp siz pp seq | Extract {seq; off; len} -> pf "%a[%a,%a)" pp seq pp off pp len @@ -157,255 +102,97 @@ end = struct pp fs trm let pp = ppx (fun _ -> None) + let pp_diff fs (x, y) = Format.fprintf fs "-- %a ++ %a" pp x pp y +end - (* Define variables as a subtype of terms *) - module Var1 = struct - module V = struct - module T = struct - type nonrec t = t [@@deriving compare, equal, sexp] - type strength = t -> [`Universal | `Existential | `Anonymous] option - - let pp = pp - let ppx = ppx - end - - include T - - let invariant x = - let@ () = Invariant.invariant [%here] x [%sexp_of: t] in - match x with - | Var _ -> () - | _ -> fail "non-var: %a" Sexp.pp_hum (sexp_of_t x) () - - let make ~id ~name = Var {id; name} |> check invariant - let id = function Var v -> v.id | x -> violates invariant x - let name = function Var v -> v.name | x -> violates invariant x - - module Set = struct - module S = NS.Set.Make (T) - include S - include Provide_of_sexp (T) - include Provide_pp (T) - - let ppx strength vs = S.pp_full (ppx strength) vs - - let pp_xs fs xs = - if not (is_empty xs) then - Format.fprintf fs "@<2>∃ @[%a@] .@;<1 2>" pp xs - end - - module Map = struct - include NS.Map.Make (T) - include Provide_of_sexp (T) - end - - let fresh name ~wrt = - let max = - match Set.max_elt wrt with None -> 0 | Some m -> max 0 (id m) - in - let x' = make ~id:(max + 1) ~name in - (x', Set.add x' wrt) - - let freshen v ~wrt = fresh (name v) ~wrt - - let program ?(name = "") ~id = - assert (id > 0) ; - make ~id:(-id) ~name - - let identified ~name ~id = make ~id ~name - let of_ v = v |> check invariant - let of_trm = function Var _ as v -> Some v | _ -> None - end +(* Define containers over terms *) +module Set = struct + include Set.Make (Trm3) + include Provide_of_sexp (Trm3) + include Provide_pp (Trm3) +end - include V - module Subst = Subst.Make (V) - end +module Map = struct + include Map.Make (Trm3) + include Provide_of_sexp (Trm3) +end - let invariant e = - let@ () = Invariant.invariant [%here] e [%sexp_of: t] in - match e with - | Q q -> assert (not (Z.equal Z.one (Q.den q))) - | Arith a -> ( - match Arith.classify a with - | Trm _ | Const _ -> assert false - | _ -> () ) - | _ -> () +(* Define variables as a subtype of terms *) +module Var = struct + open Trm3 - (** Destruct *) + module V = struct + type nonrec t = t [@@deriving compare, equal, sexp] + type strength = t -> [`Universal | `Existential | `Anonymous] option - let get_z = function Z z -> Some z | _ -> None - let get_q = function Q q -> Some q | Z z -> Some (Q.of_z z) | _ -> None + let pp = pp + let ppx = ppx - (** Construct *) + let invariant x = + let@ () = Invariant.invariant [%here] x [%sexp_of: t] in + match x with + | Var _ -> () + | _ -> fail "non-var: %a" Sexp.pp_hum (sexp_of_t x) () - let _Var id name = Var {id; name} |> check invariant + let make ~id ~name = Var {id; name} |> check invariant + let id = function Var v -> v.id | x -> violates invariant x + let name = function Var v -> v.name | x -> violates invariant x - (* statically allocated since they are tested with == *) - let zero = Z Z.zero |> check invariant - let one = Z Z.one |> check invariant + module Set = struct + include Set - let _Z z = - (if Z.equal Z.zero z then zero else if Z.equal Z.one z then one else Z z) - |> check invariant + let ppx strength vs = pp_full (ppx strength) vs - let _Q q = - (if Z.equal Z.one (Q.den q) then _Z (Q.num q) else Q q) - |> check invariant + let pp_xs fs xs = + if not (is_empty xs) then + Format.fprintf fs "@<2>∃ @[%a@] .@;<1 2>" pp xs + end - let _Arith a = - ( match Arith.classify a with - | Trm e -> e - | Const q -> _Q q - | _ -> Arith a ) - |> check invariant + module Map = Map - let add x y = _Arith Arith.(add (trm x) (trm y)) - let sub x y = _Arith Arith.(sub (trm x) (trm y)) + let fresh name ~wrt = + let max = + match Set.max_elt wrt with None -> 0 | Some m -> max 0 (id m) + in + let x' = make ~id:(max + 1) ~name in + (x', Set.add x' wrt) - let _Splat x = - (* 0^ ==> 0 *) - (if x == zero then x else Splat x) |> check invariant + let freshen v ~wrt = fresh (name v) ~wrt - let seq_size_exn = - let invalid = Invalid_argument "seq_size_exn" in - let rec seq_size_exn = function - | Sized {siz= n} | Extract {len= n} -> n - | Concat a0U -> - Array.fold ~f:(fun aJ a0I -> add a0I (seq_size_exn aJ)) a0U zero - | _ -> raise invalid - in - seq_size_exn + let program ?(name = "") ~id = + assert (id > 0) ; + make ~id:(-id) ~name - let seq_size e = - try Some (seq_size_exn e) with Invalid_argument _ -> None + let identified ~name ~id = make ~id ~name + let of_ v = v |> check invariant + let of_trm = function Var _ as v -> Some v | _ -> None + end - let _Sized ~seq ~siz = - ( match seq_size seq with - (* ⟨n,α⟩ ==> α when n ≡ |α| *) - | Some n when equal siz n -> seq - | _ -> Sized {seq; siz} ) - |> check invariant + include V + module Subst = Subst.Make (V) +end - let partial_compare x y = - match sub x y with - | Z z -> Some (Int.sign (Z.sign z)) - | Q q -> Some (Int.sign (Q.sign q)) - | _ -> None +(* Add definitions needed for arithmetic embedding into terms *) +module Trm = struct + include Trm3 - let partial_ge x y = - match partial_compare x y with Some (Pos | Zero) -> true | _ -> false - - let empty_seq = Concat [||] - - let rec _Extract ~seq ~off ~len = - [%trace] - ~call:(fun {pf} -> pf "@ %a" pp (Extract {seq; off; len})) - ~retn:(fun {pf} -> pf "%a" pp) - @@ fun () -> - (* _[_,0) ==> ⟨⟩ *) - ( if equal len zero then empty_seq - else - let o_l = add off len in - match seq with - (* α[m,k)[o,l) ==> α[m+o,l) when k ≥ o+l *) - | Extract {seq= a; off= m; len= k} when partial_ge k o_l -> - _Extract ~seq:a ~off:(add m off) ~len - (* ⟨n,0⟩[o,l) ==> ⟨l,0⟩ when n ≥ o+l *) - | Sized {siz= n; seq} when seq == zero && partial_ge n o_l -> - _Sized ~seq ~siz:len - (* ⟨n,E^⟩[o,l) ==> ⟨l,E^⟩ when n ≥ o+l *) - | Sized {siz= n; seq= Splat _ as e} when partial_ge n o_l -> - _Sized ~seq:e ~siz:len - (* ⟨n,a⟩[0,n) ==> ⟨n,a⟩ *) - | Sized {siz= n} when equal off zero && equal n len -> seq - (* For (α₀^α₁)[o,l) there are 3 cases: - * - * ⟨...⟩^⟨...⟩ - * [,) - * o < o+l ≤ |α₀| : (α₀^α₁)[o,l) ==> α₀[o,l) ^ α₁[0,0) - * - * ⟨...⟩^⟨...⟩ - * [ , ) - * o ≤ |α₀| < o+l : (α₀^α₁)[o,l) ==> α₀[o,|α₀|-o) ^ α₁[0,l-(|α₀|-o)) - * - * ⟨...⟩^⟨...⟩ - * [,) - * |α₀| ≤ o : (α₀^α₁)[o,l) ==> α₀[o,0) ^ α₁[o-|α₀|,l) - * - * So in general: - * - * (α₀^α₁)[o,l) ==> α₀[o,l₀) ^ α₁[o₁,l-l₀) - * where l₀ = max 0 (min l |α₀|-o) - * o₁ = max 0 o-|α₀| - *) - | Concat na1N -> ( - match len with - | Z l -> - Array.fold_map_until na1N (l, off) - ~f:(fun naI (l, oI) -> - if Z.equal Z.zero l then - `Continue (_Extract ~seq:naI ~off:oI ~len:zero, (l, oI)) - else - let nI = seq_size_exn naI in - let oI_nI = sub oI nI in - match oI_nI with - | Z z -> - let oJ = if Z.sign z <= 0 then zero else oI_nI in - let lI = Z.(max zero (min l (neg z))) in - let l = Z.(l - lI) in - `Continue - (_Extract ~seq:naI ~off:oI ~len:(_Z lI), (l, oJ)) - | _ -> `Stop (Extract {seq; off; len}) ) - ~finish:(fun (e1N, _) -> _Concat e1N) - | _ -> Extract {seq; off; len} ) - (* α[o,l) *) - | _ -> Extract {seq; off; len} ) - |> check invariant + (** Invariant *) - and _Concat xs = - [%trace] - ~call:(fun {pf} -> pf "@ %a" pp (Concat xs)) - ~retn:(fun {pf} -> pf "%a" pp) - @@ fun () -> - (* (α^(β^γ)^δ) ==> (α^β^γ^δ) *) - let flatten xs = - if Array.exists ~f:(function Concat _ -> true | _ -> false) xs then - Array.flat_map ~f:(function Concat s -> s | e -> [|e|]) xs - else xs - in - let simp_adjacent e f = - match (e, f) with - (* ⟨n,a⟩[o,k)^⟨n,a⟩[o+k,l) ==> ⟨n,a⟩[o,k+l) when n ≥ o+k+l *) - | ( Extract {seq= Sized {siz= n} as na; off= o; len= k} - , Extract {seq= na'; off= o_k; len= l} ) - when equal na na' && equal o_k (add o k) && partial_ge n (add o_k l) - -> - Some (_Extract ~seq:na ~off:o ~len:(add k l)) - (* ⟨m,0⟩^⟨n,0⟩ ==> ⟨m+n,0⟩ *) - | Sized {siz= m; seq= a}, Sized {siz= n; seq= a'} - when a == zero && a' == zero -> - Some (_Sized ~seq:a ~siz:(add m n)) - (* ⟨m,E^⟩^⟨n,E^⟩ ==> ⟨m+n,E^⟩ *) - | Sized {siz= m; seq= Splat _ as a}, Sized {siz= n; seq= a'} - when equal a a' -> - Some (_Sized ~seq:a ~siz:(add m n)) - | _ -> None - in - let xs = flatten xs in - let xs = Array.reduce_adjacent ~f:simp_adjacent xs in - (if Array.length xs = 1 then xs.(0) else Concat xs) |> check invariant - - let _Apply f es = - ( match Funsym.eval ~equal ~get_z ~ret_z:_Z ~get_q ~ret_q:_Q f es with - | Some c -> c - | None -> Apply (f, es) ) - |> check invariant + let invariant e = + let@ () = Invariant.invariant [%here] e [%sexp_of: t] in + match e with + | Q q -> assert (not (Z.equal Z.one (Q.den q))) + | Arith a -> ( + match Arith0.classify a with + | Trm _ | Const _ -> assert false + | _ -> () ) + | _ -> () (** Traverse *) let rec iter_vars e ~f = match e with - | Var _ as v -> f (Var1.of_ v) + | Var _ as v -> f (Var.of_ v) | Z _ | Q _ -> () | Splat x -> iter_vars ~f x | Sized {seq= x; siz= y} -> @@ -416,33 +203,55 @@ end = struct iter_vars ~f y ; iter_vars ~f z | Concat xs | Apply (_, xs) -> Array.iter ~f:(iter_vars ~f) xs - | Arith a -> Iter.iter ~f:(iter_vars ~f) (Arith.trms a) + | Arith a -> Iter.iter ~f:(iter_vars ~f) (Arith0.trms a) let vars e = Iter.from_labelled_iter (iter_vars e) + + (** Construct *) + + (* statically allocated since they are tested with == *) + let zero = Z Z.zero |> check invariant + let one = Z Z.one |> check invariant + + let _Z z = + (if Z.equal Z.zero z then zero else if Z.equal Z.one z then one else Z z) + |> check invariant + + let _Q q = + (if Z.equal Z.one (Q.den q) then _Z (Q.num q) else Q q) + |> check invariant + + let _Arith a = + ( match Arith0.classify a with + | Trm e -> e + | Const q -> _Q q + | _ -> Arith a ) + |> check invariant end include Trm -module Var = Var1 -module Set = struct - include Set.Make (Trm) - include Provide_of_sexp (Trm) - include Provide_pp (Trm) - - let of_vars : Var.Set.t -> t = - fun vs -> - of_iter - (Iter.map ~f:(fun v -> (v : Var.t :> Trm.t)) (Var.Set.to_iter vs)) -end +(* Instantiate arithmetic with embedding into terms, yielding full + Arithmetic interface *) +module Arith = + Arith0.Embed (Var) (Trm) + (struct + let to_trm = _Arith + + let get_arith e = + match e with + | Z z -> Some (Arith0.const (Q.of_z z)) + | Q q -> Some (Arith0.const q) + | Arith a -> Some a + | _ -> None + end) -module Map = struct - include Map.Make (Trm) - include Provide_of_sexp (Trm) -end +(* Full Trm definition, using full arithmetic interface *) -type arith = Arith.t +(** Destruct *) -let pp_diff fs (x, y) = Format.fprintf fs "-- %a ++ %a" pp x pp y +let get_z = function Z z -> Some z | _ -> None +let get_q = function Q q -> Some q | Z z -> Some (Q.of_z z) | _ -> None (** Construct *) @@ -452,13 +261,11 @@ let var v = (v : Var.t :> t) (* arithmetic *) -let zero = _Z Z.zero -let one = _Z Z.one let integer z = _Z z let rational q = _Q q let neg x = _Arith Arith.(neg (trm x)) -let add = Trm.add -let sub = Trm.sub +let add x y = _Arith Arith.(add (trm x) (trm y)) +let sub x y = _Arith Arith.(sub (trm x) (trm y)) let mulq q x = _Arith Arith.(mulc q (trm x)) let mul x y = _Arith (Arith.mul x y) let div x y = _Arith (Arith.div x y) @@ -467,14 +274,145 @@ let arith = _Arith (* sequences *) -let splat = _Splat -let sized = _Sized -let extract = _Extract -let concat elts = _Concat elts +let splat x = + (* 0^ ==> 0 *) + (if x == zero then x else Splat x) |> check invariant + +let seq_size_exn = + let invalid = Invalid_argument "seq_size_exn" in + let rec seq_size_exn = function + | Sized {siz= n} | Extract {len= n} -> n + | Concat a0U -> + Array.fold ~f:(fun aJ a0I -> add a0I (seq_size_exn aJ)) a0U zero + | _ -> raise invalid + in + seq_size_exn + +let seq_size e = try Some (seq_size_exn e) with Invalid_argument _ -> None + +let sized ~seq ~siz = + ( match seq_size seq with + (* ⟨n,α⟩ ==> α when n ≡ |α| *) + | Some n when equal siz n -> seq + | _ -> Sized {seq; siz} ) + |> check invariant + +let partial_compare x y = + match sub x y with + | Z z -> Some (Int.sign (Z.sign z)) + | Q q -> Some (Int.sign (Q.sign q)) + | _ -> None + +let partial_ge x y = + match partial_compare x y with Some (Pos | Zero) -> true | _ -> false + +let empty_seq = Concat [||] + +let rec extract ~seq ~off ~len = + [%trace] + ~call:(fun {pf} -> pf "@ %a" pp (Extract {seq; off; len})) + ~retn:(fun {pf} -> pf "%a" pp) + @@ fun () -> + (* _[_,0) ==> ⟨⟩ *) + ( if equal len zero then empty_seq + else + let o_l = add off len in + match seq with + (* α[m,k)[o,l) ==> α[m+o,l) when k ≥ o+l *) + | Extract {seq= a; off= m; len= k} when partial_ge k o_l -> + extract ~seq:a ~off:(add m off) ~len + (* ⟨n,0⟩[o,l) ==> ⟨l,0⟩ when n ≥ o+l *) + | Sized {siz= n; seq} when seq == zero && partial_ge n o_l -> + sized ~seq ~siz:len + (* ⟨n,E^⟩[o,l) ==> ⟨l,E^⟩ when n ≥ o+l *) + | Sized {siz= n; seq= Splat _ as e} when partial_ge n o_l -> + sized ~seq:e ~siz:len + (* ⟨n,a⟩[0,n) ==> ⟨n,a⟩ *) + | Sized {siz= n} when equal off zero && equal n len -> seq + (* For (α₀^α₁)[o,l) there are 3 cases: + * + * ⟨...⟩^⟨...⟩ + * [,) + * o < o+l ≤ |α₀| : (α₀^α₁)[o,l) ==> α₀[o,l) ^ α₁[0,0) + * + * ⟨...⟩^⟨...⟩ + * [ , ) + * o ≤ |α₀| < o+l : (α₀^α₁)[o,l) ==> α₀[o,|α₀|-o) ^ α₁[0,l-(|α₀|-o)) + * + * ⟨...⟩^⟨...⟩ + * [,) + * |α₀| ≤ o : (α₀^α₁)[o,l) ==> α₀[o,0) ^ α₁[o-|α₀|,l) + * + * So in general: + * + * (α₀^α₁)[o,l) ==> α₀[o,l₀) ^ α₁[o₁,l-l₀) + * where l₀ = max 0 (min l |α₀|-o) + * o₁ = max 0 o-|α₀| + *) + | Concat na1N -> ( + match len with + | Z l -> + Array.fold_map_until na1N (l, off) + ~f:(fun naI (l, oI) -> + if Z.equal Z.zero l then + `Continue (extract ~seq:naI ~off:oI ~len:zero, (l, oI)) + else + let nI = seq_size_exn naI in + let oI_nI = sub oI nI in + match oI_nI with + | Z z -> + let oJ = if Z.sign z <= 0 then zero else oI_nI in + let lI = Z.(max zero (min l (neg z))) in + let l = Z.(l - lI) in + `Continue + (extract ~seq:naI ~off:oI ~len:(_Z lI), (l, oJ)) + | _ -> `Stop (Extract {seq; off; len}) ) + ~finish:(fun (e1N, _) -> concat e1N) + | _ -> Extract {seq; off; len} ) + (* α[o,l) *) + | _ -> Extract {seq; off; len} ) + |> check invariant + +and concat xs = + [%trace] + ~call:(fun {pf} -> pf "@ %a" pp (Concat xs)) + ~retn:(fun {pf} -> pf "%a" pp) + @@ fun () -> + (* (α^(β^γ)^δ) ==> (α^β^γ^δ) *) + let flatten xs = + if Array.exists ~f:(function Concat _ -> true | _ -> false) xs then + Array.flat_map ~f:(function Concat s -> s | e -> [|e|]) xs + else xs + in + let simp_adjacent e f = + match (e, f) with + (* ⟨n,a⟩[o,k)^⟨n,a⟩[o+k,l) ==> ⟨n,a⟩[o,k+l) when n ≥ o+k+l *) + | ( Extract {seq= Sized {siz= n} as na; off= o; len= k} + , Extract {seq= na'; off= o_k; len= l} ) + when equal na na' && equal o_k (add o k) && partial_ge n (add o_k l) + -> + Some (extract ~seq:na ~off:o ~len:(add k l)) + (* ⟨m,0⟩^⟨n,0⟩ ==> ⟨m+n,0⟩ *) + | Sized {siz= m; seq= a}, Sized {siz= n; seq= a'} + when a == zero && a' == zero -> + Some (sized ~seq:a ~siz:(add m n)) + (* ⟨m,E^⟩^⟨n,E^⟩ ==> ⟨m+n,E^⟩ *) + | Sized {siz= m; seq= Splat _ as a}, Sized {siz= n; seq= a'} + when equal a a' -> + Some (sized ~seq:a ~siz:(add m n)) + | _ -> None + in + let xs = flatten xs in + let xs = Array.reduce_adjacent ~f:simp_adjacent xs in + (if Array.length xs = 1 then xs.(0) else Concat xs) |> check invariant (* uninterpreted *) -let apply sym args = _Apply sym args +let apply f es = + ( match Funsym.eval ~equal ~get_z ~ret_z:_Z ~get_q ~ret_q:_Q f es with + | Some c -> c + | None -> Apply (f, es) ) + |> check invariant (** Traverse *) @@ -505,25 +443,25 @@ let rec map_vars e ~f = | Var _ as v -> (f (Var.of_ v) : Var.t :> t) | Z _ | Q _ -> e | Arith a -> map1 (Arith.map ~f:(map_vars ~f)) e _Arith a - | Splat x -> map1 (map_vars ~f) e _Splat x + | Splat x -> map1 (map_vars ~f) e splat x | Sized {seq; siz} -> - map2 (map_vars ~f) e (fun seq siz -> _Sized ~seq ~siz) seq siz + map2 (map_vars ~f) e (fun seq siz -> sized ~seq ~siz) seq siz | Extract {seq; off; len} -> map3 (map_vars ~f) e - (fun seq off len -> _Extract ~seq ~off ~len) + (fun seq off len -> extract ~seq ~off ~len) seq off len - | Concat xs -> mapN (map_vars ~f) e _Concat xs - | Apply (g, xs) -> mapN (map_vars ~f) e (_Apply g) xs + | Concat xs -> mapN (map_vars ~f) e concat xs + | Apply (g, xs) -> mapN (map_vars ~f) e (apply g) xs let map e ~f = match e with | Var _ | Z _ | Q _ -> e | Arith a -> map1 (Arith.map ~f) e _Arith a - | Splat x -> map1 f e _Splat x - | Sized {seq; siz} -> map2 f e (fun seq siz -> _Sized ~seq ~siz) seq siz + | Splat x -> map1 f e splat x + | Sized {seq; siz} -> map2 f e (fun seq siz -> sized ~seq ~siz) seq siz | Extract {seq; off; len} -> - map3 f e (fun seq off len -> _Extract ~seq ~off ~len) seq off len - | Concat xs -> mapN f e _Concat xs - | Apply (g, xs) -> mapN f e (_Apply g) xs + map3 f e (fun seq off len -> extract ~seq ~off ~len) seq off len + | Concat xs -> mapN f e concat xs + | Apply (g, xs) -> mapN f e (apply g) xs let fold_map e = fold_map_from_map map e diff --git a/sledge/src/fol/trm.mli b/sledge/src/fol/trm.mli index a23a23901..4a16d9e83 100644 --- a/sledge/src/fol/trm.mli +++ b/sledge/src/fol/trm.mli @@ -9,9 +9,11 @@ type arith +(** Terms, built from variables and applications of function symbols from + various theories. Denote functions from structures to values. *) type t = private (* variables *) - | Var of {id: int; name: string} + | Var of {id: int; name: string [@ignore]} (* arithmetic *) | Z of Z.t | Q of Q.t @@ -25,24 +27,15 @@ type t = private | Apply of Funsym.t * t array [@@deriving compare, equal, sexp] +(** Arithmetic terms *) module Arith : Arithmetic.S with type trm := t with type t = arith -module Var : sig - type trm := t - - include Var_intf.S with type t = private trm - - val of_ : trm -> t - val of_trm : trm -> t option -end - module Set : sig include Set.S with type elt := t val t_of_sexp : Sexp.t -> t val pp : t pp val pp_diff : (t * t) pp - val of_vars : Var.Set.t -> t end module Map : sig @@ -51,6 +44,16 @@ module Map : sig val t_of_sexp : (Sexp.t -> 'a) -> Sexp.t -> 'a t end +(** Variable terms, represented as a subtype of general terms *) +module Var : sig + type trm := t + + include + Var_intf.S with type t = private trm with type Set.t = private Set.t + + val of_trm : trm -> t option +end + val ppx : Var.strength -> t pp val pp : t pp val pp_diff : (t * t) pp