You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

78 lines
2.0 KiB

5 months ago
"""
We have a few different kind of Matrices
Matrix, ImmutableMatrix, MatrixExpr
Here we test the extent to which they cooperate
"""
from sympy.core.symbol import symbols
from sympy.matrices import (Matrix, MatrixSymbol, eye, Identity,
ImmutableMatrix)
from sympy.matrices.expressions import MatrixExpr, MatAdd
from sympy.matrices.common import classof
from sympy.testing.pytest import raises
SM = MatrixSymbol('X', 3, 3)
SV = MatrixSymbol('v', 3, 1)
MM = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
IM = ImmutableMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
meye = eye(3)
imeye = ImmutableMatrix(eye(3))
ideye = Identity(3)
a, b, c = symbols('a,b,c')
def test_IM_MM():
assert isinstance(MM + IM, ImmutableMatrix)
assert isinstance(IM + MM, ImmutableMatrix)
assert isinstance(2*IM + MM, ImmutableMatrix)
assert MM.equals(IM)
def test_ME_MM():
assert isinstance(Identity(3) + MM, MatrixExpr)
assert isinstance(SM + MM, MatAdd)
assert isinstance(MM + SM, MatAdd)
assert (Identity(3) + MM)[1, 1] == 6
def test_equality():
a, b, c = Identity(3), eye(3), ImmutableMatrix(eye(3))
for x in [a, b, c]:
for y in [a, b, c]:
assert x.equals(y)
def test_matrix_symbol_MM():
X = MatrixSymbol('X', 3, 3)
Y = eye(3) + X
assert Y[1, 1] == 1 + X[1, 1]
def test_matrix_symbol_vector_matrix_multiplication():
A = MM * SV
B = IM * SV
assert A == B
C = (SV.T * MM.T).T
assert B == C
D = (SV.T * IM.T).T
assert C == D
def test_indexing_interactions():
assert (a * IM)[1, 1] == 5*a
assert (SM + IM)[1, 1] == SM[1, 1] + IM[1, 1]
assert (SM * IM)[1, 1] == SM[1, 0]*IM[0, 1] + SM[1, 1]*IM[1, 1] + \
SM[1, 2]*IM[2, 1]
def test_classof():
A = Matrix(3, 3, range(9))
B = ImmutableMatrix(3, 3, range(9))
C = MatrixSymbol('C', 3, 3)
assert classof(A, A) == Matrix
assert classof(B, B) == ImmutableMatrix
assert classof(A, B) == ImmutableMatrix
assert classof(B, A) == ImmutableMatrix
raises(TypeError, lambda: classof(A, C))