You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
115 lines
3.3 KiB
115 lines
3.3 KiB
5 months ago
|
from sympy.matrices.expressions.matexpr import MatrixExpr
|
||
|
from sympy.core.basic import Basic
|
||
|
from sympy.core.containers import Tuple
|
||
|
from sympy.functions.elementary.integers import floor
|
||
|
|
||
|
def normalize(i, parentsize):
|
||
|
if isinstance(i, slice):
|
||
|
i = (i.start, i.stop, i.step)
|
||
|
if not isinstance(i, (tuple, list, Tuple)):
|
||
|
if (i < 0) == True:
|
||
|
i += parentsize
|
||
|
i = (i, i+1, 1)
|
||
|
i = list(i)
|
||
|
if len(i) == 2:
|
||
|
i.append(1)
|
||
|
start, stop, step = i
|
||
|
start = start or 0
|
||
|
if stop is None:
|
||
|
stop = parentsize
|
||
|
if (start < 0) == True:
|
||
|
start += parentsize
|
||
|
if (stop < 0) == True:
|
||
|
stop += parentsize
|
||
|
step = step or 1
|
||
|
|
||
|
if ((stop - start) * step < 1) == True:
|
||
|
raise IndexError()
|
||
|
|
||
|
return (start, stop, step)
|
||
|
|
||
|
class MatrixSlice(MatrixExpr):
|
||
|
""" A MatrixSlice of a Matrix Expression
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import MatrixSlice, ImmutableMatrix
|
||
|
>>> M = ImmutableMatrix(4, 4, range(16))
|
||
|
>>> M
|
||
|
Matrix([
|
||
|
[ 0, 1, 2, 3],
|
||
|
[ 4, 5, 6, 7],
|
||
|
[ 8, 9, 10, 11],
|
||
|
[12, 13, 14, 15]])
|
||
|
|
||
|
>>> B = MatrixSlice(M, (0, 2), (2, 4))
|
||
|
>>> ImmutableMatrix(B)
|
||
|
Matrix([
|
||
|
[2, 3],
|
||
|
[6, 7]])
|
||
|
"""
|
||
|
parent = property(lambda self: self.args[0])
|
||
|
rowslice = property(lambda self: self.args[1])
|
||
|
colslice = property(lambda self: self.args[2])
|
||
|
|
||
|
def __new__(cls, parent, rowslice, colslice):
|
||
|
rowslice = normalize(rowslice, parent.shape[0])
|
||
|
colslice = normalize(colslice, parent.shape[1])
|
||
|
if not (len(rowslice) == len(colslice) == 3):
|
||
|
raise IndexError()
|
||
|
if ((0 > rowslice[0]) == True or
|
||
|
(parent.shape[0] < rowslice[1]) == True or
|
||
|
(0 > colslice[0]) == True or
|
||
|
(parent.shape[1] < colslice[1]) == True):
|
||
|
raise IndexError()
|
||
|
if isinstance(parent, MatrixSlice):
|
||
|
return mat_slice_of_slice(parent, rowslice, colslice)
|
||
|
return Basic.__new__(cls, parent, Tuple(*rowslice), Tuple(*colslice))
|
||
|
|
||
|
@property
|
||
|
def shape(self):
|
||
|
rows = self.rowslice[1] - self.rowslice[0]
|
||
|
rows = rows if self.rowslice[2] == 1 else floor(rows/self.rowslice[2])
|
||
|
cols = self.colslice[1] - self.colslice[0]
|
||
|
cols = cols if self.colslice[2] == 1 else floor(cols/self.colslice[2])
|
||
|
return rows, cols
|
||
|
|
||
|
def _entry(self, i, j, **kwargs):
|
||
|
return self.parent._entry(i*self.rowslice[2] + self.rowslice[0],
|
||
|
j*self.colslice[2] + self.colslice[0],
|
||
|
**kwargs)
|
||
|
|
||
|
@property
|
||
|
def on_diag(self):
|
||
|
return self.rowslice == self.colslice
|
||
|
|
||
|
|
||
|
def slice_of_slice(s, t):
|
||
|
start1, stop1, step1 = s
|
||
|
start2, stop2, step2 = t
|
||
|
|
||
|
start = start1 + start2*step1
|
||
|
step = step1 * step2
|
||
|
stop = start1 + step1*stop2
|
||
|
|
||
|
if stop > stop1:
|
||
|
raise IndexError()
|
||
|
|
||
|
return start, stop, step
|
||
|
|
||
|
|
||
|
def mat_slice_of_slice(parent, rowslice, colslice):
|
||
|
""" Collapse nested matrix slices
|
||
|
|
||
|
>>> from sympy import MatrixSymbol
|
||
|
>>> X = MatrixSymbol('X', 10, 10)
|
||
|
>>> X[:, 1:5][5:8, :]
|
||
|
X[5:8, 1:5]
|
||
|
>>> X[1:9:2, 2:6][1:3, 2]
|
||
|
X[3:7:2, 4:5]
|
||
|
"""
|
||
|
row = slice_of_slice(parent.rowslice, rowslice)
|
||
|
col = slice_of_slice(parent.colslice, colslice)
|
||
|
return MatrixSlice(parent.parent, row, col)
|