You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
73 lines
2.1 KiB
73 lines
2.1 KiB
5 months ago
|
from contextlib import contextmanager
|
||
|
from threading import local
|
||
|
|
||
|
from sympy.core.function import expand_mul
|
||
|
|
||
|
|
||
|
class DotProdSimpState(local):
|
||
|
def __init__(self):
|
||
|
self.state = None
|
||
|
|
||
|
_dotprodsimp_state = DotProdSimpState()
|
||
|
|
||
|
@contextmanager
|
||
|
def dotprodsimp(x):
|
||
|
old = _dotprodsimp_state.state
|
||
|
|
||
|
try:
|
||
|
_dotprodsimp_state.state = x
|
||
|
yield
|
||
|
finally:
|
||
|
_dotprodsimp_state.state = old
|
||
|
|
||
|
|
||
|
def _dotprodsimp(expr, withsimp=False):
|
||
|
"""Wrapper for simplify.dotprodsimp to avoid circular imports."""
|
||
|
from sympy.simplify.simplify import dotprodsimp as dps
|
||
|
return dps(expr, withsimp=withsimp)
|
||
|
|
||
|
|
||
|
def _get_intermediate_simp(deffunc=lambda x: x, offfunc=lambda x: x,
|
||
|
onfunc=_dotprodsimp, dotprodsimp=None):
|
||
|
"""Support function for controlling intermediate simplification. Returns a
|
||
|
simplification function according to the global setting of dotprodsimp
|
||
|
operation.
|
||
|
|
||
|
``deffunc`` - Function to be used by default.
|
||
|
``offfunc`` - Function to be used if dotprodsimp has been turned off.
|
||
|
``onfunc`` - Function to be used if dotprodsimp has been turned on.
|
||
|
``dotprodsimp`` - True, False or None. Will be overridden by global
|
||
|
_dotprodsimp_state.state if that is not None.
|
||
|
"""
|
||
|
|
||
|
if dotprodsimp is False or _dotprodsimp_state.state is False:
|
||
|
return offfunc
|
||
|
if dotprodsimp is True or _dotprodsimp_state.state is True:
|
||
|
return onfunc
|
||
|
|
||
|
return deffunc # None, None
|
||
|
|
||
|
|
||
|
def _get_intermediate_simp_bool(default=False, dotprodsimp=None):
|
||
|
"""Same as ``_get_intermediate_simp`` but returns bools instead of functions
|
||
|
by default."""
|
||
|
|
||
|
return _get_intermediate_simp(default, False, True, dotprodsimp)
|
||
|
|
||
|
|
||
|
def _iszero(x):
|
||
|
"""Returns True if x is zero."""
|
||
|
return getattr(x, 'is_zero', None)
|
||
|
|
||
|
|
||
|
def _is_zero_after_expand_mul(x):
|
||
|
"""Tests by expand_mul only, suitable for polynomials and rational
|
||
|
functions."""
|
||
|
return expand_mul(x) == 0
|
||
|
|
||
|
|
||
|
def _simplify(expr):
|
||
|
""" Wrapper to avoid circular imports. """
|
||
|
from sympy.simplify.simplify import simplify
|
||
|
return simplify(expr)
|