|
|
(*
|
|
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
|
*
|
|
|
* This source code is licensed under the MIT license found in the
|
|
|
* LICENSE file in the root directory of this source tree.
|
|
|
*)
|
|
|
|
|
|
(** Arithmetic terms *)
|
|
|
|
|
|
open Ses.Var_intf
|
|
|
include Arithmetic_intf
|
|
|
|
|
|
module Representation
|
|
|
(Var : VAR)
|
|
|
(Trm : INDETERMINATE with type var := Var.t) =
|
|
|
struct
|
|
|
module Prod = struct
|
|
|
include Multiset.Make
|
|
|
(Int)
|
|
|
(struct
|
|
|
type t = Trm.trm [@@deriving compare, equal, sexp]
|
|
|
end)
|
|
|
|
|
|
let t_of_sexp = t_of_sexp Trm.trm_of_sexp
|
|
|
end
|
|
|
|
|
|
module Mono = struct
|
|
|
type t = Prod.t [@@deriving compare, equal, sexp]
|
|
|
|
|
|
let num_den m = Prod.partition m ~f:(fun _ i -> i >= 0)
|
|
|
|
|
|
let ppx strength ppf power_product =
|
|
|
let pp_factor ppf (indet, exponent) =
|
|
|
if exponent = 1 then (Trm.ppx strength) ppf indet
|
|
|
else Format.fprintf ppf "%a^%i" (Trm.ppx strength) indet exponent
|
|
|
in
|
|
|
let pp_num ppf num =
|
|
|
if Prod.is_empty num then Trace.pp_styled `Magenta "1" ppf
|
|
|
else Prod.pp "@ @<2>× " pp_factor ppf num
|
|
|
in
|
|
|
let pp_den ppf den =
|
|
|
if not (Prod.is_empty den) then
|
|
|
Format.fprintf ppf "@ / %a"
|
|
|
(Prod.pp "@ / " pp_factor)
|
|
|
(Prod.map_counts ~f:Int.neg den)
|
|
|
in
|
|
|
let num, den = num_den power_product in
|
|
|
if Prod.is_singleton num && Prod.is_empty den then
|
|
|
Format.fprintf ppf "@[<2>%a@]" pp_num num
|
|
|
else Format.fprintf ppf "@[<2>(%a%a)@]" pp_num num pp_den den
|
|
|
|
|
|
(** [one] is the empty product Πᵢ₌₁⁰ xᵢ^pᵢ *)
|
|
|
let one = Prod.empty
|
|
|
|
|
|
let equal_one = Prod.is_empty
|
|
|
|
|
|
(** [of_ x₁ p₁] is the singleton product Πᵢ₌₁¹ x₁^p₁ *)
|
|
|
let of_ x p = Prod.of_ x p
|
|
|
|
|
|
(** [pow (Πᵢ₌₁ⁿ xᵢ^pᵢ) p] is Πᵢ₌₁ⁿ xᵢ^(pᵢ×p) *)
|
|
|
let pow mono = function
|
|
|
| 0 -> Prod.empty
|
|
|
| 1 -> mono
|
|
|
| power -> Prod.map_counts ~f:(Int.mul power) mono
|
|
|
|
|
|
let mul x y = Prod.union x y
|
|
|
|
|
|
(** [get_trm m] is [Some x] iff [equal m (of_ x 1)] *)
|
|
|
let get_trm mono =
|
|
|
match Prod.only_elt mono with Some (trm, 1) -> Some trm | _ -> None
|
|
|
|
|
|
(* traverse *)
|
|
|
|
|
|
let trms mono =
|
|
|
Iter.from_iter (fun f -> Prod.iter mono ~f:(fun trm _ -> f trm))
|
|
|
|
|
|
(* query *)
|
|
|
|
|
|
let vars p = Iter.flat_map ~f:Trm.vars (trms p)
|
|
|
let fv p = Var.Set.of_iter (vars p)
|
|
|
end
|
|
|
|
|
|
module Sum = struct
|
|
|
include Multiset.Make (Q) (Mono)
|
|
|
|
|
|
let t_of_sexp = t_of_sexp Mono.t_of_sexp
|
|
|
end
|
|
|
|
|
|
module Poly = struct
|
|
|
type t = Sum.t [@@deriving compare, equal, sexp]
|
|
|
type trm = Trm.trm
|
|
|
end
|
|
|
|
|
|
include Poly
|
|
|
|
|
|
module Make (Embed : EMBEDDING with type trm := Trm.trm and type t := t) =
|
|
|
struct
|
|
|
include Poly
|
|
|
|
|
|
let ppx strength ppf poly =
|
|
|
if Sum.is_empty poly then Trace.pp_styled `Magenta "0" ppf
|
|
|
else
|
|
|
let pp_coeff_mono ppf (m, c) =
|
|
|
if Mono.equal_one m then Trace.pp_styled `Magenta "%a" ppf Q.pp c
|
|
|
else if Q.equal Q.one c then
|
|
|
Format.fprintf ppf "%a" (Mono.ppx strength) m
|
|
|
else Format.fprintf ppf "%a@<1>×%a" Q.pp c (Mono.ppx strength) m
|
|
|
in
|
|
|
if Sum.is_singleton poly then
|
|
|
Format.fprintf ppf "@[<2>%a@]" (Sum.pp "@ + " pp_coeff_mono) poly
|
|
|
else
|
|
|
Format.fprintf ppf "@[<2>(%a)@]"
|
|
|
(Sum.pp "@ + " pp_coeff_mono)
|
|
|
poly
|
|
|
|
|
|
let pp = ppx (fun _ -> None)
|
|
|
|
|
|
let mono_invariant mono =
|
|
|
let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in
|
|
|
Prod.iter mono ~f:(fun base power ->
|
|
|
(* powers are non-zero *)
|
|
|
assert (not (Int.equal Int.zero power)) ;
|
|
|
match Embed.get_arith base with
|
|
|
| None -> ()
|
|
|
| Some poly -> (
|
|
|
match Sum.classify poly with
|
|
|
| `Many -> ()
|
|
|
| `Zero | `One _ ->
|
|
|
(* polynomial factors are not constant or singleton, which
|
|
|
should have been flattened into the parent monomial *)
|
|
|
assert false ) ) ;
|
|
|
match Mono.get_trm mono with
|
|
|
| None -> ()
|
|
|
| Some trm -> (
|
|
|
match Embed.get_arith trm with
|
|
|
| None -> ()
|
|
|
| Some _ ->
|
|
|
(* singleton monomials are not polynomials, which should have
|
|
|
been flattened into the parent polynomial *)
|
|
|
assert false )
|
|
|
|
|
|
let invariant poly =
|
|
|
let@ () = Invariant.invariant [%here] poly [%sexp_of: t] in
|
|
|
Sum.iter poly ~f:(fun mono coeff ->
|
|
|
(* coefficients are non-zero *)
|
|
|
assert (not (Q.equal Q.zero coeff)) ;
|
|
|
mono_invariant mono )
|
|
|
|
|
|
(* constants *)
|
|
|
|
|
|
let const q = Sum.of_ Mono.one q |> check invariant
|
|
|
let zero = const Q.zero |> check (fun p -> assert (Sum.is_empty p))
|
|
|
|
|
|
(* core constructors *)
|
|
|
|
|
|
let neg poly = Sum.map_counts ~f:Q.neg poly |> check invariant
|
|
|
let add p q = Sum.union p q |> check invariant
|
|
|
let sub p q = add p (neg q)
|
|
|
|
|
|
let mulc coeff poly =
|
|
|
( if Q.equal Q.one coeff then poly
|
|
|
else if Q.equal Q.zero coeff then zero
|
|
|
else Sum.map_counts ~f:(Q.mul coeff) poly )
|
|
|
|> check invariant
|
|
|
|
|
|
(* projections and embeddings *)
|
|
|
|
|
|
type view = Trm of trm | Const of Q.t | Compound
|
|
|
|
|
|
let classify poly =
|
|
|
match Sum.classify poly with
|
|
|
| `Zero -> Const Q.zero
|
|
|
| `One (mono, coeff) -> (
|
|
|
match Prod.classify mono with
|
|
|
| `Zero -> Const coeff
|
|
|
| `One (trm, 1) when Q.equal Q.one coeff -> Trm trm
|
|
|
| _ -> Compound )
|
|
|
| `Many -> Compound
|
|
|
|
|
|
let get_const poly =
|
|
|
match Sum.classify poly with
|
|
|
| `Zero -> Some Q.zero
|
|
|
| `One (mono, coeff) when Mono.equal_one mono -> Some coeff
|
|
|
| _ -> None
|
|
|
|
|
|
let get_mono poly =
|
|
|
match Sum.only_elt poly with
|
|
|
| Some (mono, coeff) when Q.equal Q.one coeff -> Some mono
|
|
|
| _ -> None
|
|
|
|
|
|
(** Terms of a polynomial: product of a coefficient and a monomial *)
|
|
|
module CM = struct
|
|
|
type t = Q.t * Prod.t
|
|
|
|
|
|
let one = (Q.one, Mono.one)
|
|
|
let mul (c1, m1) (c2, m2) = (Q.mul c1 c2, Mono.mul m1 m2)
|
|
|
|
|
|
(** Monomials [Mono.t] have [trm] indeterminates, which include, via
|
|
|
[get_arith], polynomials [t] over monomials themselves. To avoid
|
|
|
redundant representations, singleton polynomials are flattened. *)
|
|
|
let of_trm : ?power:int -> trm -> t =
|
|
|
fun ?(power = 1) base ->
|
|
|
match Embed.get_arith base with
|
|
|
| Some poly -> (
|
|
|
match Sum.classify poly with
|
|
|
(* 0 ^ p₁ ==> 0 × 1 *)
|
|
|
| `Zero -> (Q.zero, Mono.one)
|
|
|
(* (Σᵢ₌₁¹ cᵢ × Xᵢ) ^ p₁ ==> cᵢ^p₁ × Πⱼ₌₁¹ Xⱼ^pⱼ *)
|
|
|
| `One (mono, coeff) -> (Q.pow coeff power, Mono.pow mono power)
|
|
|
(* (Σᵢ₌₁ⁿ cᵢ × Xᵢ) ^ p₁ ==> 1 × Πⱼ₌₁¹ (Σᵢ₌₁ⁿ cᵢ × Xᵢ)^pⱼ *)
|
|
|
| `Many -> (Q.one, Mono.of_ base power) )
|
|
|
(* X₁ ^ p₁ ==> 1 × Πⱼ₌₁¹ Xⱼ^pⱼ *)
|
|
|
| None -> (Q.one, Mono.of_ base power)
|
|
|
|
|
|
(** Polynomials [t] have [trm] indeterminates, which, via [get_arith],
|
|
|
include polynomials themselves. To avoid redundant
|
|
|
representations, singleton monomials are flattened. Also, constant
|
|
|
multiplication is not interpreted in [Prod], so constant
|
|
|
polynomials are multiplied by their coefficients directly. *)
|
|
|
let to_poly : t -> Poly.t =
|
|
|
fun (coeff, mono) ->
|
|
|
( match Mono.get_trm mono with
|
|
|
| Some trm -> (
|
|
|
match Embed.get_arith trm with
|
|
|
(* c × (Σᵢ₌₁ⁿ cᵢ × Xᵢ) ==> Σᵢ₌₁ⁿ c×cᵢ × Xᵢ *)
|
|
|
| Some poly -> mulc coeff poly
|
|
|
(* c₁ × X₁ ==> Σᵢ₌₁¹ cᵢ × Xᵢ *)
|
|
|
| None -> Sum.of_ mono coeff )
|
|
|
(* c₁ × (Πⱼ₌₁ᵐ X₁ⱼ^p₁ⱼ) ==> Σᵢ₌₁¹ cᵢ × (Πⱼ₌₁ᵐ Xᵢⱼ^pᵢⱼ) *)
|
|
|
| None -> Sum.of_ mono coeff )
|
|
|
|> check invariant
|
|
|
end
|
|
|
|
|
|
(** Embed a term into a polynomial, by projecting a polynomial out of
|
|
|
the term if possible *)
|
|
|
let trm trm =
|
|
|
( match Embed.get_arith trm with
|
|
|
| Some poly -> poly
|
|
|
| None -> Sum.of_ (Mono.of_ trm 1) Q.one )
|
|
|
|> check (fun poly ->
|
|
|
assert (equal poly (CM.to_poly (CM.of_trm trm))) )
|
|
|
|
|
|
(** Project out the term embedded into a polynomial, if possible *)
|
|
|
let get_trm poly =
|
|
|
match get_mono poly with
|
|
|
| Some mono -> Mono.get_trm mono
|
|
|
| None -> None
|
|
|
|
|
|
(* constructors over indeterminates *)
|
|
|
|
|
|
let mul e1 e2 = CM.to_poly (CM.mul (CM.of_trm e1) (CM.of_trm e2))
|
|
|
|
|
|
let div n d =
|
|
|
CM.to_poly (CM.mul (CM.of_trm n) (CM.of_trm d ~power:(-1)))
|
|
|
|
|
|
let pow base power = CM.to_poly (CM.of_trm base ~power)
|
|
|
|
|
|
(* transform *)
|
|
|
|
|
|
let split_const poly =
|
|
|
match Sum.find_and_remove Mono.one poly with
|
|
|
| Some (c, p_c) -> (p_c, c)
|
|
|
| None -> (poly, Q.zero)
|
|
|
|
|
|
let partition_sign poly =
|
|
|
Sum.partition_map poly ~f:(fun _ coeff ->
|
|
|
if Q.sign coeff >= 0 then Left coeff else Right (Q.neg coeff) )
|
|
|
|
|
|
let map poly ~f =
|
|
|
[%trace]
|
|
|
~call:(fun {pf} -> pf "%a" pp poly)
|
|
|
~retn:(fun {pf} -> pf "%a" pp)
|
|
|
@@ fun () ->
|
|
|
let p, p' = (poly, Sum.empty) in
|
|
|
let p, p' =
|
|
|
Sum.fold poly (p, p') ~f:(fun mono coeff (p, p') ->
|
|
|
let m, cm' = (mono, CM.one) in
|
|
|
let m, cm' =
|
|
|
Prod.fold mono (m, cm') ~f:(fun trm power (m, cm') ->
|
|
|
let trm' = f trm in
|
|
|
if trm == trm' then (m, cm')
|
|
|
else
|
|
|
(Prod.remove trm m, CM.mul cm' (CM.of_trm trm' ~power)) )
|
|
|
in
|
|
|
( Sum.remove mono p
|
|
|
, Sum.union p' (CM.to_poly (CM.mul (coeff, m) cm')) ) )
|
|
|
in
|
|
|
Sum.union p p' |> check invariant
|
|
|
|
|
|
(* traverse *)
|
|
|
|
|
|
let monos poly =
|
|
|
Iter.from_iter (fun f -> Sum.iter poly ~f:(fun mono _ -> f mono))
|
|
|
|
|
|
let trms poly = Iter.flat_map ~f:Mono.trms (monos poly)
|
|
|
|
|
|
type product = Prod.t
|
|
|
|
|
|
let fold_factors = Prod.fold
|
|
|
let fold_monomials = Sum.fold
|
|
|
|
|
|
(* query *)
|
|
|
|
|
|
let vars p = Iter.flat_map ~f:Trm.vars (trms p)
|
|
|
|
|
|
(* solve *)
|
|
|
|
|
|
let exists_fv_in vs poly =
|
|
|
Iter.exists ~f:(fun v -> Var.Set.mem v vs) (vars poly)
|
|
|
|
|
|
(** [solve_for_mono r c m p] solves [0 = r + (c×m) + p] as [m = q]
|
|
|
([Some (m, q)]) such that [r + (c×m) + p = m - q] *)
|
|
|
let solve_for_mono rejected_poly coeff mono poly =
|
|
|
if Mono.equal_one mono || exists_fv_in (Mono.fv mono) poly then None
|
|
|
else
|
|
|
Some
|
|
|
( Sum.of_ mono Q.one
|
|
|
, mulc (Q.inv (Q.neg coeff)) (Sum.union rejected_poly poly) )
|
|
|
|
|
|
(** [solve_poly r p] solves [0 = r + p] as [m = q] ([Some (m, q)]) such
|
|
|
that [r + p = m - q] *)
|
|
|
let rec solve_poly rejected poly =
|
|
|
[%trace]
|
|
|
~call:(fun {pf} -> pf "0 = (%a) + (%a)" pp rejected pp poly)
|
|
|
~retn:(fun {pf} s ->
|
|
|
pf "%a"
|
|
|
(Option.pp "%a" (fun fs (v, q) ->
|
|
|
Format.fprintf fs "%a ↦ %a" pp v pp q ))
|
|
|
s )
|
|
|
@@ fun () ->
|
|
|
let* mono, coeff, poly = Sum.pop_min_elt poly in
|
|
|
match solve_for_mono rejected coeff mono poly with
|
|
|
| Some _ as soln -> soln
|
|
|
| None -> solve_poly (Sum.add mono coeff rejected) poly
|
|
|
|
|
|
(* solve [0 = e] *)
|
|
|
let solve_zero_eq ?for_ e =
|
|
|
[%trace]
|
|
|
~call:(fun {pf} ->
|
|
|
pf "0 = %a%a" Trm.pp e (Option.pp " for %a" Trm.pp) for_ )
|
|
|
~retn:(fun {pf} s ->
|
|
|
pf "%a"
|
|
|
(Option.pp "%a" (fun fs (c, r) ->
|
|
|
Format.fprintf fs "%a ↦ %a" pp c pp r ))
|
|
|
s ;
|
|
|
match (for_, s) with
|
|
|
| Some f, Some (c, _) -> assert (equal (trm f) c)
|
|
|
| _ -> () )
|
|
|
@@ fun () ->
|
|
|
let* a = Embed.get_arith e in
|
|
|
match for_ with
|
|
|
| None -> solve_poly Sum.empty a
|
|
|
| Some for_ -> (
|
|
|
let* for_poly = Embed.get_arith for_ in
|
|
|
match get_mono for_poly with
|
|
|
| Some m ->
|
|
|
let* c, p = Sum.find_and_remove m a in
|
|
|
solve_for_mono Sum.empty c m p
|
|
|
| _ -> None )
|
|
|
end
|
|
|
end
|