[sledge] Reorder Arithmetic definitions

Summary:
No functional change, apart from a minor change in invariant checking
for core constructors. Only to reduce diffs of upcoming changes.

Reviewed By: jvillard

Differential Revision: D26250518

fbshipit-source-id: a731e4fd6
master
Josh Berdine 4 years ago committed by Facebook GitHub Bot
parent c7c06addfd
commit 24ca0666d3

@ -53,8 +53,6 @@ struct
Format.fprintf ppf "@[<2>%a@]" pp_num num
else Format.fprintf ppf "@[<2>(%a%a)@]" pp_num num pp_den den
let pp = ppx Trm.pp
(** [one] is the empty product Πᵢ₌₁⁰ xᵢ^pᵢ *)
let one = Prod.empty
@ -79,11 +77,6 @@ struct
let trms mono =
Iter.from_iter (fun f -> Prod.iter mono ~f:(fun trm _ -> f trm))
(* query *)
let vars p = Iter.flat_map ~f:Trm.vars (trms p)
let fv p = Var.Set.of_iter (vars p)
end
module Sum = struct
@ -94,10 +87,7 @@ struct
module Poly = Sum
include Poly
module Make (Embed : EMBEDDING with type trm := Trm.t and type t := t) =
struct
include Poly
module S0 = struct
let ppx pp_trm ppf poly =
if Sum.is_empty poly then Trace.pp_styled `Magenta "0" ppf
else
@ -114,31 +104,13 @@ struct
(Sum.pp "@ + " pp_coeff_mono)
poly
let pp = ppx Trm.pp
(* core invariant *)
let mono_invariant mono =
let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in
Prod.iter mono ~f:(fun base power ->
Prod.iter mono ~f:(fun _ power ->
(* powers are non-zero *)
assert (not (Int.equal Int.zero power)) ;
match Embed.get_arith base with
| None -> ()
| Some poly -> (
match Sum.classify poly with
| `Many -> ()
| `Zero | `One _ ->
(* polynomial factors are not constant or singleton, which
should have been flattened into the parent monomial *)
assert false ) ) ;
match Mono.get_trm mono with
| None -> ()
| Some trm -> (
match Embed.get_arith trm with
| None -> ()
| Some _ ->
(* singleton monomials are not polynomials, which should have
been flattened into the parent polynomial *)
assert false )
assert (not (Int.equal Int.zero power)) )
let invariant poly =
let@ () = Invariant.invariant [%here] poly [%sexp_of: t] in
@ -164,6 +136,17 @@ struct
else Sum.map_counts ~f:(Q.mul coeff) poly )
|> check invariant
(* transform *)
let split_const poly =
match Sum.find_and_remove Mono.one poly with
| Some (c, p_c) -> (p_c, c)
| None -> (poly, Q.zero)
let partition_sign poly =
Sum.partition_map poly ~f:(fun _ coeff ->
if Q.sign coeff >= 0 then Left coeff else Right (Q.neg coeff) )
(* projections and embeddings *)
type kind = Trm of Trm.t | Const of Q.t | Interpreted | Uninterpreted
@ -199,6 +182,78 @@ struct
| Some (mono, coeff) when Q.equal Q.one coeff -> Some mono
| _ -> None
(** Project out the term embedded into a polynomial, if possible *)
let get_trm poly =
match get_mono poly with
| Some mono -> Mono.get_trm mono
| None -> None
end
module Make (Embed : EMBEDDING with type trm := Trm.t and type t := t) =
struct
module Mono = struct
include Mono
let pp = ppx Trm.pp
let vars p = Iter.flat_map ~f:Trm.vars (trms p)
let fv p = Var.Set.of_iter (vars p)
end
include Poly
include S0
let pp = ppx Trm.pp
(** Embed a monomial into a term, flattening if possible *)
let trm_of_mono mono =
match Mono.get_trm mono with
| Some trm -> trm
| None -> Embed.to_trm (Sum.of_ mono Q.one)
(* traverse *)
let monos poly =
Iter.from_iter (fun f ->
Sum.iter poly ~f:(fun mono _ ->
if not (Mono.equal_one mono) then f mono ) )
let trms poly =
match get_mono poly with
| Some mono -> Mono.trms mono
| None -> Iter.map ~f:trm_of_mono (monos poly)
let vars p = Iter.flat_map ~f:Trm.vars (trms p)
(* invariant *)
let mono_invariant mono =
mono_invariant mono ;
let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in
Prod.iter mono ~f:(fun base _ ->
match Embed.get_arith base with
| None -> ()
| Some poly -> (
match Sum.classify poly with
| `Many -> ()
| `Zero | `One _ ->
(* polynomial factors are not constant or singleton, which
should have been flattened into the parent monomial *)
assert false ) ) ;
match Mono.get_trm mono with
| None -> ()
| Some trm -> (
match Embed.get_arith trm with
| None -> ()
| Some _ ->
(* singleton monomials are not polynomials, which should have
been flattened into the parent polynomial *)
assert false )
let invariant poly =
invariant poly ;
let@ () = Invariant.invariant [%here] poly [%sexp_of: t] in
Sum.iter poly ~f:(fun mono _ -> mono_invariant mono)
(** Terms of a polynomial: product of a coefficient and a monomial *)
module CM = struct
type t = Q.t * Prod.t
@ -249,12 +304,6 @@ struct
|> check invariant
end
(** Embed a monomial into a term, flattening if possible *)
let trm_of_mono mono =
match Mono.get_trm mono with
| Some trm -> trm
| None -> Embed.to_trm (Sum.of_ mono Q.one)
(** Embed a term into a polynomial, by projecting a polynomial out of
the term if possible *)
let trm trm =
@ -264,12 +313,6 @@ struct
|> check (fun poly ->
assert (equal poly (CM.to_poly (CM.of_trm trm))) )
(** Project out the term embedded into a polynomial, if possible *)
let get_trm poly =
match get_mono poly with
| Some mono -> Mono.get_trm mono
| None -> None
(* constructors over indeterminates *)
let mul e1 e2 = CM.to_poly (CM.mul (CM.of_trm e1) (CM.of_trm e2))
@ -279,29 +322,6 @@ struct
let pow base power = CM.to_poly (CM.of_trm base ~power)
(* transform *)
let split_const poly =
match Sum.find_and_remove Mono.one poly with
| Some (c, p_c) -> (p_c, c)
| None -> (poly, Q.zero)
let partition_sign poly =
Sum.partition_map poly ~f:(fun _ coeff ->
if Q.sign coeff >= 0 then Left coeff else Right (Q.neg coeff) )
(* traverse *)
let monos poly =
Iter.from_iter (fun f ->
Sum.iter poly ~f:(fun mono _ ->
if not (Mono.equal_one mono) then f mono ) )
let trms poly =
match get_mono poly with
| Some mono -> Mono.trms mono
| None -> Iter.map ~f:trm_of_mono (monos poly)
(* map over [trms] *)
let map poly ~f =
[%trace]
@ -338,10 +358,6 @@ struct
Sum.union poly' delta )
|> check invariant
(* query *)
let vars p = Iter.flat_map ~f:Trm.vars (trms p)
(* solve *)
let exists_fv_in vs poly =

@ -7,26 +7,73 @@
(** Arithmetic terms *)
(** An embedding of arithmetic terms [t] into indeterminates [trm]. *)
module type EMBEDDING = sig
type t
type trm
val to_trm : t -> trm
(** Embedding from [t] to [trm]: [to_trm a] is arithmetic term [a]
embedded in an indeterminate term. *)
val get_arith : trm -> t option
(** Partial projection from [trm] to [t]: [get_arith x] is [Some a] if
[x = to_trm a]. This is used to flatten indeterminates that are
actually arithmetic for the client, thereby enabling arithmetic
operations to be interpreted more often. *)
end
(** Indeterminate terms, treated as atomic / variables except when they can
be flattened using [EMBEDDING.get_arith]. *)
module type INDETERMINATE = sig
type t [@@deriving compare, equal, sexp]
include Comparer.S with type t := t
type var
val pp : t pp
val vars : t -> var iter
end
module type S = sig
type trm
type t [@@deriving compare, equal, sexp]
val ppx : trm pp -> t pp
(** Construct and Destruct atomic terms *)
(** Construct *)
val trm : trm -> t
(** [trm x] represents the indeterminate term [x] *)
val const : Q.t -> t
(** [const q] represents the constant [q] *)
val get_const : t -> Q.t option
(** [get_const a] is [Some q] iff [equal a (const q)] *)
val neg : t -> t
val add : t -> t -> t
val sub : t -> t -> t
val mulc : Q.t -> t -> t
val trm : trm -> t
(** [trm x] represents the indeterminate term [x] *)
(** Transform *)
val split_const : t -> t * Q.t
(** Splits an arithmetic term into the sum of its constant and
non-constant parts. That is, [split_const a] is [(b, c)] such that
[a = b + c] and the absolute value of [c] is maximal. *)
val partition_sign : t -> t * t
(** [partition_sign a] is [(p, n)] such that [a] = [p - n] and all
coefficients in [p] and [n] are non-negative. *)
(** Destruct and Query *)
val get_trm : t -> trm option
(** [get_trm a] is [Some x] iff [equal a (trm x)] *)
val get_const : t -> Q.t option
(** [get_const a] is [Some q] iff [equal a (const q)] *)
type kind = Trm of trm | Const of Q.t | Interpreted | Uninterpreted
val classify : t -> kind
@ -37,27 +84,14 @@ module type S = sig
val is_uninterpreted : t -> bool
(** [is_uninterpreted a] iff [classify a = Uninterpreted] *)
(** Construct compound terms *)
(** Construct nonlinear arithmetic terms using the embedding, enabling
interpreting associativity, commutatitivity, and unit laws, but not
the full nonlinear arithmetic theory. *)
val neg : t -> t
val add : t -> t -> t
val sub : t -> t -> t
val mulc : Q.t -> t -> t
val mul : trm -> trm -> t
val div : trm -> trm -> t
val pow : trm -> int -> t
(** Transform *)
val split_const : t -> t * Q.t
(** Splits an arithmetic term into the sum of its constant and
non-constant parts. That is, [split_const a] is [(b, c)] such that
[a = b + c] and the absolute value of [c] is maximal. *)
val partition_sign : t -> t * t
(** [partition_sign a] is [(p, n)] such that [a] = [p - n] and all
coefficients in [p] and [n] are non-negative. *)
(** Traverse *)
val trms : t -> trm iter
@ -69,6 +103,8 @@ module type S = sig
term is a monomial, [trms (Π X^p)] is the sequence
of factors [X] for each [j]. *)
(** Transform *)
val map : t -> f:(trm -> trm) -> t
(** Map over the {!trms}. *)
@ -80,35 +116,6 @@ module type S = sig
is passed, then the subterm [e] must be [for_]. *)
end
(** Indeterminate terms, treated as atomic / variables except when they can
be flattened using [EMBEDDING.get_arith]. *)
module type INDETERMINATE = sig
type t [@@deriving compare, equal, sexp]
include Comparer.S with type t := t
type var
val pp : t pp
val vars : t -> var iter
end
(** An embedding of arithmetic terms [t] into indeterminates [trm]. *)
module type EMBEDDING = sig
type trm
type t
val to_trm : t -> trm
(** Embedding from [t] to [trm]: [to_trm a] is arithmetic term [a]
embedded in an indeterminate term. *)
val get_arith : trm -> t option
(** Partial projection from [trm] to [t]: [get_arith x] is [Some a] if
[x = to_trm a]. This is used to flatten indeterminates that are
actually arithmetic for the client, thereby enabling arithmetic
operations to be interpreted more often. *)
end
(** A type [t] representing arithmetic terms over indeterminates [trm]
together with a functor [Make] that takes an [EMBEDDING] of arithmetic
terms [t] into indeterminates [trm] and produces an implementation of

Loading…
Cancel
Save