You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2172 lines
74 KiB

"""Integration method that emulates by-hand techniques.
This module also provides functionality to get the steps used to evaluate a
particular integral, in the ``integral_steps`` function. This will return
nested ``Rule`` s representing the integration rules used.
Each ``Rule`` class represents a (maybe parametrized) integration rule, e.g.
``SinRule`` for integrating ``sin(x)`` and ``ReciprocalSqrtQuadraticRule``
for integrating ``1/sqrt(a+b*x+c*x**2)``. The ``eval`` method returns the
integration result.
The ``manualintegrate`` function computes the integral by calling ``eval``
on the rule returned by ``integral_steps``.
The integrator can be extended with new heuristics and evaluation
techniques. To do so, extend the ``Rule`` class, implement ``eval`` method,
then write a function that accepts an ``IntegralInfo`` object and returns
either a ``Rule`` instance or ``None``. If the new technique requires a new
match, add the key and call to the antiderivative function to integral_steps.
To enable simple substitutions, add the match to find_substitutions.
"""
from __future__ import annotations
from typing import NamedTuple, Type, Callable, Sequence
from abc import ABC, abstractmethod
from dataclasses import dataclass
from collections import defaultdict
from collections.abc import Mapping
from sympy.core.add import Add
from sympy.core.cache import cacheit
from sympy.core.containers import Dict
from sympy.core.expr import Expr
from sympy.core.function import Derivative
from sympy.core.logic import fuzzy_not
from sympy.core.mul import Mul
from sympy.core.numbers import Integer, Number, E
from sympy.core.power import Pow
from sympy.core.relational import Eq, Ne, Boolean
from sympy.core.singleton import S
from sympy.core.symbol import Dummy, Symbol, Wild
from sympy.functions.elementary.complexes import Abs
from sympy.functions.elementary.exponential import exp, log
from sympy.functions.elementary.hyperbolic import (HyperbolicFunction, csch,
cosh, coth, sech, sinh, tanh, asinh)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import (TrigonometricFunction,
cos, sin, tan, cot, csc, sec, acos, asin, atan, acot, acsc, asec)
from sympy.functions.special.delta_functions import Heaviside, DiracDelta
from sympy.functions.special.error_functions import (erf, erfi, fresnelc,
fresnels, Ci, Chi, Si, Shi, Ei, li)
from sympy.functions.special.gamma_functions import uppergamma
from sympy.functions.special.elliptic_integrals import elliptic_e, elliptic_f
from sympy.functions.special.polynomials import (chebyshevt, chebyshevu,
legendre, hermite, laguerre, assoc_laguerre, gegenbauer, jacobi,
OrthogonalPolynomial)
from sympy.functions.special.zeta_functions import polylog
from .integrals import Integral
from sympy.logic.boolalg import And
from sympy.ntheory.factor_ import primefactors
from sympy.polys.polytools import degree, lcm_list, gcd_list, Poly
from sympy.simplify.radsimp import fraction
from sympy.simplify.simplify import simplify
from sympy.solvers.solvers import solve
from sympy.strategies.core import switch, do_one, null_safe, condition
from sympy.utilities.iterables import iterable
from sympy.utilities.misc import debug
@dataclass
class Rule(ABC):
integrand: Expr
variable: Symbol
@abstractmethod
def eval(self) -> Expr:
pass
@abstractmethod
def contains_dont_know(self) -> bool:
pass
@dataclass
class AtomicRule(Rule, ABC):
"""A simple rule that does not depend on other rules"""
def contains_dont_know(self) -> bool:
return False
@dataclass
class ConstantRule(AtomicRule):
"""integrate(a, x) -> a*x"""
def eval(self) -> Expr:
return self.integrand * self.variable
@dataclass
class ConstantTimesRule(Rule):
"""integrate(a*f(x), x) -> a*integrate(f(x), x)"""
constant: Expr
other: Expr
substep: Rule
def eval(self) -> Expr:
return self.constant * self.substep.eval()
def contains_dont_know(self) -> bool:
return self.substep.contains_dont_know()
@dataclass
class PowerRule(AtomicRule):
"""integrate(x**a, x)"""
base: Expr
exp: Expr
def eval(self) -> Expr:
return Piecewise(
((self.base**(self.exp + 1))/(self.exp + 1), Ne(self.exp, -1)),
(log(self.base), True),
)
@dataclass
class NestedPowRule(AtomicRule):
"""integrate((x**a)**b, x)"""
base: Expr
exp: Expr
def eval(self) -> Expr:
m = self.base * self.integrand
return Piecewise((m / (self.exp + 1), Ne(self.exp, -1)),
(m * log(self.base), True))
@dataclass
class AddRule(Rule):
"""integrate(f(x) + g(x), x) -> integrate(f(x), x) + integrate(g(x), x)"""
substeps: list[Rule]
def eval(self) -> Expr:
return Add(*(substep.eval() for substep in self.substeps))
def contains_dont_know(self) -> bool:
return any(substep.contains_dont_know() for substep in self.substeps)
@dataclass
class URule(Rule):
"""integrate(f(g(x))*g'(x), x) -> integrate(f(u), u), u = g(x)"""
u_var: Symbol
u_func: Expr
substep: Rule
def eval(self) -> Expr:
result = self.substep.eval()
if self.u_func.is_Pow:
base, exp_ = self.u_func.as_base_exp()
if exp_ == -1:
# avoid needless -log(1/x) from substitution
result = result.subs(log(self.u_var), -log(base))
return result.subs(self.u_var, self.u_func)
def contains_dont_know(self) -> bool:
return self.substep.contains_dont_know()
@dataclass
class PartsRule(Rule):
"""integrate(u(x)*v'(x), x) -> u(x)*v(x) - integrate(u'(x)*v(x), x)"""
u: Symbol
dv: Expr
v_step: Rule
second_step: Rule | None # None when is a substep of CyclicPartsRule
def eval(self) -> Expr:
assert self.second_step is not None
v = self.v_step.eval()
return self.u * v - self.second_step.eval()
def contains_dont_know(self) -> bool:
return self.v_step.contains_dont_know() or (
self.second_step is not None and self.second_step.contains_dont_know())
@dataclass
class CyclicPartsRule(Rule):
"""Apply PartsRule multiple times to integrate exp(x)*sin(x)"""
parts_rules: list[PartsRule]
coefficient: Expr
def eval(self) -> Expr:
result = []
sign = 1
for rule in self.parts_rules:
result.append(sign * rule.u * rule.v_step.eval())
sign *= -1
return Add(*result) / (1 - self.coefficient)
def contains_dont_know(self) -> bool:
return any(substep.contains_dont_know() for substep in self.parts_rules)
@dataclass
class TrigRule(AtomicRule, ABC):
pass
@dataclass
class SinRule(TrigRule):
"""integrate(sin(x), x) -> -cos(x)"""
def eval(self) -> Expr:
return -cos(self.variable)
@dataclass
class CosRule(TrigRule):
"""integrate(cos(x), x) -> sin(x)"""
def eval(self) -> Expr:
return sin(self.variable)
@dataclass
class SecTanRule(TrigRule):
"""integrate(sec(x)*tan(x), x) -> sec(x)"""
def eval(self) -> Expr:
return sec(self.variable)
@dataclass
class CscCotRule(TrigRule):
"""integrate(csc(x)*cot(x), x) -> -csc(x)"""
def eval(self) -> Expr:
return -csc(self.variable)
@dataclass
class Sec2Rule(TrigRule):
"""integrate(sec(x)**2, x) -> tan(x)"""
def eval(self) -> Expr:
return tan(self.variable)
@dataclass
class Csc2Rule(TrigRule):
"""integrate(csc(x)**2, x) -> -cot(x)"""
def eval(self) -> Expr:
return -cot(self.variable)
@dataclass
class HyperbolicRule(AtomicRule, ABC):
pass
@dataclass
class SinhRule(HyperbolicRule):
"""integrate(sinh(x), x) -> cosh(x)"""
def eval(self) -> Expr:
return cosh(self.variable)
@dataclass
class CoshRule(HyperbolicRule):
"""integrate(cosh(x), x) -> sinh(x)"""
def eval(self):
return sinh(self.variable)
@dataclass
class ExpRule(AtomicRule):
"""integrate(a**x, x) -> a**x/ln(a)"""
base: Expr
exp: Expr
def eval(self) -> Expr:
return self.integrand / log(self.base)
@dataclass
class ReciprocalRule(AtomicRule):
"""integrate(1/x, x) -> ln(x)"""
base: Expr
def eval(self) -> Expr:
return log(self.base)
@dataclass
class ArcsinRule(AtomicRule):
"""integrate(1/sqrt(1-x**2), x) -> asin(x)"""
def eval(self) -> Expr:
return asin(self.variable)
@dataclass
class ArcsinhRule(AtomicRule):
"""integrate(1/sqrt(1+x**2), x) -> asin(x)"""
def eval(self) -> Expr:
return asinh(self.variable)
@dataclass
class ReciprocalSqrtQuadraticRule(AtomicRule):
"""integrate(1/sqrt(a+b*x+c*x**2), x) -> log(2*sqrt(c)*sqrt(a+b*x+c*x**2)+b+2*c*x)/sqrt(c)"""
a: Expr
b: Expr
c: Expr
def eval(self) -> Expr:
a, b, c, x = self.a, self.b, self.c, self.variable
return log(2*sqrt(c)*sqrt(a+b*x+c*x**2)+b+2*c*x)/sqrt(c)
@dataclass
class SqrtQuadraticDenomRule(AtomicRule):
"""integrate(poly(x)/sqrt(a+b*x+c*x**2), x)"""
a: Expr
b: Expr
c: Expr
coeffs: list[Expr]
def eval(self) -> Expr:
a, b, c, coeffs, x = self.a, self.b, self.c, self.coeffs.copy(), self.variable
# Integrate poly/sqrt(a+b*x+c*x**2) using recursion.
# coeffs are coefficients of the polynomial.
# Let I_n = x**n/sqrt(a+b*x+c*x**2), then
# I_n = A * x**(n-1)*sqrt(a+b*x+c*x**2) - B * I_{n-1} - C * I_{n-2}
# where A = 1/(n*c), B = (2*n-1)*b/(2*n*c), C = (n-1)*a/(n*c)
# See https://github.com/sympy/sympy/pull/23608 for proof.
result_coeffs = []
coeffs = coeffs.copy()
for i in range(len(coeffs)-2):
n = len(coeffs)-1-i
coeff = coeffs[i]/(c*n)
result_coeffs.append(coeff)
coeffs[i+1] -= (2*n-1)*b/2*coeff
coeffs[i+2] -= (n-1)*a*coeff
d, e = coeffs[-1], coeffs[-2]
s = sqrt(a+b*x+c*x**2)
constant = d-b*e/(2*c)
if constant == 0:
I0 = 0
else:
step = inverse_trig_rule(IntegralInfo(1/s, x), degenerate=False)
I0 = constant*step.eval()
return Add(*(result_coeffs[i]*x**(len(coeffs)-2-i)
for i in range(len(result_coeffs))), e/c)*s + I0
@dataclass
class SqrtQuadraticRule(AtomicRule):
"""integrate(sqrt(a+b*x+c*x**2), x)"""
a: Expr
b: Expr
c: Expr
def eval(self) -> Expr:
step = sqrt_quadratic_rule(IntegralInfo(self.integrand, self.variable), degenerate=False)
return step.eval()
@dataclass
class AlternativeRule(Rule):
"""Multiple ways to do integration."""
alternatives: list[Rule]
def eval(self) -> Expr:
return self.alternatives[0].eval()
def contains_dont_know(self) -> bool:
return any(substep.contains_dont_know() for substep in self.alternatives)
@dataclass
class DontKnowRule(Rule):
"""Leave the integral as is."""
def eval(self) -> Expr:
return Integral(self.integrand, self.variable)
def contains_dont_know(self) -> bool:
return True
@dataclass
class DerivativeRule(AtomicRule):
"""integrate(f'(x), x) -> f(x)"""
def eval(self) -> Expr:
assert isinstance(self.integrand, Derivative)
variable_count = list(self.integrand.variable_count)
for i, (var, count) in enumerate(variable_count):
if var == self.variable:
variable_count[i] = (var, count - 1)
break
return Derivative(self.integrand.expr, *variable_count)
@dataclass
class RewriteRule(Rule):
"""Rewrite integrand to another form that is easier to handle."""
rewritten: Expr
substep: Rule
def eval(self) -> Expr:
return self.substep.eval()
def contains_dont_know(self) -> bool:
return self.substep.contains_dont_know()
@dataclass
class CompleteSquareRule(RewriteRule):
"""Rewrite a+b*x+c*x**2 to a-b**2/(4*c) + c*(x+b/(2*c))**2"""
pass
@dataclass
class PiecewiseRule(Rule):
subfunctions: Sequence[tuple[Rule, bool | Boolean]]
def eval(self) -> Expr:
return Piecewise(*[(substep.eval(), cond)
for substep, cond in self.subfunctions])
def contains_dont_know(self) -> bool:
return any(substep.contains_dont_know() for substep, _ in self.subfunctions)
@dataclass
class HeavisideRule(Rule):
harg: Expr
ibnd: Expr
substep: Rule
def eval(self) -> Expr:
# If we are integrating over x and the integrand has the form
# Heaviside(m*x+b)*g(x) == Heaviside(harg)*g(symbol)
# then there needs to be continuity at -b/m == ibnd,
# so we subtract the appropriate term.
result = self.substep.eval()
return Heaviside(self.harg) * (result - result.subs(self.variable, self.ibnd))
def contains_dont_know(self) -> bool:
return self.substep.contains_dont_know()
@dataclass
class DiracDeltaRule(AtomicRule):
n: Expr
a: Expr
b: Expr
def eval(self) -> Expr:
n, a, b, x = self.n, self.a, self.b, self.variable
if n == 0:
return Heaviside(a+b*x)/b
return DiracDelta(a+b*x, n-1)/b
@dataclass
class TrigSubstitutionRule(Rule):
theta: Expr
func: Expr
rewritten: Expr
substep: Rule
restriction: bool | Boolean
def eval(self) -> Expr:
theta, func, x = self.theta, self.func, self.variable
func = func.subs(sec(theta), 1/cos(theta))
func = func.subs(csc(theta), 1/sin(theta))
func = func.subs(cot(theta), 1/tan(theta))
trig_function = list(func.find(TrigonometricFunction))
assert len(trig_function) == 1
trig_function = trig_function[0]
relation = solve(x - func, trig_function)
assert len(relation) == 1
numer, denom = fraction(relation[0])
if isinstance(trig_function, sin):
opposite = numer
hypotenuse = denom
adjacent = sqrt(denom**2 - numer**2)
inverse = asin(relation[0])
elif isinstance(trig_function, cos):
adjacent = numer
hypotenuse = denom
opposite = sqrt(denom**2 - numer**2)
inverse = acos(relation[0])
else: # tan
opposite = numer
adjacent = denom
hypotenuse = sqrt(denom**2 + numer**2)
inverse = atan(relation[0])
substitution = [
(sin(theta), opposite/hypotenuse),
(cos(theta), adjacent/hypotenuse),
(tan(theta), opposite/adjacent),
(theta, inverse)
]
return Piecewise(
(self.substep.eval().subs(substitution).trigsimp(), self.restriction)
)
def contains_dont_know(self) -> bool:
return self.substep.contains_dont_know()
@dataclass
class ArctanRule(AtomicRule):
"""integrate(a/(b*x**2+c), x) -> a/b / sqrt(c/b) * atan(x/sqrt(c/b))"""
a: Expr
b: Expr
c: Expr
def eval(self) -> Expr:
a, b, c, x = self.a, self.b, self.c, self.variable
return a/b / sqrt(c/b) * atan(x/sqrt(c/b))
@dataclass
class OrthogonalPolyRule(AtomicRule, ABC):
n: Expr
@dataclass
class JacobiRule(OrthogonalPolyRule):
a: Expr
b: Expr
def eval(self) -> Expr:
n, a, b, x = self.n, self.a, self.b, self.variable
return Piecewise(
(2*jacobi(n + 1, a - 1, b - 1, x)/(n + a + b), Ne(n + a + b, 0)),
(x, Eq(n, 0)),
((a + b + 2)*x**2/4 + (a - b)*x/2, Eq(n, 1)))
@dataclass
class GegenbauerRule(OrthogonalPolyRule):
a: Expr
def eval(self) -> Expr:
n, a, x = self.n, self.a, self.variable
return Piecewise(
(gegenbauer(n + 1, a - 1, x)/(2*(a - 1)), Ne(a, 1)),
(chebyshevt(n + 1, x)/(n + 1), Ne(n, -1)),
(S.Zero, True))
@dataclass
class ChebyshevTRule(OrthogonalPolyRule):
def eval(self) -> Expr:
n, x = self.n, self.variable
return Piecewise(
((chebyshevt(n + 1, x)/(n + 1) -
chebyshevt(n - 1, x)/(n - 1))/2, Ne(Abs(n), 1)),
(x**2/2, True))
@dataclass
class ChebyshevURule(OrthogonalPolyRule):
def eval(self) -> Expr:
n, x = self.n, self.variable
return Piecewise(
(chebyshevt(n + 1, x)/(n + 1), Ne(n, -1)),
(S.Zero, True))
@dataclass
class LegendreRule(OrthogonalPolyRule):
def eval(self) -> Expr:
n, x = self.n, self.variable
return(legendre(n + 1, x) - legendre(n - 1, x))/(2*n + 1)
@dataclass
class HermiteRule(OrthogonalPolyRule):
def eval(self) -> Expr:
n, x = self.n, self.variable
return hermite(n + 1, x)/(2*(n + 1))
@dataclass
class LaguerreRule(OrthogonalPolyRule):
def eval(self) -> Expr:
n, x = self.n, self.variable
return laguerre(n, x) - laguerre(n + 1, x)
@dataclass
class AssocLaguerreRule(OrthogonalPolyRule):
a: Expr
def eval(self) -> Expr:
return -assoc_laguerre(self.n + 1, self.a - 1, self.variable)
@dataclass
class IRule(AtomicRule, ABC):
a: Expr
b: Expr
@dataclass
class CiRule(IRule):
def eval(self) -> Expr:
a, b, x = self.a, self.b, self.variable
return cos(b)*Ci(a*x) - sin(b)*Si(a*x)
@dataclass
class ChiRule(IRule):
def eval(self) -> Expr:
a, b, x = self.a, self.b, self.variable
return cosh(b)*Chi(a*x) + sinh(b)*Shi(a*x)
@dataclass
class EiRule(IRule):
def eval(self) -> Expr:
a, b, x = self.a, self.b, self.variable
return exp(b)*Ei(a*x)
@dataclass
class SiRule(IRule):
def eval(self) -> Expr:
a, b, x = self.a, self.b, self.variable
return sin(b)*Ci(a*x) + cos(b)*Si(a*x)
@dataclass
class ShiRule(IRule):
def eval(self) -> Expr:
a, b, x = self.a, self.b, self.variable
return sinh(b)*Chi(a*x) + cosh(b)*Shi(a*x)
@dataclass
class LiRule(IRule):
def eval(self) -> Expr:
a, b, x = self.a, self.b, self.variable
return li(a*x + b)/a
@dataclass
class ErfRule(AtomicRule):
a: Expr
b: Expr
c: Expr
def eval(self) -> Expr:
a, b, c, x = self.a, self.b, self.c, self.variable
if a.is_extended_real:
return Piecewise(
(sqrt(S.Pi/(-a))/2 * exp(c - b**2/(4*a)) *
erf((-2*a*x - b)/(2*sqrt(-a))), a < 0),
(sqrt(S.Pi/a)/2 * exp(c - b**2/(4*a)) *
erfi((2*a*x + b)/(2*sqrt(a))), True))
return sqrt(S.Pi/a)/2 * exp(c - b**2/(4*a)) * \
erfi((2*a*x + b)/(2*sqrt(a)))
@dataclass
class FresnelCRule(AtomicRule):
a: Expr
b: Expr
c: Expr
def eval(self) -> Expr:
a, b, c, x = self.a, self.b, self.c, self.variable
return sqrt(S.Pi/(2*a)) * (
cos(b**2/(4*a) - c)*fresnelc((2*a*x + b)/sqrt(2*a*S.Pi)) +
sin(b**2/(4*a) - c)*fresnels((2*a*x + b)/sqrt(2*a*S.Pi)))
@dataclass
class FresnelSRule(AtomicRule):
a: Expr
b: Expr
c: Expr
def eval(self) -> Expr:
a, b, c, x = self.a, self.b, self.c, self.variable
return sqrt(S.Pi/(2*a)) * (
cos(b**2/(4*a) - c)*fresnels((2*a*x + b)/sqrt(2*a*S.Pi)) -
sin(b**2/(4*a) - c)*fresnelc((2*a*x + b)/sqrt(2*a*S.Pi)))
@dataclass
class PolylogRule(AtomicRule):
a: Expr
b: Expr
def eval(self) -> Expr:
return polylog(self.b + 1, self.a * self.variable)
@dataclass
class UpperGammaRule(AtomicRule):
a: Expr
e: Expr
def eval(self) -> Expr:
a, e, x = self.a, self.e, self.variable
return x**e * (-a*x)**(-e) * uppergamma(e + 1, -a*x)/a
@dataclass
class EllipticFRule(AtomicRule):
a: Expr
d: Expr
def eval(self) -> Expr:
return elliptic_f(self.variable, self.d/self.a)/sqrt(self.a)
@dataclass
class EllipticERule(AtomicRule):
a: Expr
d: Expr
def eval(self) -> Expr:
return elliptic_e(self.variable, self.d/self.a)*sqrt(self.a)
class IntegralInfo(NamedTuple):
integrand: Expr
symbol: Symbol
def manual_diff(f, symbol):
"""Derivative of f in form expected by find_substitutions
SymPy's derivatives for some trig functions (like cot) are not in a form
that works well with finding substitutions; this replaces the
derivatives for those particular forms with something that works better.
"""
if f.args:
arg = f.args[0]
if isinstance(f, tan):
return arg.diff(symbol) * sec(arg)**2
elif isinstance(f, cot):
return -arg.diff(symbol) * csc(arg)**2
elif isinstance(f, sec):
return arg.diff(symbol) * sec(arg) * tan(arg)
elif isinstance(f, csc):
return -arg.diff(symbol) * csc(arg) * cot(arg)
elif isinstance(f, Add):
return sum([manual_diff(arg, symbol) for arg in f.args])
elif isinstance(f, Mul):
if len(f.args) == 2 and isinstance(f.args[0], Number):
return f.args[0] * manual_diff(f.args[1], symbol)
return f.diff(symbol)
def manual_subs(expr, *args):
"""
A wrapper for `expr.subs(*args)` with additional logic for substitution
of invertible functions.
"""
if len(args) == 1:
sequence = args[0]
if isinstance(sequence, (Dict, Mapping)):
sequence = sequence.items()
elif not iterable(sequence):
raise ValueError("Expected an iterable of (old, new) pairs")
elif len(args) == 2:
sequence = [args]
else:
raise ValueError("subs accepts either 1 or 2 arguments")
new_subs = []
for old, new in sequence:
if isinstance(old, log):
# If log(x) = y, then exp(a*log(x)) = exp(a*y)
# that is, x**a = exp(a*y). Replace nontrivial powers of x
# before subs turns them into `exp(y)**a`, but
# do not replace x itself yet, to avoid `log(exp(y))`.
x0 = old.args[0]
expr = expr.replace(lambda x: x.is_Pow and x.base == x0,
lambda x: exp(x.exp*new))
new_subs.append((x0, exp(new)))
return expr.subs(list(sequence) + new_subs)
# Method based on that on SIN, described in "Symbolic Integration: The
# Stormy Decade"
inverse_trig_functions = (atan, asin, acos, acot, acsc, asec)
def find_substitutions(integrand, symbol, u_var):
results = []
def test_subterm(u, u_diff):
if u_diff == 0:
return False
substituted = integrand / u_diff
debug("substituted: {}, u: {}, u_var: {}".format(substituted, u, u_var))
substituted = manual_subs(substituted, u, u_var).cancel()
if substituted.has_free(symbol):
return False
# avoid increasing the degree of a rational function
if integrand.is_rational_function(symbol) and substituted.is_rational_function(u_var):
deg_before = max([degree(t, symbol) for t in integrand.as_numer_denom()])
deg_after = max([degree(t, u_var) for t in substituted.as_numer_denom()])
if deg_after > deg_before:
return False
return substituted.as_independent(u_var, as_Add=False)
def exp_subterms(term: Expr):
linear_coeffs = []
terms = []
n = Wild('n', properties=[lambda n: n.is_Integer])
for exp_ in term.find(exp):
arg = exp_.args[0]
if symbol not in arg.free_symbols:
continue
match = arg.match(n*symbol)
if match:
linear_coeffs.append(match[n])
else:
terms.append(exp_)
if linear_coeffs:
terms.append(exp(gcd_list(linear_coeffs)*symbol))
return terms
def possible_subterms(term):
if isinstance(term, (TrigonometricFunction, HyperbolicFunction,
*inverse_trig_functions,
exp, log, Heaviside)):
return [term.args[0]]
elif isinstance(term, (chebyshevt, chebyshevu,
legendre, hermite, laguerre)):
return [term.args[1]]
elif isinstance(term, (gegenbauer, assoc_laguerre)):
return [term.args[2]]
elif isinstance(term, jacobi):
return [term.args[3]]
elif isinstance(term, Mul):
r = []
for u in term.args:
r.append(u)
r.extend(possible_subterms(u))
return r
elif isinstance(term, Pow):
r = [arg for arg in term.args if arg.has(symbol)]
if term.exp.is_Integer:
r.extend([term.base**d for d in primefactors(term.exp)
if 1 < d < abs(term.args[1])])
if term.base.is_Add:
r.extend([t for t in possible_subterms(term.base)
if t.is_Pow])
return r
elif isinstance(term, Add):
r = []
for arg in term.args:
r.append(arg)
r.extend(possible_subterms(arg))
return r
return []
for u in list(dict.fromkeys(possible_subterms(integrand) + exp_subterms(integrand))):
if u == symbol:
continue
u_diff = manual_diff(u, symbol)
new_integrand = test_subterm(u, u_diff)
if new_integrand is not False:
constant, new_integrand = new_integrand
if new_integrand == integrand.subs(symbol, u_var):
continue
substitution = (u, constant, new_integrand)
if substitution not in results:
results.append(substitution)
return results
def rewriter(condition, rewrite):
"""Strategy that rewrites an integrand."""
def _rewriter(integral):
integrand, symbol = integral
debug("Integral: {} is rewritten with {} on symbol: {}".format(integrand, rewrite, symbol))
if condition(*integral):
rewritten = rewrite(*integral)
if rewritten != integrand:
substep = integral_steps(rewritten, symbol)
if not isinstance(substep, DontKnowRule) and substep:
return RewriteRule(integrand, symbol, rewritten, substep)
return _rewriter
def proxy_rewriter(condition, rewrite):
"""Strategy that rewrites an integrand based on some other criteria."""
def _proxy_rewriter(criteria):
criteria, integral = criteria
integrand, symbol = integral
debug("Integral: {} is rewritten with {} on symbol: {} and criteria: {}".format(integrand, rewrite, symbol, criteria))
args = criteria + list(integral)
if condition(*args):
rewritten = rewrite(*args)
if rewritten != integrand:
return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol))
return _proxy_rewriter
def multiplexer(conditions):
"""Apply the rule that matches the condition, else None"""
def multiplexer_rl(expr):
for key, rule in conditions.items():
if key(expr):
return rule(expr)
return multiplexer_rl
def alternatives(*rules):
"""Strategy that makes an AlternativeRule out of multiple possible results."""
def _alternatives(integral):
alts = []
count = 0
debug("List of Alternative Rules")
for rule in rules:
count = count + 1
debug("Rule {}: {}".format(count, rule))
result = rule(integral)
if (result and not isinstance(result, DontKnowRule) and
result != integral and result not in alts):
alts.append(result)
if len(alts) == 1:
return alts[0]
elif alts:
doable = [rule for rule in alts if not rule.contains_dont_know()]
if doable:
return AlternativeRule(*integral, doable)
else:
return AlternativeRule(*integral, alts)
return _alternatives
def constant_rule(integral):
return ConstantRule(*integral)
def power_rule(integral):
integrand, symbol = integral
base, expt = integrand.as_base_exp()
if symbol not in expt.free_symbols and isinstance(base, Symbol):
if simplify(expt + 1) == 0:
return ReciprocalRule(integrand, symbol, base)
return PowerRule(integrand, symbol, base, expt)
elif symbol not in base.free_symbols and isinstance(expt, Symbol):
rule = ExpRule(integrand, symbol, base, expt)
if fuzzy_not(log(base).is_zero):
return rule
elif log(base).is_zero:
return ConstantRule(1, symbol)
return PiecewiseRule(integrand, symbol, [
(rule, Ne(log(base), 0)),
(ConstantRule(1, symbol), True)
])
def exp_rule(integral):
integrand, symbol = integral
if isinstance(integrand.args[0], Symbol):
return ExpRule(integrand, symbol, E, integrand.args[0])
def orthogonal_poly_rule(integral):
orthogonal_poly_classes = {
jacobi: JacobiRule,
gegenbauer: GegenbauerRule,
chebyshevt: ChebyshevTRule,
chebyshevu: ChebyshevURule,
legendre: LegendreRule,
hermite: HermiteRule,
laguerre: LaguerreRule,
assoc_laguerre: AssocLaguerreRule
}
orthogonal_poly_var_index = {
jacobi: 3,
gegenbauer: 2,
assoc_laguerre: 2
}
integrand, symbol = integral
for klass in orthogonal_poly_classes:
if isinstance(integrand, klass):
var_index = orthogonal_poly_var_index.get(klass, 1)
if (integrand.args[var_index] is symbol and not
any(v.has(symbol) for v in integrand.args[:var_index])):
return orthogonal_poly_classes[klass](integrand, symbol, *integrand.args[:var_index])
_special_function_patterns: list[tuple[Type, Expr, Callable | None, tuple]] = []
_wilds = []
_symbol = Dummy('x')
def special_function_rule(integral):
integrand, symbol = integral
if not _special_function_patterns:
a = Wild('a', exclude=[_symbol], properties=[lambda x: not x.is_zero])
b = Wild('b', exclude=[_symbol])
c = Wild('c', exclude=[_symbol])
d = Wild('d', exclude=[_symbol], properties=[lambda x: not x.is_zero])
e = Wild('e', exclude=[_symbol], properties=[
lambda x: not (x.is_nonnegative and x.is_integer)])
_wilds.extend((a, b, c, d, e))
# patterns consist of a SymPy class, a wildcard expr, an optional
# condition coded as a lambda (when Wild properties are not enough),
# followed by an applicable rule
linear_pattern = a*_symbol + b
quadratic_pattern = a*_symbol**2 + b*_symbol + c
_special_function_patterns.extend((
(Mul, exp(linear_pattern, evaluate=False)/_symbol, None, EiRule),
(Mul, cos(linear_pattern, evaluate=False)/_symbol, None, CiRule),
(Mul, cosh(linear_pattern, evaluate=False)/_symbol, None, ChiRule),
(Mul, sin(linear_pattern, evaluate=False)/_symbol, None, SiRule),
(Mul, sinh(linear_pattern, evaluate=False)/_symbol, None, ShiRule),
(Pow, 1/log(linear_pattern, evaluate=False), None, LiRule),
(exp, exp(quadratic_pattern, evaluate=False), None, ErfRule),
(sin, sin(quadratic_pattern, evaluate=False), None, FresnelSRule),
(cos, cos(quadratic_pattern, evaluate=False), None, FresnelCRule),
(Mul, _symbol**e*exp(a*_symbol, evaluate=False), None, UpperGammaRule),
(Mul, polylog(b, a*_symbol, evaluate=False)/_symbol, None, PolylogRule),
(Pow, 1/sqrt(a - d*sin(_symbol, evaluate=False)**2),
lambda a, d: a != d, EllipticFRule),
(Pow, sqrt(a - d*sin(_symbol, evaluate=False)**2),
lambda a, d: a != d, EllipticERule),
))
_integrand = integrand.subs(symbol, _symbol)
for type_, pattern, constraint, rule in _special_function_patterns:
if isinstance(_integrand, type_):
match = _integrand.match(pattern)
if match:
wild_vals = tuple(match.get(w) for w in _wilds
if match.get(w) is not None)
if constraint is None or constraint(*wild_vals):
return rule(integrand, symbol, *wild_vals)
def _add_degenerate_step(generic_cond, generic_step: Rule, degenerate_step: Rule | None) -> Rule:
if degenerate_step is None:
return generic_step
if isinstance(generic_step, PiecewiseRule):
subfunctions = [(substep, (cond & generic_cond).simplify())
for substep, cond in generic_step.subfunctions]
else:
subfunctions = [(generic_step, generic_cond)]
if isinstance(degenerate_step, PiecewiseRule):
subfunctions += degenerate_step.subfunctions
else:
subfunctions.append((degenerate_step, S.true))
return PiecewiseRule(generic_step.integrand, generic_step.variable, subfunctions)
def nested_pow_rule(integral: IntegralInfo):
# nested (c*(a+b*x)**d)**e
integrand, x = integral
a_ = Wild('a', exclude=[x])
b_ = Wild('b', exclude=[x, 0])
pattern = a_+b_*x
generic_cond = S.true
class NoMatch(Exception):
pass
def _get_base_exp(expr: Expr) -> tuple[Expr, Expr]:
if not expr.has_free(x):
return S.One, S.Zero
if expr.is_Mul:
_, terms = expr.as_coeff_mul()
if not terms:
return S.One, S.Zero
results = [_get_base_exp(term) for term in terms]
bases = {b for b, _ in results}
bases.discard(S.One)
if len(bases) == 1:
return bases.pop(), Add(*(e for _, e in results))
raise NoMatch
if expr.is_Pow:
b, e = expr.base, expr.exp # type: ignore
if e.has_free(x):
raise NoMatch
base_, sub_exp = _get_base_exp(b)
return base_, sub_exp * e
match = expr.match(pattern)
if match:
a, b = match[a_], match[b_]
base_ = x + a/b
nonlocal generic_cond
generic_cond = Ne(b, 0)
return base_, S.One
raise NoMatch
try:
base, exp_ = _get_base_exp(integrand)
except NoMatch:
return
if generic_cond is S.true:
degenerate_step = None
else:
# equivalent with subs(b, 0) but no need to find b
degenerate_step = ConstantRule(integrand.subs(x, 0), x)
generic_step = NestedPowRule(integrand, x, base, exp_)
return _add_degenerate_step(generic_cond, generic_step, degenerate_step)
def inverse_trig_rule(integral: IntegralInfo, degenerate=True):
"""
Set degenerate=False on recursive call where coefficient of quadratic term
is assumed non-zero.
"""
integrand, symbol = integral
base, exp = integrand.as_base_exp()
a = Wild('a', exclude=[symbol])
b = Wild('b', exclude=[symbol])
c = Wild('c', exclude=[symbol, 0])
match = base.match(a + b*symbol + c*symbol**2)
if not match:
return
def make_inverse_trig(RuleClass, a, sign_a, c, sign_c, h) -> Rule:
u_var = Dummy("u")
rewritten = 1/sqrt(sign_a*a + sign_c*c*(symbol-h)**2) # a>0, c>0
quadratic_base = sqrt(c/a)*(symbol-h)
constant = 1/sqrt(c)
u_func = None
if quadratic_base is not symbol:
u_func = quadratic_base
quadratic_base = u_var
standard_form = 1/sqrt(sign_a + sign_c*quadratic_base**2)
substep = RuleClass(standard_form, quadratic_base)
if constant != 1:
substep = ConstantTimesRule(constant*standard_form, symbol, constant, standard_form, substep)
if u_func is not None:
substep = URule(rewritten, symbol, u_var, u_func, substep)
if h != 0:
substep = CompleteSquareRule(integrand, symbol, rewritten, substep)
return substep
a, b, c = [match.get(i, S.Zero) for i in (a, b, c)]
generic_cond = Ne(c, 0)
if not degenerate or generic_cond is S.true:
degenerate_step = None
elif b.is_zero:
degenerate_step = ConstantRule(a ** exp, symbol)
else:
degenerate_step = sqrt_linear_rule(IntegralInfo((a + b * symbol) ** exp, symbol))
if simplify(2*exp + 1) == 0:
h, k = -b/(2*c), a - b**2/(4*c) # rewrite base to k + c*(symbol-h)**2
non_square_cond = Ne(k, 0)
square_step = None
if non_square_cond is not S.true:
square_step = NestedPowRule(1/sqrt(c*(symbol-h)**2), symbol, symbol-h, S.NegativeOne)
if non_square_cond is S.false:
return square_step
generic_step = ReciprocalSqrtQuadraticRule(integrand, symbol, a, b, c)
step = _add_degenerate_step(non_square_cond, generic_step, square_step)
if k.is_real and c.is_real:
# list of ((rule, base_exp, a, sign_a, b, sign_b), condition)
rules = []
for args, cond in ( # don't apply ArccoshRule to x**2-1
((ArcsinRule, k, 1, -c, -1, h), And(k > 0, c < 0)), # 1-x**2
((ArcsinhRule, k, 1, c, 1, h), And(k > 0, c > 0)), # 1+x**2
):
if cond is S.true:
return make_inverse_trig(*args)
if cond is not S.false:
rules.append((make_inverse_trig(*args), cond))
if rules:
if not k.is_positive: # conditions are not thorough, need fall back rule
rules.append((generic_step, S.true))
step = PiecewiseRule(integrand, symbol, rules)
else:
step = generic_step
return _add_degenerate_step(generic_cond, step, degenerate_step)
if exp == S.Half:
step = SqrtQuadraticRule(integrand, symbol, a, b, c)
return _add_degenerate_step(generic_cond, step, degenerate_step)
def add_rule(integral):
integrand, symbol = integral
results = [integral_steps(g, symbol)
for g in integrand.as_ordered_terms()]
return None if None in results else AddRule(integrand, symbol, results)
def mul_rule(integral: IntegralInfo):
integrand, symbol = integral
# Constant times function case
coeff, f = integrand.as_independent(symbol)
if coeff != 1:
next_step = integral_steps(f, symbol)
if next_step is not None:
return ConstantTimesRule(integrand, symbol, coeff, f, next_step)
def _parts_rule(integrand, symbol) -> tuple[Expr, Expr, Expr, Expr, Rule] | None:
# LIATE rule:
# log, inverse trig, algebraic, trigonometric, exponential
def pull_out_algebraic(integrand):
integrand = integrand.cancel().together()
# iterating over Piecewise args would not work here
algebraic = ([] if isinstance(integrand, Piecewise) or not integrand.is_Mul
else [arg for arg in integrand.args if arg.is_algebraic_expr(symbol)])
if algebraic:
u = Mul(*algebraic)
dv = (integrand / u).cancel()
return u, dv
def pull_out_u(*functions) -> Callable[[Expr], tuple[Expr, Expr] | None]:
def pull_out_u_rl(integrand: Expr) -> tuple[Expr, Expr] | None:
if any(integrand.has(f) for f in functions):
args = [arg for arg in integrand.args
if any(isinstance(arg, cls) for cls in functions)]
if args:
u = Mul(*args)
dv = integrand / u
return u, dv
return None
return pull_out_u_rl
liate_rules = [pull_out_u(log), pull_out_u(*inverse_trig_functions),
pull_out_algebraic, pull_out_u(sin, cos),
pull_out_u(exp)]
dummy = Dummy("temporary")
# we can integrate log(x) and atan(x) by setting dv = 1
if isinstance(integrand, (log, *inverse_trig_functions)):
integrand = dummy * integrand
for index, rule in enumerate(liate_rules):
result = rule(integrand)
if result:
u, dv = result
# Don't pick u to be a constant if possible
if symbol not in u.free_symbols and not u.has(dummy):
return None
u = u.subs(dummy, 1)
dv = dv.subs(dummy, 1)
# Don't pick a non-polynomial algebraic to be differentiated
if rule == pull_out_algebraic and not u.is_polynomial(symbol):
return None
# Don't trade one logarithm for another
if isinstance(u, log):
rec_dv = 1/dv
if (rec_dv.is_polynomial(symbol) and
degree(rec_dv, symbol) == 1):
return None
# Can integrate a polynomial times OrthogonalPolynomial
if rule == pull_out_algebraic:
if dv.is_Derivative or dv.has(TrigonometricFunction) or \
isinstance(dv, OrthogonalPolynomial):
v_step = integral_steps(dv, symbol)
if v_step.contains_dont_know():
return None
else:
du = u.diff(symbol)
v = v_step.eval()
return u, dv, v, du, v_step
# make sure dv is amenable to integration
accept = False
if index < 2: # log and inverse trig are usually worth trying
accept = True
elif (rule == pull_out_algebraic and dv.args and
all(isinstance(a, (sin, cos, exp))
for a in dv.args)):
accept = True
else:
for lrule in liate_rules[index + 1:]:
r = lrule(integrand)
if r and r[0].subs(dummy, 1).equals(dv):
accept = True
break
if accept:
du = u.diff(symbol)
v_step = integral_steps(simplify(dv), symbol)
if not v_step.contains_dont_know():
v = v_step.eval()
return u, dv, v, du, v_step
return None
def parts_rule(integral):
integrand, symbol = integral
constant, integrand = integrand.as_coeff_Mul()
result = _parts_rule(integrand, symbol)
steps = []
if result:
u, dv, v, du, v_step = result
debug("u : {}, dv : {}, v : {}, du : {}, v_step: {}".format(u, dv, v, du, v_step))
steps.append(result)
if isinstance(v, Integral):
return
# Set a limit on the number of times u can be used
if isinstance(u, (sin, cos, exp, sinh, cosh)):
cachekey = u.xreplace({symbol: _cache_dummy})
if _parts_u_cache[cachekey] > 2:
return
_parts_u_cache[cachekey] += 1
# Try cyclic integration by parts a few times
for _ in range(4):
debug("Cyclic integration {} with v: {}, du: {}, integrand: {}".format(_, v, du, integrand))
coefficient = ((v * du) / integrand).cancel()
if coefficient == 1:
break
if symbol not in coefficient.free_symbols:
rule = CyclicPartsRule(integrand, symbol,
[PartsRule(None, None, u, dv, v_step, None)
for (u, dv, v, du, v_step) in steps],
(-1) ** len(steps) * coefficient)
if (constant != 1) and rule:
rule = ConstantTimesRule(constant * integrand, symbol, constant, integrand, rule)
return rule
# _parts_rule is sensitive to constants, factor it out
next_constant, next_integrand = (v * du).as_coeff_Mul()
result = _parts_rule(next_integrand, symbol)
if result:
u, dv, v, du, v_step = result
u *= next_constant
du *= next_constant
steps.append((u, dv, v, du, v_step))
else:
break
def make_second_step(steps, integrand):
if steps:
u, dv, v, du, v_step = steps[0]
return PartsRule(integrand, symbol, u, dv, v_step, make_second_step(steps[1:], v * du))
return integral_steps(integrand, symbol)
if steps:
u, dv, v, du, v_step = steps[0]
rule = PartsRule(integrand, symbol, u, dv, v_step, make_second_step(steps[1:], v * du))
if (constant != 1) and rule:
rule = ConstantTimesRule(constant * integrand, symbol, constant, integrand, rule)
return rule
def trig_rule(integral):
integrand, symbol = integral
if integrand == sin(symbol):
return SinRule(integrand, symbol)
if integrand == cos(symbol):
return CosRule(integrand, symbol)
if integrand == sec(symbol)**2:
return Sec2Rule(integrand, symbol)
if integrand == csc(symbol)**2:
return Csc2Rule(integrand, symbol)
if isinstance(integrand, tan):
rewritten = sin(*integrand.args) / cos(*integrand.args)
elif isinstance(integrand, cot):
rewritten = cos(*integrand.args) / sin(*integrand.args)
elif isinstance(integrand, sec):
arg = integrand.args[0]
rewritten = ((sec(arg)**2 + tan(arg) * sec(arg)) /
(sec(arg) + tan(arg)))
elif isinstance(integrand, csc):
arg = integrand.args[0]
rewritten = ((csc(arg)**2 + cot(arg) * csc(arg)) /
(csc(arg) + cot(arg)))
else:
return
return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol))
def trig_product_rule(integral: IntegralInfo):
integrand, symbol = integral
if integrand == sec(symbol) * tan(symbol):
return SecTanRule(integrand, symbol)
if integrand == csc(symbol) * cot(symbol):
return CscCotRule(integrand, symbol)
def quadratic_denom_rule(integral):
integrand, symbol = integral
a = Wild('a', exclude=[symbol])
b = Wild('b', exclude=[symbol])
c = Wild('c', exclude=[symbol])
match = integrand.match(a / (b * symbol ** 2 + c))
if match:
a, b, c = match[a], match[b], match[c]
general_rule = ArctanRule(integrand, symbol, a, b, c)
if b.is_extended_real and c.is_extended_real:
positive_cond = c/b > 0
if positive_cond is S.true:
return general_rule
coeff = a/(2*sqrt(-c)*sqrt(b))
constant = sqrt(-c/b)
r1 = 1/(symbol-constant)
r2 = 1/(symbol+constant)
log_steps = [ReciprocalRule(r1, symbol, symbol-constant),
ConstantTimesRule(-r2, symbol, -1, r2, ReciprocalRule(r2, symbol, symbol+constant))]
rewritten = sub = r1 - r2
negative_step = AddRule(sub, symbol, log_steps)
if coeff != 1:
rewritten = Mul(coeff, sub, evaluate=False)
negative_step = ConstantTimesRule(rewritten, symbol, coeff, sub, negative_step)
negative_step = RewriteRule(integrand, symbol, rewritten, negative_step)
if positive_cond is S.false:
return negative_step
return PiecewiseRule(integrand, symbol, [(general_rule, positive_cond), (negative_step, S.true)])
return general_rule
d = Wild('d', exclude=[symbol])
match2 = integrand.match(a / (b * symbol ** 2 + c * symbol + d))
if match2:
b, c = match2[b], match2[c]
if b.is_zero:
return
u = Dummy('u')
u_func = symbol + c/(2*b)
integrand2 = integrand.subs(symbol, u - c / (2*b))
next_step = integral_steps(integrand2, u)
if next_step:
return URule(integrand2, symbol, u, u_func, next_step)
else:
return
e = Wild('e', exclude=[symbol])
match3 = integrand.match((a* symbol + b) / (c * symbol ** 2 + d * symbol + e))
if match3:
a, b, c, d, e = match3[a], match3[b], match3[c], match3[d], match3[e]
if c.is_zero:
return
denominator = c * symbol**2 + d * symbol + e
const = a/(2*c)
numer1 = (2*c*symbol+d)
numer2 = - const*d + b
u = Dummy('u')
step1 = URule(integrand, symbol,
u, denominator, integral_steps(u**(-1), u))
if const != 1:
step1 = ConstantTimesRule(const*numer1/denominator, symbol,
const, numer1/denominator, step1)
if numer2.is_zero:
return step1
step2 = integral_steps(numer2/denominator, symbol)
substeps = AddRule(integrand, symbol, [step1, step2])
rewriten = const*numer1/denominator+numer2/denominator
return RewriteRule(integrand, symbol, rewriten, substeps)
return
def sqrt_linear_rule(integral: IntegralInfo):
"""
Substitute common (a+b*x)**(1/n)
"""
integrand, x = integral
a = Wild('a', exclude=[x])
b = Wild('b', exclude=[x, 0])
a0 = b0 = 0
bases, qs, bs = [], [], []
for pow_ in integrand.find(Pow): # collect all (a+b*x)**(p/q)
base, exp_ = pow_.base, pow_.exp
if exp_.is_Integer or x not in base.free_symbols: # skip 1/x and sqrt(2)
continue
if not exp_.is_Rational: # exclude x**pi
return
match = base.match(a+b*x)
if not match: # skip non-linear
continue # for sqrt(x+sqrt(x)), although base is non-linear, we can still substitute sqrt(x)
a1, b1 = match[a], match[b]
if a0*b1 != a1*b0 or not (b0/b1).is_nonnegative: # cannot transform sqrt(x) to sqrt(x+1) or sqrt(-x)
return
if b0 == 0 or (b0/b1 > 1) is S.true: # choose the latter of sqrt(2*x) and sqrt(x) as representative
a0, b0 = a1, b1
bases.append(base)
bs.append(b1)
qs.append(exp_.q)
if b0 == 0: # no such pattern found
return
q0: Integer = lcm_list(qs)
u_x = (a0 + b0*x)**(1/q0)
u = Dummy("u")
substituted = integrand.subs({base**(S.One/q): (b/b0)**(S.One/q)*u**(q0/q)
for base, b, q in zip(bases, bs, qs)}).subs(x, (u**q0-a0)/b0)
substep = integral_steps(substituted*u**(q0-1)*q0/b0, u)
if not substep.contains_dont_know():
step: Rule = URule(integrand, x, u, u_x, substep)
generic_cond = Ne(b0, 0)
if generic_cond is not S.true: # possible degenerate case
simplified = integrand.subs({b: 0 for b in bs})
degenerate_step = integral_steps(simplified, x)
step = PiecewiseRule(integrand, x, [(step, generic_cond), (degenerate_step, S.true)])
return step
def sqrt_quadratic_rule(integral: IntegralInfo, degenerate=True):
integrand, x = integral
a = Wild('a', exclude=[x])
b = Wild('b', exclude=[x])
c = Wild('c', exclude=[x, 0])
f = Wild('f')
n = Wild('n', properties=[lambda n: n.is_Integer and n.is_odd])
match = integrand.match(f*sqrt(a+b*x+c*x**2)**n)
if not match:
return
a, b, c, f, n = match[a], match[b], match[c], match[f], match[n]
f_poly = f.as_poly(x)
if f_poly is None:
return
generic_cond = Ne(c, 0)
if not degenerate or generic_cond is S.true:
degenerate_step = None
elif b.is_zero:
degenerate_step = integral_steps(f*sqrt(a)**n, x)
else:
degenerate_step = sqrt_linear_rule(IntegralInfo(f*sqrt(a+b*x)**n, x))
def sqrt_quadratic_denom_rule(numer_poly: Poly, integrand: Expr):
denom = sqrt(a+b*x+c*x**2)
deg = numer_poly.degree()
if deg <= 1:
# integrand == (d+e*x)/sqrt(a+b*x+c*x**2)
e, d = numer_poly.all_coeffs() if deg == 1 else (S.Zero, numer_poly.as_expr())
# rewrite numerator to A*(2*c*x+b) + B
A = e/(2*c)
B = d-A*b
pre_substitute = (2*c*x+b)/denom
constant_step: Rule | None = None
linear_step: Rule | None = None
if A != 0:
u = Dummy("u")
pow_rule = PowerRule(1/sqrt(u), u, u, -S.Half)
linear_step = URule(pre_substitute, x, u, a+b*x+c*x**2, pow_rule)
if A != 1:
linear_step = ConstantTimesRule(A*pre_substitute, x, A, pre_substitute, linear_step)
if B != 0:
constant_step = inverse_trig_rule(IntegralInfo(1/denom, x), degenerate=False)
if B != 1:
constant_step = ConstantTimesRule(B/denom, x, B, 1/denom, constant_step) # type: ignore
if linear_step and constant_step:
add = Add(A*pre_substitute, B/denom, evaluate=False)
step: Rule | None = RewriteRule(integrand, x, add, AddRule(add, x, [linear_step, constant_step]))
else:
step = linear_step or constant_step
else:
coeffs = numer_poly.all_coeffs()
step = SqrtQuadraticDenomRule(integrand, x, a, b, c, coeffs)
return step
if n > 0: # rewrite poly * sqrt(s)**(2*k-1) to poly*s**k / sqrt(s)
numer_poly = f_poly * (a+b*x+c*x**2)**((n+1)/2)
rewritten = numer_poly.as_expr()/sqrt(a+b*x+c*x**2)
substep = sqrt_quadratic_denom_rule(numer_poly, rewritten)
generic_step = RewriteRule(integrand, x, rewritten, substep)
elif n == -1:
generic_step = sqrt_quadratic_denom_rule(f_poly, integrand)
else:
return # todo: handle n < -1 case
return _add_degenerate_step(generic_cond, generic_step, degenerate_step)
def hyperbolic_rule(integral: tuple[Expr, Symbol]):
integrand, symbol = integral
if isinstance(integrand, HyperbolicFunction) and integrand.args[0] == symbol:
if integrand.func == sinh:
return SinhRule(integrand, symbol)
if integrand.func == cosh:
return CoshRule(integrand, symbol)
u = Dummy('u')
if integrand.func == tanh:
rewritten = sinh(symbol)/cosh(symbol)
return RewriteRule(integrand, symbol, rewritten,
URule(rewritten, symbol, u, cosh(symbol), ReciprocalRule(1/u, u, u)))
if integrand.func == coth:
rewritten = cosh(symbol)/sinh(symbol)
return RewriteRule(integrand, symbol, rewritten,
URule(rewritten, symbol, u, sinh(symbol), ReciprocalRule(1/u, u, u)))
else:
rewritten = integrand.rewrite(tanh)
if integrand.func == sech:
return RewriteRule(integrand, symbol, rewritten,
URule(rewritten, symbol, u, tanh(symbol/2),
ArctanRule(2/(u**2 + 1), u, S(2), S.One, S.One)))
if integrand.func == csch:
return RewriteRule(integrand, symbol, rewritten,
URule(rewritten, symbol, u, tanh(symbol/2),
ReciprocalRule(1/u, u, u)))
@cacheit
def make_wilds(symbol):
a = Wild('a', exclude=[symbol])
b = Wild('b', exclude=[symbol])
m = Wild('m', exclude=[symbol], properties=[lambda n: isinstance(n, Integer)])
n = Wild('n', exclude=[symbol], properties=[lambda n: isinstance(n, Integer)])
return a, b, m, n
@cacheit
def sincos_pattern(symbol):
a, b, m, n = make_wilds(symbol)
pattern = sin(a*symbol)**m * cos(b*symbol)**n
return pattern, a, b, m, n
@cacheit
def tansec_pattern(symbol):
a, b, m, n = make_wilds(symbol)
pattern = tan(a*symbol)**m * sec(b*symbol)**n
return pattern, a, b, m, n
@cacheit
def cotcsc_pattern(symbol):
a, b, m, n = make_wilds(symbol)
pattern = cot(a*symbol)**m * csc(b*symbol)**n
return pattern, a, b, m, n
@cacheit
def heaviside_pattern(symbol):
m = Wild('m', exclude=[symbol])
b = Wild('b', exclude=[symbol])
g = Wild('g')
pattern = Heaviside(m*symbol + b) * g
return pattern, m, b, g
def uncurry(func):
def uncurry_rl(args):
return func(*args)
return uncurry_rl
def trig_rewriter(rewrite):
def trig_rewriter_rl(args):
a, b, m, n, integrand, symbol = args
rewritten = rewrite(a, b, m, n, integrand, symbol)
if rewritten != integrand:
return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol))
return trig_rewriter_rl
sincos_botheven_condition = uncurry(
lambda a, b, m, n, i, s: m.is_even and n.is_even and
m.is_nonnegative and n.is_nonnegative)
sincos_botheven = trig_rewriter(
lambda a, b, m, n, i, symbol: ( (((1 - cos(2*a*symbol)) / 2) ** (m / 2)) *
(((1 + cos(2*b*symbol)) / 2) ** (n / 2)) ))
sincos_sinodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd and m >= 3)
sincos_sinodd = trig_rewriter(
lambda a, b, m, n, i, symbol: ( (1 - cos(a*symbol)**2)**((m - 1) / 2) *
sin(a*symbol) *
cos(b*symbol) ** n))
sincos_cosodd_condition = uncurry(lambda a, b, m, n, i, s: n.is_odd and n >= 3)
sincos_cosodd = trig_rewriter(
lambda a, b, m, n, i, symbol: ( (1 - sin(b*symbol)**2)**((n - 1) / 2) *
cos(b*symbol) *
sin(a*symbol) ** m))
tansec_seceven_condition = uncurry(lambda a, b, m, n, i, s: n.is_even and n >= 4)
tansec_seceven = trig_rewriter(
lambda a, b, m, n, i, symbol: ( (1 + tan(b*symbol)**2) ** (n/2 - 1) *
sec(b*symbol)**2 *
tan(a*symbol) ** m ))
tansec_tanodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd)
tansec_tanodd = trig_rewriter(
lambda a, b, m, n, i, symbol: ( (sec(a*symbol)**2 - 1) ** ((m - 1) / 2) *
tan(a*symbol) *
sec(b*symbol) ** n ))
tan_tansquared_condition = uncurry(lambda a, b, m, n, i, s: m == 2 and n == 0)
tan_tansquared = trig_rewriter(
lambda a, b, m, n, i, symbol: ( sec(a*symbol)**2 - 1))
cotcsc_csceven_condition = uncurry(lambda a, b, m, n, i, s: n.is_even and n >= 4)
cotcsc_csceven = trig_rewriter(
lambda a, b, m, n, i, symbol: ( (1 + cot(b*symbol)**2) ** (n/2 - 1) *
csc(b*symbol)**2 *
cot(a*symbol) ** m ))
cotcsc_cotodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd)
cotcsc_cotodd = trig_rewriter(
lambda a, b, m, n, i, symbol: ( (csc(a*symbol)**2 - 1) ** ((m - 1) / 2) *
cot(a*symbol) *
csc(b*symbol) ** n ))
def trig_sincos_rule(integral):
integrand, symbol = integral
if any(integrand.has(f) for f in (sin, cos)):
pattern, a, b, m, n = sincos_pattern(symbol)
match = integrand.match(pattern)
if not match:
return
return multiplexer({
sincos_botheven_condition: sincos_botheven,
sincos_sinodd_condition: sincos_sinodd,
sincos_cosodd_condition: sincos_cosodd
})(tuple(
[match.get(i, S.Zero) for i in (a, b, m, n)] +
[integrand, symbol]))
def trig_tansec_rule(integral):
integrand, symbol = integral
integrand = integrand.subs({
1 / cos(symbol): sec(symbol)
})
if any(integrand.has(f) for f in (tan, sec)):
pattern, a, b, m, n = tansec_pattern(symbol)
match = integrand.match(pattern)
if not match:
return
return multiplexer({
tansec_tanodd_condition: tansec_tanodd,
tansec_seceven_condition: tansec_seceven,
tan_tansquared_condition: tan_tansquared
})(tuple(
[match.get(i, S.Zero) for i in (a, b, m, n)] +
[integrand, symbol]))
def trig_cotcsc_rule(integral):
integrand, symbol = integral
integrand = integrand.subs({
1 / sin(symbol): csc(symbol),
1 / tan(symbol): cot(symbol),
cos(symbol) / tan(symbol): cot(symbol)
})
if any(integrand.has(f) for f in (cot, csc)):
pattern, a, b, m, n = cotcsc_pattern(symbol)
match = integrand.match(pattern)
if not match:
return
return multiplexer({
cotcsc_cotodd_condition: cotcsc_cotodd,
cotcsc_csceven_condition: cotcsc_csceven
})(tuple(
[match.get(i, S.Zero) for i in (a, b, m, n)] +
[integrand, symbol]))
def trig_sindouble_rule(integral):
integrand, symbol = integral
a = Wild('a', exclude=[sin(2*symbol)])
match = integrand.match(sin(2*symbol)*a)
if match:
sin_double = 2*sin(symbol)*cos(symbol)/sin(2*symbol)
return integral_steps(integrand * sin_double, symbol)
def trig_powers_products_rule(integral):
return do_one(null_safe(trig_sincos_rule),
null_safe(trig_tansec_rule),
null_safe(trig_cotcsc_rule),
null_safe(trig_sindouble_rule))(integral)
def trig_substitution_rule(integral):
integrand, symbol = integral
A = Wild('a', exclude=[0, symbol])
B = Wild('b', exclude=[0, symbol])
theta = Dummy("theta")
target_pattern = A + B*symbol**2
matches = integrand.find(target_pattern)
for expr in matches:
match = expr.match(target_pattern)
a = match.get(A, S.Zero)
b = match.get(B, S.Zero)
a_positive = ((a.is_number and a > 0) or a.is_positive)
b_positive = ((b.is_number and b > 0) or b.is_positive)
a_negative = ((a.is_number and a < 0) or a.is_negative)
b_negative = ((b.is_number and b < 0) or b.is_negative)
x_func = None
if a_positive and b_positive:
# a**2 + b*x**2. Assume sec(theta) > 0, -pi/2 < theta < pi/2
x_func = (sqrt(a)/sqrt(b)) * tan(theta)
# Do not restrict the domain: tan(theta) takes on any real
# value on the interval -pi/2 < theta < pi/2 so x takes on
# any value
restriction = True
elif a_positive and b_negative:
# a**2 - b*x**2. Assume cos(theta) > 0, -pi/2 < theta < pi/2
constant = sqrt(a)/sqrt(-b)
x_func = constant * sin(theta)
restriction = And(symbol > -constant, symbol < constant)
elif a_negative and b_positive:
# b*x**2 - a**2. Assume sin(theta) > 0, 0 < theta < pi
constant = sqrt(-a)/sqrt(b)
x_func = constant * sec(theta)
restriction = And(symbol > -constant, symbol < constant)
if x_func:
# Manually simplify sqrt(trig(theta)**2) to trig(theta)
# Valid due to assumed domain restriction
substitutions = {}
for f in [sin, cos, tan,
sec, csc, cot]:
substitutions[sqrt(f(theta)**2)] = f(theta)
substitutions[sqrt(f(theta)**(-2))] = 1/f(theta)
replaced = integrand.subs(symbol, x_func).trigsimp()
replaced = manual_subs(replaced, substitutions)
if not replaced.has(symbol):
replaced *= manual_diff(x_func, theta)
replaced = replaced.trigsimp()
secants = replaced.find(1/cos(theta))
if secants:
replaced = replaced.xreplace({
1/cos(theta): sec(theta)
})
substep = integral_steps(replaced, theta)
if not substep.contains_dont_know():
return TrigSubstitutionRule(integrand, symbol,
theta, x_func, replaced, substep, restriction)
def heaviside_rule(integral):
integrand, symbol = integral
pattern, m, b, g = heaviside_pattern(symbol)
match = integrand.match(pattern)
if match and 0 != match[g]:
# f = Heaviside(m*x + b)*g
substep = integral_steps(match[g], symbol)
m, b = match[m], match[b]
return HeavisideRule(integrand, symbol, m*symbol + b, -b/m, substep)
def dirac_delta_rule(integral: IntegralInfo):
integrand, x = integral
if len(integrand.args) == 1:
n = S.Zero
else:
n = integrand.args[1]
if not n.is_Integer or n < 0:
return
a, b = Wild('a', exclude=[x]), Wild('b', exclude=[x, 0])
match = integrand.args[0].match(a+b*x)
if not match:
return
a, b = match[a], match[b]
generic_cond = Ne(b, 0)
if generic_cond is S.true:
degenerate_step = None
else:
degenerate_step = ConstantRule(DiracDelta(a, n), x)
generic_step = DiracDeltaRule(integrand, x, n, a, b)
return _add_degenerate_step(generic_cond, generic_step, degenerate_step)
def substitution_rule(integral):
integrand, symbol = integral
u_var = Dummy("u")
substitutions = find_substitutions(integrand, symbol, u_var)
count = 0
if substitutions:
debug("List of Substitution Rules")
ways = []
for u_func, c, substituted in substitutions:
subrule = integral_steps(substituted, u_var)
count = count + 1
debug("Rule {}: {}".format(count, subrule))
if subrule.contains_dont_know():
continue
if simplify(c - 1) != 0:
_, denom = c.as_numer_denom()
if subrule:
subrule = ConstantTimesRule(c * substituted, u_var, c, substituted, subrule)
if denom.free_symbols:
piecewise = []
could_be_zero = []
if isinstance(denom, Mul):
could_be_zero = denom.args
else:
could_be_zero.append(denom)
for expr in could_be_zero:
if not fuzzy_not(expr.is_zero):
substep = integral_steps(manual_subs(integrand, expr, 0), symbol)
if substep:
piecewise.append((
substep,
Eq(expr, 0)
))
piecewise.append((subrule, True))
subrule = PiecewiseRule(substituted, symbol, piecewise)
ways.append(URule(integrand, symbol, u_var, u_func, subrule))
if len(ways) > 1:
return AlternativeRule(integrand, symbol, ways)
elif ways:
return ways[0]
partial_fractions_rule = rewriter(
lambda integrand, symbol: integrand.is_rational_function(),
lambda integrand, symbol: integrand.apart(symbol))
cancel_rule = rewriter(
# lambda integrand, symbol: integrand.is_algebraic_expr(),
# lambda integrand, symbol: isinstance(integrand, Mul),
lambda integrand, symbol: True,
lambda integrand, symbol: integrand.cancel())
distribute_expand_rule = rewriter(
lambda integrand, symbol: (
all(arg.is_Pow or arg.is_polynomial(symbol) for arg in integrand.args)
or isinstance(integrand, Pow)
or isinstance(integrand, Mul)),
lambda integrand, symbol: integrand.expand())
trig_expand_rule = rewriter(
# If there are trig functions with different arguments, expand them
lambda integrand, symbol: (
len({a.args[0] for a in integrand.atoms(TrigonometricFunction)}) > 1),
lambda integrand, symbol: integrand.expand(trig=True))
def derivative_rule(integral):
integrand = integral[0]
diff_variables = integrand.variables
undifferentiated_function = integrand.expr
integrand_variables = undifferentiated_function.free_symbols
if integral.symbol in integrand_variables:
if integral.symbol in diff_variables:
return DerivativeRule(*integral)
else:
return DontKnowRule(integrand, integral.symbol)
else:
return ConstantRule(*integral)
def rewrites_rule(integral):
integrand, symbol = integral
if integrand.match(1/cos(symbol)):
rewritten = integrand.subs(1/cos(symbol), sec(symbol))
return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol))
def fallback_rule(integral):
return DontKnowRule(*integral)
# Cache is used to break cyclic integrals.
# Need to use the same dummy variable in cached expressions for them to match.
# Also record "u" of integration by parts, to avoid infinite repetition.
_integral_cache: dict[Expr, Expr | None] = {}
_parts_u_cache: dict[Expr, int] = defaultdict(int)
_cache_dummy = Dummy("z")
def integral_steps(integrand, symbol, **options):
"""Returns the steps needed to compute an integral.
Explanation
===========
This function attempts to mirror what a student would do by hand as
closely as possible.
SymPy Gamma uses this to provide a step-by-step explanation of an
integral. The code it uses to format the results of this function can be
found at
https://github.com/sympy/sympy_gamma/blob/master/app/logic/intsteps.py.
Examples
========
>>> from sympy import exp, sin
>>> from sympy.integrals.manualintegrate import integral_steps
>>> from sympy.abc import x
>>> print(repr(integral_steps(exp(x) / (1 + exp(2 * x)), x))) \
# doctest: +NORMALIZE_WHITESPACE
URule(integrand=exp(x)/(exp(2*x) + 1), variable=x, u_var=_u, u_func=exp(x),
substep=ArctanRule(integrand=1/(_u**2 + 1), variable=_u, a=1, b=1, c=1))
>>> print(repr(integral_steps(sin(x), x))) \
# doctest: +NORMALIZE_WHITESPACE
SinRule(integrand=sin(x), variable=x)
>>> print(repr(integral_steps((x**2 + 3)**2, x))) \
# doctest: +NORMALIZE_WHITESPACE
RewriteRule(integrand=(x**2 + 3)**2, variable=x, rewritten=x**4 + 6*x**2 + 9,
substep=AddRule(integrand=x**4 + 6*x**2 + 9, variable=x,
substeps=[PowerRule(integrand=x**4, variable=x, base=x, exp=4),
ConstantTimesRule(integrand=6*x**2, variable=x, constant=6, other=x**2,
substep=PowerRule(integrand=x**2, variable=x, base=x, exp=2)),
ConstantRule(integrand=9, variable=x)]))
Returns
=======
rule : Rule
The first step; most rules have substeps that must also be
considered. These substeps can be evaluated using ``manualintegrate``
to obtain a result.
"""
cachekey = integrand.xreplace({symbol: _cache_dummy})
if cachekey in _integral_cache:
if _integral_cache[cachekey] is None:
# Stop this attempt, because it leads around in a loop
return DontKnowRule(integrand, symbol)
else:
# TODO: This is for future development, as currently
# _integral_cache gets no values other than None
return (_integral_cache[cachekey].xreplace(_cache_dummy, symbol),
symbol)
else:
_integral_cache[cachekey] = None
integral = IntegralInfo(integrand, symbol)
def key(integral):
integrand = integral.integrand
if symbol not in integrand.free_symbols:
return Number
for cls in (Symbol, TrigonometricFunction, OrthogonalPolynomial):
if isinstance(integrand, cls):
return cls
return type(integrand)
def integral_is_subclass(*klasses):
def _integral_is_subclass(integral):
k = key(integral)
return k and issubclass(k, klasses)
return _integral_is_subclass
result = do_one(
null_safe(special_function_rule),
null_safe(switch(key, {
Pow: do_one(null_safe(power_rule), null_safe(inverse_trig_rule),
null_safe(sqrt_linear_rule),
null_safe(quadratic_denom_rule)),
Symbol: power_rule,
exp: exp_rule,
Add: add_rule,
Mul: do_one(null_safe(mul_rule), null_safe(trig_product_rule),
null_safe(heaviside_rule), null_safe(quadratic_denom_rule),
null_safe(sqrt_linear_rule),
null_safe(sqrt_quadratic_rule)),
Derivative: derivative_rule,
TrigonometricFunction: trig_rule,
Heaviside: heaviside_rule,
DiracDelta: dirac_delta_rule,
OrthogonalPolynomial: orthogonal_poly_rule,
Number: constant_rule
})),
do_one(
null_safe(trig_rule),
null_safe(hyperbolic_rule),
null_safe(alternatives(
rewrites_rule,
substitution_rule,
condition(
integral_is_subclass(Mul, Pow),
partial_fractions_rule),
condition(
integral_is_subclass(Mul, Pow),
cancel_rule),
condition(
integral_is_subclass(Mul, log,
*inverse_trig_functions),
parts_rule),
condition(
integral_is_subclass(Mul, Pow),
distribute_expand_rule),
trig_powers_products_rule,
trig_expand_rule
)),
null_safe(condition(integral_is_subclass(Mul, Pow), nested_pow_rule)),
null_safe(trig_substitution_rule)
),
fallback_rule)(integral)
del _integral_cache[cachekey]
return result
def manualintegrate(f, var):
"""manualintegrate(f, var)
Explanation
===========
Compute indefinite integral of a single variable using an algorithm that
resembles what a student would do by hand.
Unlike :func:`~.integrate`, var can only be a single symbol.
Examples
========
>>> from sympy import sin, cos, tan, exp, log, integrate
>>> from sympy.integrals.manualintegrate import manualintegrate
>>> from sympy.abc import x
>>> manualintegrate(1 / x, x)
log(x)
>>> integrate(1/x)
log(x)
>>> manualintegrate(log(x), x)
x*log(x) - x
>>> integrate(log(x))
x*log(x) - x
>>> manualintegrate(exp(x) / (1 + exp(2 * x)), x)
atan(exp(x))
>>> integrate(exp(x) / (1 + exp(2 * x)))
RootSum(4*_z**2 + 1, Lambda(_i, _i*log(2*_i + exp(x))))
>>> manualintegrate(cos(x)**4 * sin(x), x)
-cos(x)**5/5
>>> integrate(cos(x)**4 * sin(x), x)
-cos(x)**5/5
>>> manualintegrate(cos(x)**4 * sin(x)**3, x)
cos(x)**7/7 - cos(x)**5/5
>>> integrate(cos(x)**4 * sin(x)**3, x)
cos(x)**7/7 - cos(x)**5/5
>>> manualintegrate(tan(x), x)
-log(cos(x))
>>> integrate(tan(x), x)
-log(cos(x))
See Also
========
sympy.integrals.integrals.integrate
sympy.integrals.integrals.Integral.doit
sympy.integrals.integrals.Integral
"""
result = integral_steps(f, var).eval()
# Clear the cache of u-parts
_parts_u_cache.clear()
# If we got Piecewise with two parts, put generic first
if isinstance(result, Piecewise) and len(result.args) == 2:
cond = result.args[0][1]
if isinstance(cond, Eq) and result.args[1][1] == True:
result = result.func(
(result.args[1][0], Ne(*cond.args)),
(result.args[0][0], True))
return result