from sympy.core import S, Pow from sympy.core.function import (Derivative, AppliedUndef, diff) from sympy.core.relational import Equality, Eq from sympy.core.symbol import Dummy from sympy.core.sympify import sympify from sympy.logic.boolalg import BooleanAtom from sympy.functions import exp from sympy.series import Order from sympy.simplify.simplify import simplify, posify, besselsimp from sympy.simplify.trigsimp import trigsimp from sympy.simplify.sqrtdenest import sqrtdenest from sympy.solvers import solve from sympy.solvers.deutils import _preprocess, ode_order from sympy.utilities.iterables import iterable, is_sequence def sub_func_doit(eq, func, new): r""" When replacing the func with something else, we usually want the derivative evaluated, so this function helps in making that happen. Examples ======== >>> from sympy import Derivative, symbols, Function >>> from sympy.solvers.ode.subscheck import sub_func_doit >>> x, z = symbols('x, z') >>> y = Function('y') >>> sub_func_doit(3*Derivative(y(x), x) - 1, y(x), x) 2 >>> sub_func_doit(x*Derivative(y(x), x) - y(x)**2 + y(x), y(x), ... 1/(x*(z + 1/x))) x*(-1/(x**2*(z + 1/x)) + 1/(x**3*(z + 1/x)**2)) + 1/(x*(z + 1/x)) ...- 1/(x**2*(z + 1/x)**2) """ reps= {func: new} for d in eq.atoms(Derivative): if d.expr == func: reps[d] = new.diff(*d.variable_count) else: reps[d] = d.xreplace({func: new}).doit(deep=False) return eq.xreplace(reps) def checkodesol(ode, sol, func=None, order='auto', solve_for_func=True): r""" Substitutes ``sol`` into ``ode`` and checks that the result is ``0``. This works when ``func`` is one function, like `f(x)` or a list of functions like `[f(x), g(x)]` when `ode` is a system of ODEs. ``sol`` can be a single solution or a list of solutions. Each solution may be an :py:class:`~sympy.core.relational.Equality` that the solution satisfies, e.g. ``Eq(f(x), C1), Eq(f(x) + C1, 0)``; or simply an :py:class:`~sympy.core.expr.Expr`, e.g. ``f(x) - C1``. In most cases it will not be necessary to explicitly identify the function, but if the function cannot be inferred from the original equation it can be supplied through the ``func`` argument. If a sequence of solutions is passed, the same sort of container will be used to return the result for each solution. It tries the following methods, in order, until it finds zero equivalence: 1. Substitute the solution for `f` in the original equation. This only works if ``ode`` is solved for `f`. It will attempt to solve it first unless ``solve_for_func == False``. 2. Take `n` derivatives of the solution, where `n` is the order of ``ode``, and check to see if that is equal to the solution. This only works on exact ODEs. 3. Take the 1st, 2nd, ..., `n`\th derivatives of the solution, each time solving for the derivative of `f` of that order (this will always be possible because `f` is a linear operator). Then back substitute each derivative into ``ode`` in reverse order. This function returns a tuple. The first item in the tuple is ``True`` if the substitution results in ``0``, and ``False`` otherwise. The second item in the tuple is what the substitution results in. It should always be ``0`` if the first item is ``True``. Sometimes this function will return ``False`` even when an expression is identically equal to ``0``. This happens when :py:meth:`~sympy.simplify.simplify.simplify` does not reduce the expression to ``0``. If an expression returned by this function vanishes identically, then ``sol`` really is a solution to the ``ode``. If this function seems to hang, it is probably because of a hard simplification. To use this function to test, test the first item of the tuple. Examples ======== >>> from sympy import (Eq, Function, checkodesol, symbols, ... Derivative, exp) >>> x, C1, C2 = symbols('x,C1,C2') >>> f, g = symbols('f g', cls=Function) >>> checkodesol(f(x).diff(x), Eq(f(x), C1)) (True, 0) >>> assert checkodesol(f(x).diff(x), C1)[0] >>> assert not checkodesol(f(x).diff(x), x)[0] >>> checkodesol(f(x).diff(x, 2), x**2) (False, 2) >>> eqs = [Eq(Derivative(f(x), x), f(x)), Eq(Derivative(g(x), x), g(x))] >>> sol = [Eq(f(x), C1*exp(x)), Eq(g(x), C2*exp(x))] >>> checkodesol(eqs, sol) (True, [0, 0]) """ if iterable(ode): return checksysodesol(ode, sol, func=func) if not isinstance(ode, Equality): ode = Eq(ode, 0) if func is None: try: _, func = _preprocess(ode.lhs) except ValueError: funcs = [s.atoms(AppliedUndef) for s in ( sol if is_sequence(sol, set) else [sol])] funcs = set().union(*funcs) if len(funcs) != 1: raise ValueError( 'must pass func arg to checkodesol for this case.') func = funcs.pop() if not isinstance(func, AppliedUndef) or len(func.args) != 1: raise ValueError( "func must be a function of one variable, not %s" % func) if is_sequence(sol, set): return type(sol)([checkodesol(ode, i, order=order, solve_for_func=solve_for_func) for i in sol]) if not isinstance(sol, Equality): sol = Eq(func, sol) elif sol.rhs == func: sol = sol.reversed if order == 'auto': order = ode_order(ode, func) solved = sol.lhs == func and not sol.rhs.has(func) if solve_for_func and not solved: rhs = solve(sol, func) if rhs: eqs = [Eq(func, t) for t in rhs] if len(rhs) == 1: eqs = eqs[0] return checkodesol(ode, eqs, order=order, solve_for_func=False) x = func.args[0] # Handle series solutions here if sol.has(Order): assert sol.lhs == func Oterm = sol.rhs.getO() solrhs = sol.rhs.removeO() Oexpr = Oterm.expr assert isinstance(Oexpr, Pow) sorder = Oexpr.exp assert Oterm == Order(x**sorder) odesubs = (ode.lhs-ode.rhs).subs(func, solrhs).doit().expand() neworder = Order(x**(sorder - order)) odesubs = odesubs + neworder assert odesubs.getO() == neworder residual = odesubs.removeO() return (residual == 0, residual) s = True testnum = 0 while s: if testnum == 0: # First pass, try substituting a solved solution directly into the # ODE. This has the highest chance of succeeding. ode_diff = ode.lhs - ode.rhs if sol.lhs == func: s = sub_func_doit(ode_diff, func, sol.rhs) s = besselsimp(s) else: testnum += 1 continue ss = simplify(s.rewrite(exp)) if ss: # with the new numer_denom in power.py, if we do a simple # expansion then testnum == 0 verifies all solutions. s = ss.expand(force=True) else: s = 0 testnum += 1 elif testnum == 1: # Second pass. If we cannot substitute f, try seeing if the nth # derivative is equal, this will only work for odes that are exact, # by definition. s = simplify( trigsimp(diff(sol.lhs, x, order) - diff(sol.rhs, x, order)) - trigsimp(ode.lhs) + trigsimp(ode.rhs)) # s2 = simplify( # diff(sol.lhs, x, order) - diff(sol.rhs, x, order) - \ # ode.lhs + ode.rhs) testnum += 1 elif testnum == 2: # Third pass. Try solving for df/dx and substituting that into the # ODE. Thanks to Chris Smith for suggesting this method. Many of # the comments below are his, too. # The method: # - Take each of 1..n derivatives of the solution. # - Solve each nth derivative for d^(n)f/dx^(n) # (the differential of that order) # - Back substitute into the ODE in decreasing order # (i.e., n, n-1, ...) # - Check the result for zero equivalence if sol.lhs == func and not sol.rhs.has(func): diffsols = {0: sol.rhs} elif sol.rhs == func and not sol.lhs.has(func): diffsols = {0: sol.lhs} else: diffsols = {} sol = sol.lhs - sol.rhs for i in range(1, order + 1): # Differentiation is a linear operator, so there should always # be 1 solution. Nonetheless, we test just to make sure. # We only need to solve once. After that, we automatically # have the solution to the differential in the order we want. if i == 1: ds = sol.diff(x) try: sdf = solve(ds, func.diff(x, i)) if not sdf: raise NotImplementedError except NotImplementedError: testnum += 1 break else: diffsols[i] = sdf[0] else: # This is what the solution says df/dx should be. diffsols[i] = diffsols[i - 1].diff(x) # Make sure the above didn't fail. if testnum > 2: continue else: # Substitute it into ODE to check for self consistency. lhs, rhs = ode.lhs, ode.rhs for i in range(order, -1, -1): if i == 0 and 0 not in diffsols: # We can only substitute f(x) if the solution was # solved for f(x). break lhs = sub_func_doit(lhs, func.diff(x, i), diffsols[i]) rhs = sub_func_doit(rhs, func.diff(x, i), diffsols[i]) ode_or_bool = Eq(lhs, rhs) ode_or_bool = simplify(ode_or_bool) if isinstance(ode_or_bool, (bool, BooleanAtom)): if ode_or_bool: lhs = rhs = S.Zero else: lhs = ode_or_bool.lhs rhs = ode_or_bool.rhs # No sense in overworking simplify -- just prove that the # numerator goes to zero num = trigsimp((lhs - rhs).as_numer_denom()[0]) # since solutions are obtained using force=True we test # using the same level of assumptions ## replace function with dummy so assumptions will work _func = Dummy('func') num = num.subs(func, _func) ## posify the expression num, reps = posify(num) s = simplify(num).xreplace(reps).xreplace({_func: func}) testnum += 1 else: break if not s: return (True, s) elif s is True: # The code above never was able to change s raise NotImplementedError("Unable to test if " + str(sol) + " is a solution to " + str(ode) + ".") else: return (False, s) def checksysodesol(eqs, sols, func=None): r""" Substitutes corresponding ``sols`` for each functions into each ``eqs`` and checks that the result of substitutions for each equation is ``0``. The equations and solutions passed can be any iterable. This only works when each ``sols`` have one function only, like `x(t)` or `y(t)`. For each function, ``sols`` can have a single solution or a list of solutions. In most cases it will not be necessary to explicitly identify the function, but if the function cannot be inferred from the original equation it can be supplied through the ``func`` argument. When a sequence of equations is passed, the same sequence is used to return the result for each equation with each function substituted with corresponding solutions. It tries the following method to find zero equivalence for each equation: Substitute the solutions for functions, like `x(t)` and `y(t)` into the original equations containing those functions. This function returns a tuple. The first item in the tuple is ``True`` if the substitution results for each equation is ``0``, and ``False`` otherwise. The second item in the tuple is what the substitution results in. Each element of the ``list`` should always be ``0`` corresponding to each equation if the first item is ``True``. Note that sometimes this function may return ``False``, but with an expression that is identically equal to ``0``, instead of returning ``True``. This is because :py:meth:`~sympy.simplify.simplify.simplify` cannot reduce the expression to ``0``. If an expression returned by each function vanishes identically, then ``sols`` really is a solution to ``eqs``. If this function seems to hang, it is probably because of a difficult simplification. Examples ======== >>> from sympy import Eq, diff, symbols, sin, cos, exp, sqrt, S, Function >>> from sympy.solvers.ode.subscheck import checksysodesol >>> C1, C2 = symbols('C1:3') >>> t = symbols('t') >>> x, y = symbols('x, y', cls=Function) >>> eq = (Eq(diff(x(t),t), x(t) + y(t) + 17), Eq(diff(y(t),t), -2*x(t) + y(t) + 12)) >>> sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - S(5)/3), ... Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - S(46)/3)] >>> checksysodesol(eq, sol) (True, [0, 0]) >>> eq = (Eq(diff(x(t),t),x(t)*y(t)**4), Eq(diff(y(t),t),y(t)**3)) >>> sol = [Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), -sqrt(2)*sqrt(-1/(C2 + t))/2), ... Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), sqrt(2)*sqrt(-1/(C2 + t))/2)] >>> checksysodesol(eq, sol) (True, [0, 0]) """ def _sympify(eq): return list(map(sympify, eq if iterable(eq) else [eq])) eqs = _sympify(eqs) for i in range(len(eqs)): if isinstance(eqs[i], Equality): eqs[i] = eqs[i].lhs - eqs[i].rhs if func is None: funcs = [] for eq in eqs: derivs = eq.atoms(Derivative) func = set().union(*[d.atoms(AppliedUndef) for d in derivs]) funcs.extend(func) funcs = list(set(funcs)) if not all(isinstance(func, AppliedUndef) and len(func.args) == 1 for func in funcs)\ and len({func.args for func in funcs})!=1: raise ValueError("func must be a function of one variable, not %s" % func) for sol in sols: if len(sol.atoms(AppliedUndef)) != 1: raise ValueError("solutions should have one function only") if len(funcs) != len({sol.lhs for sol in sols}): raise ValueError("number of solutions provided does not match the number of equations") dictsol = {} for sol in sols: func = list(sol.atoms(AppliedUndef))[0] if sol.rhs == func: sol = sol.reversed solved = sol.lhs == func and not sol.rhs.has(func) if not solved: rhs = solve(sol, func) if not rhs: raise NotImplementedError else: rhs = sol.rhs dictsol[func] = rhs checkeq = [] for eq in eqs: for func in funcs: eq = sub_func_doit(eq, func, dictsol[func]) ss = simplify(eq) if ss != 0: eq = ss.expand(force=True) if eq != 0: eq = sqrtdenest(eq).simplify() else: eq = 0 checkeq.append(eq) if len(set(checkeq)) == 1 and list(set(checkeq))[0] == 0: return (True, checkeq) else: return (False, checkeq)