You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
103 lines
3.0 KiB
103 lines
3.0 KiB
5 months ago
|
from sympy.core.relational import Eq
|
||
|
from sympy.core.expr import Expr
|
||
|
from sympy.core.numbers import Integer
|
||
|
from sympy.logic.boolalg import Boolean, And
|
||
|
from sympy.matrices.expressions.matexpr import MatrixExpr
|
||
|
from sympy.matrices.common import ShapeError
|
||
|
from typing import Union
|
||
|
|
||
|
|
||
|
def is_matadd_valid(*args: MatrixExpr) -> Boolean:
|
||
|
"""Return the symbolic condition how ``MatAdd``, ``HadamardProduct``
|
||
|
makes sense.
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
args
|
||
|
The list of arguments of matrices to be tested for.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import MatrixSymbol, symbols
|
||
|
>>> from sympy.matrices.expressions._shape import is_matadd_valid
|
||
|
|
||
|
>>> m, n, p, q = symbols('m n p q')
|
||
|
>>> A = MatrixSymbol('A', m, n)
|
||
|
>>> B = MatrixSymbol('B', p, q)
|
||
|
>>> is_matadd_valid(A, B)
|
||
|
Eq(m, p) & Eq(n, q)
|
||
|
"""
|
||
|
rows, cols = zip(*(arg.shape for arg in args))
|
||
|
return And(
|
||
|
*(Eq(i, j) for i, j in zip(rows[:-1], rows[1:])),
|
||
|
*(Eq(i, j) for i, j in zip(cols[:-1], cols[1:])),
|
||
|
)
|
||
|
|
||
|
|
||
|
def is_matmul_valid(*args: Union[MatrixExpr, Expr]) -> Boolean:
|
||
|
"""Return the symbolic condition how ``MatMul`` makes sense
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
args
|
||
|
The list of arguments of matrices and scalar expressions to be tested
|
||
|
for.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import MatrixSymbol, symbols
|
||
|
>>> from sympy.matrices.expressions._shape import is_matmul_valid
|
||
|
|
||
|
>>> m, n, p, q = symbols('m n p q')
|
||
|
>>> A = MatrixSymbol('A', m, n)
|
||
|
>>> B = MatrixSymbol('B', p, q)
|
||
|
>>> is_matmul_valid(A, B)
|
||
|
Eq(n, p)
|
||
|
"""
|
||
|
rows, cols = zip(*(arg.shape for arg in args if isinstance(arg, MatrixExpr)))
|
||
|
return And(*(Eq(i, j) for i, j in zip(cols[:-1], rows[1:])))
|
||
|
|
||
|
|
||
|
def is_square(arg: MatrixExpr, /) -> Boolean:
|
||
|
"""Return the symbolic condition how the matrix is assumed to be square
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
arg
|
||
|
The matrix to be tested for.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import MatrixSymbol, symbols
|
||
|
>>> from sympy.matrices.expressions._shape import is_square
|
||
|
|
||
|
>>> m, n = symbols('m n')
|
||
|
>>> A = MatrixSymbol('A', m, n)
|
||
|
>>> is_square(A)
|
||
|
Eq(m, n)
|
||
|
"""
|
||
|
return Eq(arg.rows, arg.cols)
|
||
|
|
||
|
|
||
|
def validate_matadd_integer(*args: MatrixExpr) -> None:
|
||
|
"""Validate matrix shape for addition only for integer values"""
|
||
|
rows, cols = zip(*(x.shape for x in args))
|
||
|
if len(set(filter(lambda x: isinstance(x, (int, Integer)), rows))) > 1:
|
||
|
raise ShapeError(f"Matrices have mismatching shape: {rows}")
|
||
|
if len(set(filter(lambda x: isinstance(x, (int, Integer)), cols))) > 1:
|
||
|
raise ShapeError(f"Matrices have mismatching shape: {cols}")
|
||
|
|
||
|
|
||
|
def validate_matmul_integer(*args: MatrixExpr) -> None:
|
||
|
"""Validate matrix shape for multiplication only for integer values"""
|
||
|
for A, B in zip(args[:-1], args[1:]):
|
||
|
i, j = A.cols, B.rows
|
||
|
if isinstance(i, (int, Integer)) and isinstance(j, (int, Integer)) and i != j:
|
||
|
raise ShapeError("Matrices are not aligned", i, j)
|