You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

362 lines
12 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

(*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*)
(** Arithmetic terms *)
open Ses.Var_intf
include Arithmetic_intf
module Representation
(Var : VAR)
(Trm : INDETERMINATE with type var := Var.t) =
struct
module Prod = struct
include Multiset.Make
(Int)
(struct
type t = Trm.trm [@@deriving compare, equal, sexp]
end)
let t_of_sexp = t_of_sexp Trm.trm_of_sexp
end
module Mono = struct
type t = Prod.t [@@deriving compare, equal, sexp]
let num_den m = Prod.partition m ~f:(fun _ i -> i >= 0)
let ppx strength ppf power_product =
let pp_factor ppf (indet, exponent) =
if exponent = 1 then (Trm.ppx strength) ppf indet
else Format.fprintf ppf "%a^%i" (Trm.ppx strength) indet exponent
in
let pp_num ppf num =
if Prod.is_empty num then Trace.pp_styled `Magenta "1" ppf
else Prod.pp "@ @<2>× " pp_factor ppf num
in
let pp_den ppf den =
if not (Prod.is_empty den) then
Format.fprintf ppf "@ / %a"
(Prod.pp "@ / " pp_factor)
(Prod.map_counts ~f:Int.neg den)
in
let num, den = num_den power_product in
if Prod.is_singleton num && Prod.is_empty den then
Format.fprintf ppf "@[<2>%a@]" pp_num num
else Format.fprintf ppf "@[<2>(%a%a)@]" pp_num num pp_den den
(** [one] is the empty product Πᵢ₌₁⁰ xᵢ^pᵢ *)
let one = Prod.empty
let equal_one = Prod.is_empty
(** [of_ x₁ p₁] is the singleton product Πᵢ₌₁¹ x₁^p₁ *)
let of_ x p = Prod.of_ x p
(** [pow (Πᵢ₌₁ⁿ xᵢ^pᵢ) p] is Πᵢ₌₁ⁿ xᵢ^(pᵢ×p) *)
let pow mono = function
| 0 -> Prod.empty
| 1 -> mono
| power -> Prod.map_counts ~f:(Int.mul power) mono
let mul x y = Prod.union x y
(** [get_trm m] is [Some x] iff [equal m (of_ x 1)] *)
let get_trm mono =
match Prod.only_elt mono with Some (trm, 1) -> Some trm | _ -> None
(* traverse *)
let trms mono =
Iter.from_iter (fun f -> Prod.iter mono ~f:(fun trm _ -> f trm))
(* query *)
let vars p = Iter.flat_map ~f:Trm.vars (trms p)
let fv p = Var.Set.of_iter (vars p)
end
module Sum = struct
include Multiset.Make (Q) (Mono)
let t_of_sexp = t_of_sexp Mono.t_of_sexp
end
module Poly = struct
type t = Sum.t [@@deriving compare, equal, sexp]
type trm = Trm.trm
end
include Poly
module Make (Embed : EMBEDDING with type trm := Trm.trm and type t := t) =
struct
include Poly
let ppx strength ppf poly =
if Sum.is_empty poly then Trace.pp_styled `Magenta "0" ppf
else
let pp_coeff_mono ppf (m, c) =
if Mono.equal_one m then Trace.pp_styled `Magenta "%a" ppf Q.pp c
else if Q.equal Q.one c then
Format.fprintf ppf "%a" (Mono.ppx strength) m
else Format.fprintf ppf "%a@<1>×%a" Q.pp c (Mono.ppx strength) m
in
if Sum.is_singleton poly then
Format.fprintf ppf "@[<2>%a@]" (Sum.pp "@ + " pp_coeff_mono) poly
else
Format.fprintf ppf "@[<2>(%a)@]"
(Sum.pp "@ + " pp_coeff_mono)
poly
let pp = ppx (fun _ -> None)
let mono_invariant mono =
let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in
Prod.iter mono ~f:(fun base power ->
(* powers are non-zero *)
assert (not (Int.equal Int.zero power)) ;
match Embed.get_arith base with
| None -> ()
| Some poly -> (
match Sum.classify poly with
| `Many -> ()
| `Zero | `One _ ->
(* polynomial factors are not constant or singleton, which
should have been flattened into the parent monomial *)
assert false ) ) ;
match Mono.get_trm mono with
| None -> ()
| Some trm -> (
match Embed.get_arith trm with
| None -> ()
| Some _ ->
(* singleton monomials are not polynomials, which should have
been flattened into the parent polynomial *)
assert false )
let invariant poly =
let@ () = Invariant.invariant [%here] poly [%sexp_of: t] in
Sum.iter poly ~f:(fun mono coeff ->
(* coefficients are non-zero *)
assert (not (Q.equal Q.zero coeff)) ;
mono_invariant mono )
(* constants *)
let const q = Sum.of_ Mono.one q |> check invariant
let zero = const Q.zero |> check (fun p -> assert (Sum.is_empty p))
(* core constructors *)
let neg poly = Sum.map_counts ~f:Q.neg poly |> check invariant
let add p q = Sum.union p q |> check invariant
let sub p q = add p (neg q)
let mulc coeff poly =
( if Q.equal Q.one coeff then poly
else if Q.equal Q.zero coeff then zero
else Sum.map_counts ~f:(Q.mul coeff) poly )
|> check invariant
(* projections and embeddings *)
type view = Trm of trm | Const of Q.t | Compound
let classify poly =
match Sum.classify poly with
| `Zero -> Const Q.zero
| `One (mono, coeff) -> (
match Prod.classify mono with
| `Zero -> Const coeff
| `One (trm, 1) when Q.equal Q.one coeff -> Trm trm
| _ -> Compound )
| `Many -> Compound
let get_const poly =
match Sum.classify poly with
| `Zero -> Some Q.zero
| `One (mono, coeff) when Mono.equal_one mono -> Some coeff
| _ -> None
let get_mono poly =
match Sum.only_elt poly with
| Some (mono, coeff) when Q.equal Q.one coeff -> Some mono
| _ -> None
(** Terms of a polynomial: product of a coefficient and a monomial *)
module CM = struct
type t = Q.t * Prod.t
let one = (Q.one, Mono.one)
let mul (c1, m1) (c2, m2) = (Q.mul c1 c2, Mono.mul m1 m2)
(** Monomials [Mono.t] have [trm] indeterminates, which include, via
[get_arith], polynomials [t] over monomials themselves. To avoid
redundant representations, singleton polynomials are flattened. *)
let of_trm : ?power:int -> trm -> t =
fun ?(power = 1) base ->
match Embed.get_arith base with
| Some poly -> (
match Sum.classify poly with
(* 0 ^ p₁ ==> 0 × 1 *)
| `Zero -> (Q.zero, Mono.one)
(* (Σᵢ₌₁¹ cᵢ × Xᵢ) ^ p₁ ==> cᵢ^p₁ × Πⱼ₌₁¹ Xⱼ^pⱼ *)
| `One (mono, coeff) -> (Q.pow coeff power, Mono.pow mono power)
(* (Σᵢ₌₁ⁿ cᵢ × Xᵢ) ^ p₁ ==> 1 × Πⱼ₌₁¹ (Σᵢ₌₁ⁿ cᵢ × Xᵢ)^pⱼ *)
| `Many -> (Q.one, Mono.of_ base power) )
(* X₁ ^ p₁ ==> 1 × Πⱼ₌₁¹ Xⱼ^pⱼ *)
| None -> (Q.one, Mono.of_ base power)
(** Polynomials [t] have [trm] indeterminates, which, via [get_arith],
include polynomials themselves. To avoid redundant
representations, singleton monomials are flattened. Also, constant
multiplication is not interpreted in [Prod], so constant
polynomials are multiplied by their coefficients directly. *)
let to_poly : t -> Poly.t =
fun (coeff, mono) ->
( match Mono.get_trm mono with
| Some trm -> (
match Embed.get_arith trm with
(* c × (Σᵢ₌₁ⁿ cᵢ × Xᵢ) ==> Σᵢ₌₁ⁿ c×cᵢ × Xᵢ *)
| Some poly -> mulc coeff poly
(* c₁ × X₁ ==> Σᵢ₌₁¹ cᵢ × Xᵢ *)
| None -> Sum.of_ mono coeff )
(* c₁ × (Πⱼ₌₁ᵐ X₁ⱼ^p₁ⱼ) ==> Σᵢ₌₁¹ cᵢ × (Πⱼ₌₁ᵐ Xᵢⱼ^pᵢⱼ) *)
| None -> Sum.of_ mono coeff )
|> check invariant
end
(** Embed a term into a polynomial, by projecting a polynomial out of
the term if possible *)
let trm trm =
( match Embed.get_arith trm with
| Some poly -> poly
| None -> Sum.of_ (Mono.of_ trm 1) Q.one )
|> check (fun poly ->
assert (equal poly (CM.to_poly (CM.of_trm trm))) )
(** Project out the term embedded into a polynomial, if possible *)
let get_trm poly =
match get_mono poly with
| Some mono -> Mono.get_trm mono
| None -> None
(* constructors over indeterminates *)
let mul e1 e2 = CM.to_poly (CM.mul (CM.of_trm e1) (CM.of_trm e2))
let div n d =
CM.to_poly (CM.mul (CM.of_trm n) (CM.of_trm d ~power:(-1)))
let pow base power = CM.to_poly (CM.of_trm base ~power)
(* transform *)
let split_const poly =
match Sum.find_and_remove Mono.one poly with
| Some (c, p_c) -> (p_c, c)
| None -> (poly, Q.zero)
let partition_sign poly =
Sum.partition_map poly ~f:(fun _ coeff ->
if Q.sign coeff >= 0 then Left coeff else Right (Q.neg coeff) )
let map poly ~f =
[%trace]
~call:(fun {pf} -> pf "%a" pp poly)
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
let p, p' = (poly, Sum.empty) in
let p, p' =
Sum.fold poly (p, p') ~f:(fun mono coeff (p, p') ->
let m, cm' = (mono, CM.one) in
let m, cm' =
Prod.fold mono (m, cm') ~f:(fun trm power (m, cm') ->
let trm' = f trm in
if trm == trm' then (m, cm')
else
(Prod.remove trm m, CM.mul cm' (CM.of_trm trm' ~power)) )
in
( Sum.remove mono p
, Sum.union p' (CM.to_poly (CM.mul (coeff, m) cm')) ) )
in
Sum.union p p' |> check invariant
(* traverse *)
let monos poly =
Iter.from_iter (fun f -> Sum.iter poly ~f:(fun mono _ -> f mono))
let trms poly = Iter.flat_map ~f:Mono.trms (monos poly)
type product = Prod.t
let fold_factors = Prod.fold
let fold_monomials = Sum.fold
(* query *)
let vars p = Iter.flat_map ~f:Trm.vars (trms p)
(* solve *)
let exists_fv_in vs poly =
Iter.exists ~f:(fun v -> Var.Set.mem v vs) (vars poly)
(** [solve_for_mono r c m p] solves [0 = r + (c×m) + p] as [m = q]
([Some (m, q)]) such that [r + (c×m) + p = m - q] *)
let solve_for_mono rejected_poly coeff mono poly =
if Mono.equal_one mono || exists_fv_in (Mono.fv mono) poly then None
else
Some
( Sum.of_ mono Q.one
, mulc (Q.inv (Q.neg coeff)) (Sum.union rejected_poly poly) )
(** [solve_poly r p] solves [0 = r + p] as [m = q] ([Some (m, q)]) such
that [r + p = m - q] *)
let rec solve_poly rejected poly =
[%trace]
~call:(fun {pf} -> pf "0 = (%a) + (%a)" pp rejected pp poly)
~retn:(fun {pf} s ->
pf "%a"
(Option.pp "%a" (fun fs (v, q) ->
Format.fprintf fs "%a ↦ %a" pp v pp q ))
s )
@@ fun () ->
let* mono, coeff, poly = Sum.pop_min_elt poly in
match solve_for_mono rejected coeff mono poly with
| Some _ as soln -> soln
| None -> solve_poly (Sum.add mono coeff rejected) poly
(* solve [0 = e] *)
let solve_zero_eq ?for_ e =
[%trace]
~call:(fun {pf} ->
pf "0 = %a%a" Trm.pp e (Option.pp " for %a" Trm.pp) for_ )
~retn:(fun {pf} s ->
pf "%a"
(Option.pp "%a" (fun fs (c, r) ->
Format.fprintf fs "%a ↦ %a" pp c pp r ))
s ;
match (for_, s) with
| Some f, Some (c, _) -> assert (equal (trm f) c)
| _ -> () )
@@ fun () ->
let* a = Embed.get_arith e in
match for_ with
| None -> solve_poly Sum.empty a
| Some for_ -> (
let* for_poly = Embed.get_arith for_ in
match get_mono for_poly with
| Some m ->
let* c, p = Sum.find_and_remove m a in
solve_for_mono Sum.empty c m p
| _ -> None )
end
end