You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
470 lines
15 KiB
470 lines
15 KiB
5 months ago
|
"""Module with functions operating on IndexedBase, Indexed and Idx objects
|
||
|
|
||
|
- Check shape conformance
|
||
|
- Determine indices in resulting expression
|
||
|
|
||
|
etc.
|
||
|
|
||
|
Methods in this module could be implemented by calling methods on Expr
|
||
|
objects instead. When things stabilize this could be a useful
|
||
|
refactoring.
|
||
|
"""
|
||
|
|
||
|
from functools import reduce
|
||
|
|
||
|
from sympy.core.function import Function
|
||
|
from sympy.functions import exp, Piecewise
|
||
|
from sympy.tensor.indexed import Idx, Indexed
|
||
|
from sympy.utilities import sift
|
||
|
|
||
|
from collections import OrderedDict
|
||
|
|
||
|
class IndexConformanceException(Exception):
|
||
|
pass
|
||
|
|
||
|
def _unique_and_repeated(inds):
|
||
|
"""
|
||
|
Returns the unique and repeated indices. Also note, from the examples given below
|
||
|
that the order of indices is maintained as given in the input.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy.tensor.index_methods import _unique_and_repeated
|
||
|
>>> _unique_and_repeated([2, 3, 1, 3, 0, 4, 0])
|
||
|
([2, 1, 4], [3, 0])
|
||
|
"""
|
||
|
uniq = OrderedDict()
|
||
|
for i in inds:
|
||
|
if i in uniq:
|
||
|
uniq[i] = 0
|
||
|
else:
|
||
|
uniq[i] = 1
|
||
|
return sift(uniq, lambda x: uniq[x], binary=True)
|
||
|
|
||
|
def _remove_repeated(inds):
|
||
|
"""
|
||
|
Removes repeated objects from sequences
|
||
|
|
||
|
Returns a set of the unique objects and a tuple of all that have been
|
||
|
removed.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy.tensor.index_methods import _remove_repeated
|
||
|
>>> l1 = [1, 2, 3, 2]
|
||
|
>>> _remove_repeated(l1)
|
||
|
({1, 3}, (2,))
|
||
|
|
||
|
"""
|
||
|
u, r = _unique_and_repeated(inds)
|
||
|
return set(u), tuple(r)
|
||
|
|
||
|
|
||
|
def _get_indices_Mul(expr, return_dummies=False):
|
||
|
"""Determine the outer indices of a Mul object.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy.tensor.index_methods import _get_indices_Mul
|
||
|
>>> from sympy.tensor.indexed import IndexedBase, Idx
|
||
|
>>> i, j, k = map(Idx, ['i', 'j', 'k'])
|
||
|
>>> x = IndexedBase('x')
|
||
|
>>> y = IndexedBase('y')
|
||
|
>>> _get_indices_Mul(x[i, k]*y[j, k])
|
||
|
({i, j}, {})
|
||
|
>>> _get_indices_Mul(x[i, k]*y[j, k], return_dummies=True)
|
||
|
({i, j}, {}, (k,))
|
||
|
|
||
|
"""
|
||
|
|
||
|
inds = list(map(get_indices, expr.args))
|
||
|
inds, syms = list(zip(*inds))
|
||
|
|
||
|
inds = list(map(list, inds))
|
||
|
inds = list(reduce(lambda x, y: x + y, inds))
|
||
|
inds, dummies = _remove_repeated(inds)
|
||
|
|
||
|
symmetry = {}
|
||
|
for s in syms:
|
||
|
for pair in s:
|
||
|
if pair in symmetry:
|
||
|
symmetry[pair] *= s[pair]
|
||
|
else:
|
||
|
symmetry[pair] = s[pair]
|
||
|
|
||
|
if return_dummies:
|
||
|
return inds, symmetry, dummies
|
||
|
else:
|
||
|
return inds, symmetry
|
||
|
|
||
|
|
||
|
def _get_indices_Pow(expr):
|
||
|
"""Determine outer indices of a power or an exponential.
|
||
|
|
||
|
A power is considered a universal function, so that the indices of a Pow is
|
||
|
just the collection of indices present in the expression. This may be
|
||
|
viewed as a bit inconsistent in the special case:
|
||
|
|
||
|
x[i]**2 = x[i]*x[i] (1)
|
||
|
|
||
|
The above expression could have been interpreted as the contraction of x[i]
|
||
|
with itself, but we choose instead to interpret it as a function
|
||
|
|
||
|
lambda y: y**2
|
||
|
|
||
|
applied to each element of x (a universal function in numpy terms). In
|
||
|
order to allow an interpretation of (1) as a contraction, we need
|
||
|
contravariant and covariant Idx subclasses. (FIXME: this is not yet
|
||
|
implemented)
|
||
|
|
||
|
Expressions in the base or exponent are subject to contraction as usual,
|
||
|
but an index that is present in the exponent, will not be considered
|
||
|
contractable with its own base. Note however, that indices in the same
|
||
|
exponent can be contracted with each other.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy.tensor.index_methods import _get_indices_Pow
|
||
|
>>> from sympy import Pow, exp, IndexedBase, Idx
|
||
|
>>> A = IndexedBase('A')
|
||
|
>>> x = IndexedBase('x')
|
||
|
>>> i, j, k = map(Idx, ['i', 'j', 'k'])
|
||
|
>>> _get_indices_Pow(exp(A[i, j]*x[j]))
|
||
|
({i}, {})
|
||
|
>>> _get_indices_Pow(Pow(x[i], x[i]))
|
||
|
({i}, {})
|
||
|
>>> _get_indices_Pow(Pow(A[i, j]*x[j], x[i]))
|
||
|
({i}, {})
|
||
|
|
||
|
"""
|
||
|
base, exp = expr.as_base_exp()
|
||
|
binds, bsyms = get_indices(base)
|
||
|
einds, esyms = get_indices(exp)
|
||
|
|
||
|
inds = binds | einds
|
||
|
|
||
|
# FIXME: symmetries from power needs to check special cases, else nothing
|
||
|
symmetries = {}
|
||
|
|
||
|
return inds, symmetries
|
||
|
|
||
|
|
||
|
def _get_indices_Add(expr):
|
||
|
"""Determine outer indices of an Add object.
|
||
|
|
||
|
In a sum, each term must have the same set of outer indices. A valid
|
||
|
expression could be
|
||
|
|
||
|
x(i)*y(j) - x(j)*y(i)
|
||
|
|
||
|
But we do not allow expressions like:
|
||
|
|
||
|
x(i)*y(j) - z(j)*z(j)
|
||
|
|
||
|
FIXME: Add support for Numpy broadcasting
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy.tensor.index_methods import _get_indices_Add
|
||
|
>>> from sympy.tensor.indexed import IndexedBase, Idx
|
||
|
>>> i, j, k = map(Idx, ['i', 'j', 'k'])
|
||
|
>>> x = IndexedBase('x')
|
||
|
>>> y = IndexedBase('y')
|
||
|
>>> _get_indices_Add(x[i] + x[k]*y[i, k])
|
||
|
({i}, {})
|
||
|
|
||
|
"""
|
||
|
|
||
|
inds = list(map(get_indices, expr.args))
|
||
|
inds, syms = list(zip(*inds))
|
||
|
|
||
|
# allow broadcast of scalars
|
||
|
non_scalars = [x for x in inds if x != set()]
|
||
|
if not non_scalars:
|
||
|
return set(), {}
|
||
|
|
||
|
if not all(x == non_scalars[0] for x in non_scalars[1:]):
|
||
|
raise IndexConformanceException("Indices are not consistent: %s" % expr)
|
||
|
if not reduce(lambda x, y: x != y or y, syms):
|
||
|
symmetries = syms[0]
|
||
|
else:
|
||
|
# FIXME: search for symmetries
|
||
|
symmetries = {}
|
||
|
|
||
|
return non_scalars[0], symmetries
|
||
|
|
||
|
|
||
|
def get_indices(expr):
|
||
|
"""Determine the outer indices of expression ``expr``
|
||
|
|
||
|
By *outer* we mean indices that are not summation indices. Returns a set
|
||
|
and a dict. The set contains outer indices and the dict contains
|
||
|
information about index symmetries.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy.tensor.index_methods import get_indices
|
||
|
>>> from sympy import symbols
|
||
|
>>> from sympy.tensor import IndexedBase
|
||
|
>>> x, y, A = map(IndexedBase, ['x', 'y', 'A'])
|
||
|
>>> i, j, a, z = symbols('i j a z', integer=True)
|
||
|
|
||
|
The indices of the total expression is determined, Repeated indices imply a
|
||
|
summation, for instance the trace of a matrix A:
|
||
|
|
||
|
>>> get_indices(A[i, i])
|
||
|
(set(), {})
|
||
|
|
||
|
In the case of many terms, the terms are required to have identical
|
||
|
outer indices. Else an IndexConformanceException is raised.
|
||
|
|
||
|
>>> get_indices(x[i] + A[i, j]*y[j])
|
||
|
({i}, {})
|
||
|
|
||
|
:Exceptions:
|
||
|
|
||
|
An IndexConformanceException means that the terms ar not compatible, e.g.
|
||
|
|
||
|
>>> get_indices(x[i] + y[j]) #doctest: +SKIP
|
||
|
(...)
|
||
|
IndexConformanceException: Indices are not consistent: x(i) + y(j)
|
||
|
|
||
|
.. warning::
|
||
|
The concept of *outer* indices applies recursively, starting on the deepest
|
||
|
level. This implies that dummies inside parenthesis are assumed to be
|
||
|
summed first, so that the following expression is handled gracefully:
|
||
|
|
||
|
>>> get_indices((x[i] + A[i, j]*y[j])*x[j])
|
||
|
({i, j}, {})
|
||
|
|
||
|
This is correct and may appear convenient, but you need to be careful
|
||
|
with this as SymPy will happily .expand() the product, if requested. The
|
||
|
resulting expression would mix the outer ``j`` with the dummies inside
|
||
|
the parenthesis, which makes it a different expression. To be on the
|
||
|
safe side, it is best to avoid such ambiguities by using unique indices
|
||
|
for all contractions that should be held separate.
|
||
|
|
||
|
"""
|
||
|
# We call ourself recursively to determine indices of sub expressions.
|
||
|
|
||
|
# break recursion
|
||
|
if isinstance(expr, Indexed):
|
||
|
c = expr.indices
|
||
|
inds, dummies = _remove_repeated(c)
|
||
|
return inds, {}
|
||
|
elif expr is None:
|
||
|
return set(), {}
|
||
|
elif isinstance(expr, Idx):
|
||
|
return {expr}, {}
|
||
|
elif expr.is_Atom:
|
||
|
return set(), {}
|
||
|
|
||
|
|
||
|
# recurse via specialized functions
|
||
|
else:
|
||
|
if expr.is_Mul:
|
||
|
return _get_indices_Mul(expr)
|
||
|
elif expr.is_Add:
|
||
|
return _get_indices_Add(expr)
|
||
|
elif expr.is_Pow or isinstance(expr, exp):
|
||
|
return _get_indices_Pow(expr)
|
||
|
|
||
|
elif isinstance(expr, Piecewise):
|
||
|
# FIXME: No support for Piecewise yet
|
||
|
return set(), {}
|
||
|
elif isinstance(expr, Function):
|
||
|
# Support ufunc like behaviour by returning indices from arguments.
|
||
|
# Functions do not interpret repeated indices across arguments
|
||
|
# as summation
|
||
|
ind0 = set()
|
||
|
for arg in expr.args:
|
||
|
ind, sym = get_indices(arg)
|
||
|
ind0 |= ind
|
||
|
return ind0, sym
|
||
|
|
||
|
# this test is expensive, so it should be at the end
|
||
|
elif not expr.has(Indexed):
|
||
|
return set(), {}
|
||
|
raise NotImplementedError(
|
||
|
"FIXME: No specialized handling of type %s" % type(expr))
|
||
|
|
||
|
|
||
|
def get_contraction_structure(expr):
|
||
|
"""Determine dummy indices of ``expr`` and describe its structure
|
||
|
|
||
|
By *dummy* we mean indices that are summation indices.
|
||
|
|
||
|
The structure of the expression is determined and described as follows:
|
||
|
|
||
|
1) A conforming summation of Indexed objects is described with a dict where
|
||
|
the keys are summation indices and the corresponding values are sets
|
||
|
containing all terms for which the summation applies. All Add objects
|
||
|
in the SymPy expression tree are described like this.
|
||
|
|
||
|
2) For all nodes in the SymPy expression tree that are *not* of type Add, the
|
||
|
following applies:
|
||
|
|
||
|
If a node discovers contractions in one of its arguments, the node
|
||
|
itself will be stored as a key in the dict. For that key, the
|
||
|
corresponding value is a list of dicts, each of which is the result of a
|
||
|
recursive call to get_contraction_structure(). The list contains only
|
||
|
dicts for the non-trivial deeper contractions, omitting dicts with None
|
||
|
as the one and only key.
|
||
|
|
||
|
.. Note:: The presence of expressions among the dictionary keys indicates
|
||
|
multiple levels of index contractions. A nested dict displays nested
|
||
|
contractions and may itself contain dicts from a deeper level. In
|
||
|
practical calculations the summation in the deepest nested level must be
|
||
|
calculated first so that the outer expression can access the resulting
|
||
|
indexed object.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy.tensor.index_methods import get_contraction_structure
|
||
|
>>> from sympy import default_sort_key
|
||
|
>>> from sympy.tensor import IndexedBase, Idx
|
||
|
>>> x, y, A = map(IndexedBase, ['x', 'y', 'A'])
|
||
|
>>> i, j, k, l = map(Idx, ['i', 'j', 'k', 'l'])
|
||
|
>>> get_contraction_structure(x[i]*y[i] + A[j, j])
|
||
|
{(i,): {x[i]*y[i]}, (j,): {A[j, j]}}
|
||
|
>>> get_contraction_structure(x[i]*y[j])
|
||
|
{None: {x[i]*y[j]}}
|
||
|
|
||
|
A multiplication of contracted factors results in nested dicts representing
|
||
|
the internal contractions.
|
||
|
|
||
|
>>> d = get_contraction_structure(x[i, i]*y[j, j])
|
||
|
>>> sorted(d.keys(), key=default_sort_key)
|
||
|
[None, x[i, i]*y[j, j]]
|
||
|
|
||
|
In this case, the product has no contractions:
|
||
|
|
||
|
>>> d[None]
|
||
|
{x[i, i]*y[j, j]}
|
||
|
|
||
|
Factors are contracted "first":
|
||
|
|
||
|
>>> sorted(d[x[i, i]*y[j, j]], key=default_sort_key)
|
||
|
[{(i,): {x[i, i]}}, {(j,): {y[j, j]}}]
|
||
|
|
||
|
A parenthesized Add object is also returned as a nested dictionary. The
|
||
|
term containing the parenthesis is a Mul with a contraction among the
|
||
|
arguments, so it will be found as a key in the result. It stores the
|
||
|
dictionary resulting from a recursive call on the Add expression.
|
||
|
|
||
|
>>> d = get_contraction_structure(x[i]*(y[i] + A[i, j]*x[j]))
|
||
|
>>> sorted(d.keys(), key=default_sort_key)
|
||
|
[(A[i, j]*x[j] + y[i])*x[i], (i,)]
|
||
|
>>> d[(i,)]
|
||
|
{(A[i, j]*x[j] + y[i])*x[i]}
|
||
|
>>> d[x[i]*(A[i, j]*x[j] + y[i])]
|
||
|
[{None: {y[i]}, (j,): {A[i, j]*x[j]}}]
|
||
|
|
||
|
Powers with contractions in either base or exponent will also be found as
|
||
|
keys in the dictionary, mapping to a list of results from recursive calls:
|
||
|
|
||
|
>>> d = get_contraction_structure(A[j, j]**A[i, i])
|
||
|
>>> d[None]
|
||
|
{A[j, j]**A[i, i]}
|
||
|
>>> nested_contractions = d[A[j, j]**A[i, i]]
|
||
|
>>> nested_contractions[0]
|
||
|
{(j,): {A[j, j]}}
|
||
|
>>> nested_contractions[1]
|
||
|
{(i,): {A[i, i]}}
|
||
|
|
||
|
The description of the contraction structure may appear complicated when
|
||
|
represented with a string in the above examples, but it is easy to iterate
|
||
|
over:
|
||
|
|
||
|
>>> from sympy import Expr
|
||
|
>>> for key in d:
|
||
|
... if isinstance(key, Expr):
|
||
|
... continue
|
||
|
... for term in d[key]:
|
||
|
... if term in d:
|
||
|
... # treat deepest contraction first
|
||
|
... pass
|
||
|
... # treat outermost contactions here
|
||
|
|
||
|
"""
|
||
|
|
||
|
# We call ourself recursively to inspect sub expressions.
|
||
|
|
||
|
if isinstance(expr, Indexed):
|
||
|
junk, key = _remove_repeated(expr.indices)
|
||
|
return {key or None: {expr}}
|
||
|
elif expr.is_Atom:
|
||
|
return {None: {expr}}
|
||
|
elif expr.is_Mul:
|
||
|
junk, junk, key = _get_indices_Mul(expr, return_dummies=True)
|
||
|
result = {key or None: {expr}}
|
||
|
# recurse on every factor
|
||
|
nested = []
|
||
|
for fac in expr.args:
|
||
|
facd = get_contraction_structure(fac)
|
||
|
if not (None in facd and len(facd) == 1):
|
||
|
nested.append(facd)
|
||
|
if nested:
|
||
|
result[expr] = nested
|
||
|
return result
|
||
|
elif expr.is_Pow or isinstance(expr, exp):
|
||
|
# recurse in base and exp separately. If either has internal
|
||
|
# contractions we must include ourselves as a key in the returned dict
|
||
|
b, e = expr.as_base_exp()
|
||
|
dbase = get_contraction_structure(b)
|
||
|
dexp = get_contraction_structure(e)
|
||
|
|
||
|
dicts = []
|
||
|
for d in dbase, dexp:
|
||
|
if not (None in d and len(d) == 1):
|
||
|
dicts.append(d)
|
||
|
result = {None: {expr}}
|
||
|
if dicts:
|
||
|
result[expr] = dicts
|
||
|
return result
|
||
|
elif expr.is_Add:
|
||
|
# Note: we just collect all terms with identical summation indices, We
|
||
|
# do nothing to identify equivalent terms here, as this would require
|
||
|
# substitutions or pattern matching in expressions of unknown
|
||
|
# complexity.
|
||
|
result = {}
|
||
|
for term in expr.args:
|
||
|
# recurse on every term
|
||
|
d = get_contraction_structure(term)
|
||
|
for key in d:
|
||
|
if key in result:
|
||
|
result[key] |= d[key]
|
||
|
else:
|
||
|
result[key] = d[key]
|
||
|
return result
|
||
|
|
||
|
elif isinstance(expr, Piecewise):
|
||
|
# FIXME: No support for Piecewise yet
|
||
|
return {None: expr}
|
||
|
elif isinstance(expr, Function):
|
||
|
# Collect non-trivial contraction structures in each argument
|
||
|
# We do not report repeated indices in separate arguments as a
|
||
|
# contraction
|
||
|
deeplist = []
|
||
|
for arg in expr.args:
|
||
|
deep = get_contraction_structure(arg)
|
||
|
if not (None in deep and len(deep) == 1):
|
||
|
deeplist.append(deep)
|
||
|
d = {None: {expr}}
|
||
|
if deeplist:
|
||
|
d[expr] = deeplist
|
||
|
return d
|
||
|
|
||
|
# this test is expensive, so it should be at the end
|
||
|
elif not expr.has(Indexed):
|
||
|
return {None: {expr}}
|
||
|
raise NotImplementedError(
|
||
|
"FIXME: No specialized handling of type %s" % type(expr))
|