You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
304 lines
7.9 KiB
304 lines
7.9 KiB
5 months ago
|
from sympy.core import S
|
||
|
from sympy.core.sympify import _sympify
|
||
|
from sympy.functions import KroneckerDelta
|
||
|
|
||
|
from .matexpr import MatrixExpr
|
||
|
from .special import ZeroMatrix, Identity, OneMatrix
|
||
|
|
||
|
|
||
|
class PermutationMatrix(MatrixExpr):
|
||
|
"""A Permutation Matrix
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
perm : Permutation
|
||
|
The permutation the matrix uses.
|
||
|
|
||
|
The size of the permutation determines the matrix size.
|
||
|
|
||
|
See the documentation of
|
||
|
:class:`sympy.combinatorics.permutations.Permutation` for
|
||
|
the further information of how to create a permutation object.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import Matrix, PermutationMatrix
|
||
|
>>> from sympy.combinatorics import Permutation
|
||
|
|
||
|
Creating a permutation matrix:
|
||
|
|
||
|
>>> p = Permutation(1, 2, 0)
|
||
|
>>> P = PermutationMatrix(p)
|
||
|
>>> P = P.as_explicit()
|
||
|
>>> P
|
||
|
Matrix([
|
||
|
[0, 1, 0],
|
||
|
[0, 0, 1],
|
||
|
[1, 0, 0]])
|
||
|
|
||
|
Permuting a matrix row and column:
|
||
|
|
||
|
>>> M = Matrix([0, 1, 2])
|
||
|
>>> Matrix(P*M)
|
||
|
Matrix([
|
||
|
[1],
|
||
|
[2],
|
||
|
[0]])
|
||
|
|
||
|
>>> Matrix(M.T*P)
|
||
|
Matrix([[2, 0, 1]])
|
||
|
|
||
|
See Also
|
||
|
========
|
||
|
|
||
|
sympy.combinatorics.permutations.Permutation
|
||
|
"""
|
||
|
|
||
|
def __new__(cls, perm):
|
||
|
from sympy.combinatorics.permutations import Permutation
|
||
|
|
||
|
perm = _sympify(perm)
|
||
|
if not isinstance(perm, Permutation):
|
||
|
raise ValueError(
|
||
|
"{} must be a SymPy Permutation instance.".format(perm))
|
||
|
|
||
|
return super().__new__(cls, perm)
|
||
|
|
||
|
@property
|
||
|
def shape(self):
|
||
|
size = self.args[0].size
|
||
|
return (size, size)
|
||
|
|
||
|
@property
|
||
|
def is_Identity(self):
|
||
|
return self.args[0].is_Identity
|
||
|
|
||
|
def doit(self, **hints):
|
||
|
if self.is_Identity:
|
||
|
return Identity(self.rows)
|
||
|
return self
|
||
|
|
||
|
def _entry(self, i, j, **kwargs):
|
||
|
perm = self.args[0]
|
||
|
return KroneckerDelta(perm.apply(i), j)
|
||
|
|
||
|
def _eval_power(self, exp):
|
||
|
return PermutationMatrix(self.args[0] ** exp).doit()
|
||
|
|
||
|
def _eval_inverse(self):
|
||
|
return PermutationMatrix(self.args[0] ** -1)
|
||
|
|
||
|
_eval_transpose = _eval_adjoint = _eval_inverse
|
||
|
|
||
|
def _eval_determinant(self):
|
||
|
sign = self.args[0].signature()
|
||
|
if sign == 1:
|
||
|
return S.One
|
||
|
elif sign == -1:
|
||
|
return S.NegativeOne
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def _eval_rewrite_as_BlockDiagMatrix(self, *args, **kwargs):
|
||
|
from sympy.combinatorics.permutations import Permutation
|
||
|
from .blockmatrix import BlockDiagMatrix
|
||
|
|
||
|
perm = self.args[0]
|
||
|
full_cyclic_form = perm.full_cyclic_form
|
||
|
|
||
|
cycles_picks = []
|
||
|
|
||
|
# Stage 1. Decompose the cycles into the blockable form.
|
||
|
a, b, c = 0, 0, 0
|
||
|
flag = False
|
||
|
for cycle in full_cyclic_form:
|
||
|
l = len(cycle)
|
||
|
m = max(cycle)
|
||
|
|
||
|
if not flag:
|
||
|
if m + 1 > a + l:
|
||
|
flag = True
|
||
|
temp = [cycle]
|
||
|
b = m
|
||
|
c = l
|
||
|
else:
|
||
|
cycles_picks.append([cycle])
|
||
|
a += l
|
||
|
|
||
|
else:
|
||
|
if m > b:
|
||
|
if m + 1 == a + c + l:
|
||
|
temp.append(cycle)
|
||
|
cycles_picks.append(temp)
|
||
|
flag = False
|
||
|
a = m+1
|
||
|
else:
|
||
|
b = m
|
||
|
temp.append(cycle)
|
||
|
c += l
|
||
|
else:
|
||
|
if b + 1 == a + c + l:
|
||
|
temp.append(cycle)
|
||
|
cycles_picks.append(temp)
|
||
|
flag = False
|
||
|
a = b+1
|
||
|
else:
|
||
|
temp.append(cycle)
|
||
|
c += l
|
||
|
|
||
|
# Stage 2. Normalize each decomposed cycles and build matrix.
|
||
|
p = 0
|
||
|
args = []
|
||
|
for pick in cycles_picks:
|
||
|
new_cycles = []
|
||
|
l = 0
|
||
|
for cycle in pick:
|
||
|
new_cycle = [i - p for i in cycle]
|
||
|
new_cycles.append(new_cycle)
|
||
|
l += len(cycle)
|
||
|
p += l
|
||
|
perm = Permutation(new_cycles)
|
||
|
mat = PermutationMatrix(perm)
|
||
|
args.append(mat)
|
||
|
|
||
|
return BlockDiagMatrix(*args)
|
||
|
|
||
|
|
||
|
class MatrixPermute(MatrixExpr):
|
||
|
r"""Symbolic representation for permuting matrix rows or columns.
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
perm : Permutation, PermutationMatrix
|
||
|
The permutation to use for permuting the matrix.
|
||
|
The permutation can be resized to the suitable one,
|
||
|
|
||
|
axis : 0 or 1
|
||
|
The axis to permute alongside.
|
||
|
If `0`, it will permute the matrix rows.
|
||
|
If `1`, it will permute the matrix columns.
|
||
|
|
||
|
Notes
|
||
|
=====
|
||
|
|
||
|
This follows the same notation used in
|
||
|
:meth:`sympy.matrices.common.MatrixCommon.permute`.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import Matrix, MatrixPermute
|
||
|
>>> from sympy.combinatorics import Permutation
|
||
|
|
||
|
Permuting the matrix rows:
|
||
|
|
||
|
>>> p = Permutation(1, 2, 0)
|
||
|
>>> A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||
|
>>> B = MatrixPermute(A, p, axis=0)
|
||
|
>>> B.as_explicit()
|
||
|
Matrix([
|
||
|
[4, 5, 6],
|
||
|
[7, 8, 9],
|
||
|
[1, 2, 3]])
|
||
|
|
||
|
Permuting the matrix columns:
|
||
|
|
||
|
>>> B = MatrixPermute(A, p, axis=1)
|
||
|
>>> B.as_explicit()
|
||
|
Matrix([
|
||
|
[2, 3, 1],
|
||
|
[5, 6, 4],
|
||
|
[8, 9, 7]])
|
||
|
|
||
|
See Also
|
||
|
========
|
||
|
|
||
|
sympy.matrices.common.MatrixCommon.permute
|
||
|
"""
|
||
|
def __new__(cls, mat, perm, axis=S.Zero):
|
||
|
from sympy.combinatorics.permutations import Permutation
|
||
|
|
||
|
mat = _sympify(mat)
|
||
|
if not mat.is_Matrix:
|
||
|
raise ValueError(
|
||
|
"{} must be a SymPy matrix instance.".format(perm))
|
||
|
|
||
|
perm = _sympify(perm)
|
||
|
if isinstance(perm, PermutationMatrix):
|
||
|
perm = perm.args[0]
|
||
|
|
||
|
if not isinstance(perm, Permutation):
|
||
|
raise ValueError(
|
||
|
"{} must be a SymPy Permutation or a PermutationMatrix " \
|
||
|
"instance".format(perm))
|
||
|
|
||
|
axis = _sympify(axis)
|
||
|
if axis not in (0, 1):
|
||
|
raise ValueError("The axis must be 0 or 1.")
|
||
|
|
||
|
mat_size = mat.shape[axis]
|
||
|
if mat_size != perm.size:
|
||
|
try:
|
||
|
perm = perm.resize(mat_size)
|
||
|
except ValueError:
|
||
|
raise ValueError(
|
||
|
"Size does not match between the permutation {} "
|
||
|
"and the matrix {} threaded over the axis {} "
|
||
|
"and cannot be converted."
|
||
|
.format(perm, mat, axis))
|
||
|
|
||
|
return super().__new__(cls, mat, perm, axis)
|
||
|
|
||
|
def doit(self, deep=True, **hints):
|
||
|
mat, perm, axis = self.args
|
||
|
|
||
|
if deep:
|
||
|
mat = mat.doit(deep=deep, **hints)
|
||
|
perm = perm.doit(deep=deep, **hints)
|
||
|
|
||
|
if perm.is_Identity:
|
||
|
return mat
|
||
|
|
||
|
if mat.is_Identity:
|
||
|
if axis is S.Zero:
|
||
|
return PermutationMatrix(perm)
|
||
|
elif axis is S.One:
|
||
|
return PermutationMatrix(perm**-1)
|
||
|
|
||
|
if isinstance(mat, (ZeroMatrix, OneMatrix)):
|
||
|
return mat
|
||
|
|
||
|
if isinstance(mat, MatrixPermute) and mat.args[2] == axis:
|
||
|
return MatrixPermute(mat.args[0], perm * mat.args[1], axis)
|
||
|
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
def shape(self):
|
||
|
return self.args[0].shape
|
||
|
|
||
|
def _entry(self, i, j, **kwargs):
|
||
|
mat, perm, axis = self.args
|
||
|
|
||
|
if axis == 0:
|
||
|
return mat[perm.apply(i), j]
|
||
|
elif axis == 1:
|
||
|
return mat[i, perm.apply(j)]
|
||
|
|
||
|
def _eval_rewrite_as_MatMul(self, *args, **kwargs):
|
||
|
from .matmul import MatMul
|
||
|
|
||
|
mat, perm, axis = self.args
|
||
|
|
||
|
deep = kwargs.get("deep", True)
|
||
|
|
||
|
if deep:
|
||
|
mat = mat.rewrite(MatMul)
|
||
|
|
||
|
if axis == 0:
|
||
|
return MatMul(PermutationMatrix(perm), mat)
|
||
|
elif axis == 1:
|
||
|
return MatMul(mat, PermutationMatrix(perm**-1))
|