You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

156 lines
4.7 KiB

from functools import reduce
import operator
from sympy.core import Basic, sympify
from sympy.core.add import add, Add, _could_extract_minus_sign
from sympy.core.sorting import default_sort_key
from sympy.functions import adjoint
from sympy.matrices.matrices import MatrixBase
from sympy.matrices.expressions.transpose import transpose
from sympy.strategies import (rm_id, unpack, flatten, sort, condition,
exhaust, do_one, glom)
from sympy.matrices.expressions.matexpr import MatrixExpr
from sympy.matrices.expressions.special import ZeroMatrix, GenericZeroMatrix
from sympy.matrices.expressions._shape import validate_matadd_integer as validate
from sympy.utilities.iterables import sift
from sympy.utilities.exceptions import sympy_deprecation_warning
# XXX: MatAdd should perhaps not subclass directly from Add
class MatAdd(MatrixExpr, Add):
"""A Sum of Matrix Expressions
MatAdd inherits from and operates like SymPy Add
Examples
========
>>> from sympy import MatAdd, MatrixSymbol
>>> A = MatrixSymbol('A', 5, 5)
>>> B = MatrixSymbol('B', 5, 5)
>>> C = MatrixSymbol('C', 5, 5)
>>> MatAdd(A, B, C)
A + B + C
"""
is_MatAdd = True
identity = GenericZeroMatrix()
def __new__(cls, *args, evaluate=False, check=None, _sympify=True):
if not args:
return cls.identity
# This must be removed aggressively in the constructor to avoid
# TypeErrors from GenericZeroMatrix().shape
args = list(filter(lambda i: cls.identity != i, args))
if _sympify:
args = list(map(sympify, args))
if not all(isinstance(arg, MatrixExpr) for arg in args):
raise TypeError("Mix of Matrix and Scalar symbols")
obj = Basic.__new__(cls, *args)
if check is not None:
sympy_deprecation_warning(
"Passing check to MatAdd is deprecated and the check argument will be removed in a future version.",
deprecated_since_version="1.11",
active_deprecations_target='remove-check-argument-from-matrix-operations')
if check is not False:
validate(*args)
if evaluate:
obj = cls._evaluate(obj)
return obj
@classmethod
def _evaluate(cls, expr):
return canonicalize(expr)
@property
def shape(self):
return self.args[0].shape
def could_extract_minus_sign(self):
return _could_extract_minus_sign(self)
def expand(self, **kwargs):
expanded = super(MatAdd, self).expand(**kwargs)
return self._evaluate(expanded)
def _entry(self, i, j, **kwargs):
return Add(*[arg._entry(i, j, **kwargs) for arg in self.args])
def _eval_transpose(self):
return MatAdd(*[transpose(arg) for arg in self.args]).doit()
def _eval_adjoint(self):
return MatAdd(*[adjoint(arg) for arg in self.args]).doit()
def _eval_trace(self):
from .trace import trace
return Add(*[trace(arg) for arg in self.args]).doit()
def doit(self, **hints):
deep = hints.get('deep', True)
if deep:
args = [arg.doit(**hints) for arg in self.args]
else:
args = self.args
return canonicalize(MatAdd(*args))
def _eval_derivative_matrix_lines(self, x):
add_lines = [arg._eval_derivative_matrix_lines(x) for arg in self.args]
return [j for i in add_lines for j in i]
add.register_handlerclass((Add, MatAdd), MatAdd)
factor_of = lambda arg: arg.as_coeff_mmul()[0]
matrix_of = lambda arg: unpack(arg.as_coeff_mmul()[1])
def combine(cnt, mat):
if cnt == 1:
return mat
else:
return cnt * mat
def merge_explicit(matadd):
""" Merge explicit MatrixBase arguments
Examples
========
>>> from sympy import MatrixSymbol, eye, Matrix, MatAdd, pprint
>>> from sympy.matrices.expressions.matadd import merge_explicit
>>> A = MatrixSymbol('A', 2, 2)
>>> B = eye(2)
>>> C = Matrix([[1, 2], [3, 4]])
>>> X = MatAdd(A, B, C)
>>> pprint(X)
[1 0] [1 2]
A + [ ] + [ ]
[0 1] [3 4]
>>> pprint(merge_explicit(X))
[2 2]
A + [ ]
[3 5]
"""
groups = sift(matadd.args, lambda arg: isinstance(arg, MatrixBase))
if len(groups[True]) > 1:
return MatAdd(*(groups[False] + [reduce(operator.add, groups[True])]))
else:
return matadd
rules = (rm_id(lambda x: x == 0 or isinstance(x, ZeroMatrix)),
unpack,
flatten,
glom(matrix_of, factor_of, combine),
merge_explicit,
sort(default_sort_key))
canonicalize = exhaust(condition(lambda x: isinstance(x, MatAdd),
do_one(*rules)))