You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

115 lines
3.3 KiB

from sympy.matrices.expressions.matexpr import MatrixExpr
from sympy.core.basic import Basic
from sympy.core.containers import Tuple
from sympy.functions.elementary.integers import floor
def normalize(i, parentsize):
if isinstance(i, slice):
i = (i.start, i.stop, i.step)
if not isinstance(i, (tuple, list, Tuple)):
if (i < 0) == True:
i += parentsize
i = (i, i+1, 1)
i = list(i)
if len(i) == 2:
i.append(1)
start, stop, step = i
start = start or 0
if stop is None:
stop = parentsize
if (start < 0) == True:
start += parentsize
if (stop < 0) == True:
stop += parentsize
step = step or 1
if ((stop - start) * step < 1) == True:
raise IndexError()
return (start, stop, step)
class MatrixSlice(MatrixExpr):
""" A MatrixSlice of a Matrix Expression
Examples
========
>>> from sympy import MatrixSlice, ImmutableMatrix
>>> M = ImmutableMatrix(4, 4, range(16))
>>> M
Matrix([
[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
>>> B = MatrixSlice(M, (0, 2), (2, 4))
>>> ImmutableMatrix(B)
Matrix([
[2, 3],
[6, 7]])
"""
parent = property(lambda self: self.args[0])
rowslice = property(lambda self: self.args[1])
colslice = property(lambda self: self.args[2])
def __new__(cls, parent, rowslice, colslice):
rowslice = normalize(rowslice, parent.shape[0])
colslice = normalize(colslice, parent.shape[1])
if not (len(rowslice) == len(colslice) == 3):
raise IndexError()
if ((0 > rowslice[0]) == True or
(parent.shape[0] < rowslice[1]) == True or
(0 > colslice[0]) == True or
(parent.shape[1] < colslice[1]) == True):
raise IndexError()
if isinstance(parent, MatrixSlice):
return mat_slice_of_slice(parent, rowslice, colslice)
return Basic.__new__(cls, parent, Tuple(*rowslice), Tuple(*colslice))
@property
def shape(self):
rows = self.rowslice[1] - self.rowslice[0]
rows = rows if self.rowslice[2] == 1 else floor(rows/self.rowslice[2])
cols = self.colslice[1] - self.colslice[0]
cols = cols if self.colslice[2] == 1 else floor(cols/self.colslice[2])
return rows, cols
def _entry(self, i, j, **kwargs):
return self.parent._entry(i*self.rowslice[2] + self.rowslice[0],
j*self.colslice[2] + self.colslice[0],
**kwargs)
@property
def on_diag(self):
return self.rowslice == self.colslice
def slice_of_slice(s, t):
start1, stop1, step1 = s
start2, stop2, step2 = t
start = start1 + start2*step1
step = step1 * step2
stop = start1 + step1*stop2
if stop > stop1:
raise IndexError()
return start, stop, step
def mat_slice_of_slice(parent, rowslice, colslice):
""" Collapse nested matrix slices
>>> from sympy import MatrixSymbol
>>> X = MatrixSymbol('X', 10, 10)
>>> X[:, 1:5][5:8, :]
X[5:8, 1:5]
>>> X[1:9:2, 2:6][1:3, 2]
X[3:7:2, 4:5]
"""
row = slice_of_slice(parent.rowslice, rowslice)
col = slice_of_slice(parent.colslice, colslice)
return MatrixSlice(parent.parent, row, col)