You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

73 lines
2.1 KiB

from contextlib import contextmanager
from threading import local
from sympy.core.function import expand_mul
class DotProdSimpState(local):
def __init__(self):
self.state = None
_dotprodsimp_state = DotProdSimpState()
@contextmanager
def dotprodsimp(x):
old = _dotprodsimp_state.state
try:
_dotprodsimp_state.state = x
yield
finally:
_dotprodsimp_state.state = old
def _dotprodsimp(expr, withsimp=False):
"""Wrapper for simplify.dotprodsimp to avoid circular imports."""
from sympy.simplify.simplify import dotprodsimp as dps
return dps(expr, withsimp=withsimp)
def _get_intermediate_simp(deffunc=lambda x: x, offfunc=lambda x: x,
onfunc=_dotprodsimp, dotprodsimp=None):
"""Support function for controlling intermediate simplification. Returns a
simplification function according to the global setting of dotprodsimp
operation.
``deffunc`` - Function to be used by default.
``offfunc`` - Function to be used if dotprodsimp has been turned off.
``onfunc`` - Function to be used if dotprodsimp has been turned on.
``dotprodsimp`` - True, False or None. Will be overridden by global
_dotprodsimp_state.state if that is not None.
"""
if dotprodsimp is False or _dotprodsimp_state.state is False:
return offfunc
if dotprodsimp is True or _dotprodsimp_state.state is True:
return onfunc
return deffunc # None, None
def _get_intermediate_simp_bool(default=False, dotprodsimp=None):
"""Same as ``_get_intermediate_simp`` but returns bools instead of functions
by default."""
return _get_intermediate_simp(default, False, True, dotprodsimp)
def _iszero(x):
"""Returns True if x is zero."""
return getattr(x, 'is_zero', None)
def _is_zero_after_expand_mul(x):
"""Tests by expand_mul only, suitable for polynomials and rational
functions."""
return expand_mul(x) == 0
def _simplify(expr):
""" Wrapper to avoid circular imports. """
from sympy.simplify.simplify import simplify
return simplify(expr)