|
|
|
@ -21,6 +21,16 @@ class ExpandDims(Expander):
|
|
|
|
|
"""ExpandDims expander"""
|
|
|
|
|
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
"""
|
|
|
|
|
对输入数据进行维度扩展。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
graph_builder (GraphBuilder): 图构建器对象。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: 扩展后的数据。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
input_x = self.inputs[0]
|
|
|
|
|
shape = self.infer_shape(input_x.shape, self.attrs['axis'])
|
|
|
|
|
result = graph_builder.emit('Reshape', [input_x], attrs={'shape': shape})
|
|
|
|
@ -29,8 +39,35 @@ class ExpandDims(Expander):
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def infer_shape(shape, axis):
|
|
|
|
|
"""
|
|
|
|
|
根据给定的轴位置推断新的形状。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
shape (list or tuple): 原始形状,表示一个多维数组的尺寸。
|
|
|
|
|
axis (int, list or tuple): 指定要插入新维度的轴位置。如果为整数,表示在指定位置插入一个维度;如果为列表或元组,则按顺序在指定位置插入多个维度。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list: 插入新维度后的新形状。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: 如果axis的值或类型不符合要求时抛出。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""infer shape for expand_dims"""
|
|
|
|
|
def insert_axis(shape, axis):
|
|
|
|
|
"""
|
|
|
|
|
在指定轴上插入一个新的维度。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
shape (list): 原始数组的形状,类型为列表。
|
|
|
|
|
axis (int): 要插入新维度的轴的位置。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list: 插入新维度后的数组形状。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: 如果axis的类型不是int,或者axis的值不在合法范围内,将抛出异常。
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(axis, int) or axis > len(shape) or axis < -len(shape) - 1:
|
|
|
|
|
raise ValueError("For 'ExpandDims', value of attr 'axis' should be of type int and in the range [{}, "
|
|
|
|
|
"{}], but got {} with type {}".format(-len(shape) - 1, len(shape), axis, type(axis)))
|
|
|
|
|