You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

454 lines
14 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

(*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*)
(** Arithmetic terms *)
include Arithmetic_intf
module Int = struct
include Int
include Comparer.Make (Int)
end
module Q = struct
include Q
include Comparer.Make (Q)
end
type ('trm, 'compare_trm) mono = ('trm, int, 'compare_trm) Multiset.t
[@@deriving compare, equal, sexp]
type 'compare_trm compare_mono =
('compare_trm, Int.compare) Multiset.compare
[@@deriving compare, equal, sexp]
type ('trm, 'compare_trm) t =
(('trm, 'compare_trm) mono, Q.t, 'compare_trm compare_mono) Multiset.t
[@@deriving compare, equal, sexp]
module Make (Trm0 : sig
type t [@@deriving equal, sexp]
include Comparer.S with type t := t
end) =
struct
module Prod = Multiset.Make (Trm0) (Int)
module Mono = struct
type t = Prod.t [@@deriving compare, equal, sexp_of]
let num_den m = Prod.partition m ~f:(fun _ i -> i >= 0)
let ppx pp_trm ppf power_product =
let pp_factor ppf (indet, exponent) =
if exponent = 1 then pp_trm ppf indet
else Format.fprintf ppf "%a^%i" pp_trm indet exponent
in
let pp_num ppf num =
if Prod.is_empty num then Trace.pp_styled `Magenta "1" ppf
else Prod.pp "@ @<2>× " pp_factor ppf num
in
let pp_den ppf den =
if not (Prod.is_empty den) then
Format.fprintf ppf "@ / %a"
(Prod.pp "@ / " pp_factor)
(Prod.map_counts ~f:Int.neg den)
in
let num, den = num_den power_product in
if Prod.is_singleton num && Prod.is_empty den then
Format.fprintf ppf "@[<2>%a@]" pp_num num
else Format.fprintf ppf "@[<2>(%a%a)@]" pp_num num pp_den den
(** [one] is the empty product Πᵢ₌₁⁰ xᵢ^pᵢ *)
let one = Prod.empty
let equal_one = Prod.is_empty
(** [of_ x₁ p₁] is the singleton product Πᵢ₌₁¹ x₁^p₁ *)
let of_ x p = Prod.of_ x p
(** [pow (Πᵢ₌₁ⁿ xᵢ^pᵢ) p] is Πᵢ₌₁ⁿ xᵢ^(pᵢ×p) *)
let pow mono = function
| 0 -> Prod.empty
| 1 -> mono
| power -> Prod.map_counts ~f:(Int.mul power) mono
let mul x y = Prod.union x y
(** [get_trm m] is [Some x] iff [equal m (of_ x 1)] *)
let get_trm mono =
match Prod.only_elt mono with Some (trm, 1) -> Some trm | _ -> None
(* traverse *)
let trms mono =
Iter.from_iter (fun f -> Prod.iter mono ~f:(fun trm _ -> f trm))
end
module Sum = Multiset.Make (Prod) (Q)
module Poly = Sum
include Poly
module S0 = struct
let ppx pp_trm ppf poly =
if Sum.is_empty poly then Trace.pp_styled `Magenta "0" ppf
else
let pp_coeff_mono ppf (m, c) =
if Mono.equal_one m then Trace.pp_styled `Magenta "%a" ppf Q.pp c
else if Q.equal Q.one c then
Format.fprintf ppf "%a" (Mono.ppx pp_trm) m
else Format.fprintf ppf "%a@<1>×%a" Q.pp c (Mono.ppx pp_trm) m
in
if Sum.is_singleton poly then
Format.fprintf ppf "@[<2>%a@]" (Sum.pp "@ + " pp_coeff_mono) poly
else
Format.fprintf ppf "@[<2>(%a)@]"
(Sum.pp "@ + " pp_coeff_mono)
poly
let trms poly =
Iter.from_iter (fun f ->
Sum.iter poly ~f:(fun mono _ ->
Prod.iter mono ~f:(fun trm _ -> f trm) ) )
(* core invariant *)
let mono_invariant mono =
let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in
Prod.iter mono ~f:(fun _ power ->
(* powers are non-zero *)
assert (not (Int.equal Int.zero power)) )
let invariant poly =
let@ () = Invariant.invariant [%here] poly [%sexp_of: t] in
Sum.iter poly ~f:(fun mono coeff ->
(* coefficients are non-zero *)
assert (not (Q.equal Q.zero coeff)) ;
mono_invariant mono )
(* embed a term into a polynomial *)
let trm trm = Sum.of_ (Mono.of_ trm 1) Q.one
(* constants *)
let const q = Sum.of_ Mono.one q |> check invariant
let zero = const Q.zero |> check (fun p -> assert (Sum.is_empty p))
let one = const Q.one
(* core constructors *)
let neg poly = Sum.map_counts ~f:Q.neg poly |> check invariant
let add p q = Sum.union p q |> check invariant
let sub p q = add p (neg q)
let mulc coeff poly =
( if Q.equal Q.one coeff then poly
else if Q.equal Q.zero coeff then zero
else Sum.map_counts ~f:(Q.mul coeff) poly )
|> check invariant
(* transform *)
let split_const poly =
match Sum.find_and_remove Mono.one poly with
| Some c, p_c -> (p_c, c)
| None, _ -> (poly, Q.zero)
let partition_sign poly =
Sum.partition_map poly ~f:(fun _ coeff ->
if Q.sign coeff >= 0 then Left coeff else Right (Q.neg coeff) )
(* projections and embeddings *)
type kind = Trm of Trm0.t | Const of Q.t | Interpreted | Uninterpreted
let classify poly =
match Sum.classify poly with
| Zero2 -> Const Q.zero
| One2 (mono, coeff) -> (
match Prod.classify mono with
| Zero2 -> Const coeff
| One2 (trm, 1) ->
if Q.equal Q.one coeff then Trm trm else Interpreted
| _ -> Uninterpreted )
| Many2 -> Interpreted
let is_uninterpreted poly =
match Sum.only_elt poly with
| Some (mono, _) -> (
match Prod.classify mono with
| Zero2 -> false
| One2 (_, 1) -> false
| _ -> true )
| None -> false
let get_const poly =
match Sum.classify poly with
| Zero2 -> Some Q.zero
| One2 (mono, coeff) when Mono.equal_one mono -> Some coeff
| _ -> None
let get_mono poly =
match Sum.only_elt poly with
| Some (mono, coeff) when Q.equal Q.one coeff -> Some mono
| _ -> None
(** Project out the term embedded into a polynomial, if possible *)
let get_trm poly =
match get_mono poly with
| Some mono -> Mono.get_trm mono
| None -> None
end
include S0
module Embed
(Var : Var_intf.S)
(Trm : TRM with type t = Trm0.t with type var := Var.t)
(Embed : EMBEDDING with type trm := Trm0.t with type t := t) =
struct
module Mono = struct
include Mono
let pp = ppx Trm.pp
let vars p = Iter.flat_map ~f:Trm.vars (trms p)
let fv p = Var.Set.of_iter (vars p)
end
include Poly
include S0
(** hide S0.trm and S0.trms that ignore the embedding, shadowed below *)
let[@warning "-32"] trm, trms = ((), ())
let pp = ppx Trm.pp
(** Embed a polynomial into a term *)
let trm_of_poly = Embed.to_trm
(** Embed a monomial into a term, flattening if possible *)
let trm_of_mono mono =
match Mono.get_trm mono with
| Some trm -> trm
| None -> trm_of_poly (Sum.of_ mono Q.one)
(* traverse *)
let vars poly = Iter.flat_map ~f:Trm.vars (S0.trms poly)
let monos poly =
Iter.from_iter (fun f ->
Sum.iter poly ~f:(fun mono _ ->
if not (Mono.equal_one mono) then f mono ) )
let trms poly =
match get_mono poly with
| Some mono -> Mono.trms mono
| None -> Iter.map ~f:trm_of_mono (monos poly)
let mono_invariant mono =
mono_invariant mono ;
let@ () = Invariant.invariant [%here] mono [%sexp_of: Mono.t] in
Prod.iter mono ~f:(fun base _ ->
match Embed.get_arith base with
| None -> ()
| Some poly -> (
match Sum.classify poly with
| Many2 -> ()
| Zero2 | One2 _ ->
(* polynomial factors are not constant or singleton, which
should have been flattened into the parent monomial *)
assert false ) ) ;
match Mono.get_trm mono with
| None -> ()
| Some trm -> (
match Embed.get_arith trm with
| None -> ()
| Some _ ->
(* singleton monomials are not polynomials, which should have
been flattened into the parent polynomial *)
assert false )
let invariant poly =
invariant poly ;
let@ () = Invariant.invariant [%here] poly [%sexp_of: t] in
Sum.iter poly ~f:(fun mono _ -> mono_invariant mono)
(** Terms of a polynomial: product of a coefficient and a monomial *)
module CM = struct
type t = Q.t * Prod.t
let mul (c1, m1) (c2, m2) = (Q.mul c1 c2, Mono.mul m1 m2)
(** Monomials [Mono.t] have [trm] indeterminates, which include, via
[get_arith], polynomials [t] over monomials themselves. To avoid
redundant representations, singleton polynomials are flattened. *)
let of_trm : ?power:int -> Trm.t -> t =
fun ?(power = 1) base ->
[%trace]
~call:(fun {pf} -> pf "@ %a^%i" Trm.pp base power)
~retn:(fun {pf} (c, m) -> pf "%a×%a" Q.pp c Mono.pp m)
@@ fun () ->
match Embed.get_arith base with
| Some poly -> (
match Sum.classify poly with
(* 0 ^ p₁ ==> 0 × 1 *)
| Zero2 -> (Q.zero, Mono.one)
(* (Σᵢ₌₁¹ cᵢ × Xᵢ) ^ p₁ ==> cᵢ^p₁ × Πⱼ₌₁¹ Xⱼ^pⱼ *)
| One2 (mono, coeff) -> (Q.pow coeff power, Mono.pow mono power)
(* (Σᵢ₌₁ⁿ cᵢ × Xᵢ) ^ p₁ ==> 1 × Πⱼ₌₁¹ (Σᵢ₌₁ⁿ cᵢ × Xᵢ)^pⱼ *)
| Many2 -> (Q.one, Mono.of_ base power) )
(* X₁ ^ p₁ ==> 1 × Πⱼ₌₁¹ Xⱼ^pⱼ *)
| None -> (Q.one, Mono.of_ base power)
(** Polynomials [t] have [trm] indeterminates, which, via [get_arith],
include polynomials themselves. To avoid redundant
representations, singleton monomials are flattened. Also, constant
multiplication is not interpreted in [Prod], so constant
polynomials are multiplied by their coefficients directly. *)
let to_poly : t -> Poly.t =
fun (coeff, mono) ->
[%trace]
~call:(fun {pf} -> pf "@ %a×%a" Q.pp coeff Mono.pp mono)
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
( match Mono.get_trm mono with
| Some trm -> (
match Embed.get_arith trm with
(* c × (Σᵢ₌₁ⁿ cᵢ × Xᵢ) ==> Σᵢ₌₁ⁿ c×cᵢ × Xᵢ *)
| Some poly -> mulc coeff poly
(* c₁ × X₁ ==> Σᵢ₌₁¹ cᵢ × Xᵢ *)
| None -> Sum.of_ mono coeff )
(* c₁ × (Πⱼ₌₁ᵐ X₁ⱼ^p₁ⱼ) ==> Σᵢ₌₁¹ cᵢ × (Πⱼ₌₁ᵐ Xᵢⱼ^pᵢⱼ) *)
| None -> Sum.of_ mono coeff )
|> check invariant
end
(** Embed a term into a polynomial, by projecting a polynomial out of
the term if possible *)
let trm trm =
( match Embed.get_arith trm with
| Some poly -> poly
| None -> S0.trm trm )
|> check (fun poly ->
assert (equal poly (CM.to_poly (CM.of_trm trm))) )
(* constructors over indeterminates *)
let mul e1 e2 = CM.to_poly (CM.mul (CM.of_trm e1) (CM.of_trm e2))
let div n d =
CM.to_poly (CM.mul (CM.of_trm n) (CM.of_trm d ~power:(-1)))
let pow base power = CM.to_poly (CM.of_trm base ~power)
(** map over [trms] *)
let map poly ~f =
[%trace]
~call:(fun {pf} -> pf "@ %a" pp poly)
~retn:(fun {pf} -> pf "%a" pp)
@@ fun () ->
( match get_mono poly with
| Some mono ->
let mono', (coeff, mono_delta) =
Prod.fold mono
(mono, (Q.one, Mono.one))
~f:(fun base power (mono', delta) ->
let base' = f base in
if base' == base then (mono', delta)
else
( Prod.remove base mono'
, CM.mul delta (CM.of_trm ~power base') ) )
in
if mono' == mono then poly
else CM.to_poly (coeff, Mono.mul mono' mono_delta)
| None ->
let poly', delta =
Sum.fold poly (poly, Sum.empty)
~f:(fun mono coeff (poly', delta) ->
if Mono.equal_one mono then (poly', delta)
else
let e = trm_of_mono mono in
let e' = f e in
if e == e' then (poly', delta)
else
( Sum.remove mono poly'
, Sum.union delta (mulc coeff (trm e')) ) )
in
Sum.union poly' delta )
|> check invariant
(* solve *)
let exists_fv_in vs poly =
Iter.exists ~f:(fun v -> Var.Set.mem v vs) (vars poly)
(** [solve_for_mono r c m p] solves [0 = r + (c×m) + p] as [m = q]
([Some (m, q)]) such that [r + (c×m) + p = m - q] *)
let solve_for_mono rejected_poly coeff mono poly =
[%trace]
~call:(fun {pf} ->
pf "@ 0 = %a + (%a×%a) + %a" pp rejected_poly Q.pp coeff Mono.pp
mono pp poly )
~retn:(fun {pf} s ->
pf "%a"
(Option.pp "%a" (fun fs (v, q) ->
Format.fprintf fs "%a ↦ %a" pp v pp q ))
s )
@@ fun () ->
if Mono.equal_one mono || exists_fv_in (Mono.fv mono) poly then None
else
Some
( Sum.of_ mono Q.one
, mulc (Q.inv (Q.neg coeff)) (Sum.union rejected_poly poly) )
(** [solve_poly r p] solves [0 = r + p] as [m = q] ([Some (m, q)]) such
that [r + p = m - q] *)
let rec solve_poly rejected poly =
[%trace]
~call:(fun {pf} -> pf "@ 0 = (%a) + (%a)" pp rejected pp poly)
~retn:(fun {pf} s ->
pf "%a"
(Option.pp "%a" (fun fs (v, q) ->
Format.fprintf fs "%a ↦ %a" pp v pp q ))
s )
@@ fun () ->
let* mono, coeff, poly = Sum.pop_min_elt poly in
match solve_for_mono rejected coeff mono poly with
| Some _ as soln -> soln
| None -> solve_poly (Sum.add mono coeff rejected) poly
(** solve [0 = e] *)
let solve_zero_eq ?for_ e =
[%trace]
~call:(fun {pf} ->
pf "@ 0 = %a%a" Trm.pp e (Option.pp " for %a" Trm.pp) for_ )
~retn:(fun {pf} s ->
pf "%a"
(Option.pp "%a" (fun fs (c, r) ->
Format.fprintf fs "%a ↦ %a" pp c pp r ))
s ;
match (for_, s) with
| Some f, Some (c, _) -> assert (equal (trm f) c)
| _ -> () )
@@ fun () ->
let a = trm e in
match for_ with
| None -> solve_poly Sum.empty a
| Some for_ -> (
let for_poly = trm for_ in
match get_mono for_poly with
| Some m ->
let c, p = Sum.find_and_remove m a in
let* c = c in
solve_for_mono Sum.empty c m p
| _ -> None )
end
[@@inline]
end
[@@inline]