_extends\graph_kernel\expanders\expand_dims.py

branch-yixin
yixin 2 months ago
parent 979f67f6fa
commit 60f231dc1c

@ -21,6 +21,16 @@ class ExpandDims(Expander):
"""ExpandDims expander"""
def _expand(self, graph_builder):
"""
对输入数据进行维度扩展
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展后的数据
"""
input_x = self.inputs[0]
shape = self.infer_shape(input_x.shape, self.attrs['axis'])
result = graph_builder.emit('Reshape', [input_x], attrs={'shape': shape})
@ -29,8 +39,35 @@ class ExpandDims(Expander):
@staticmethod
def infer_shape(shape, axis):
"""
根据给定的轴位置推断新的形状
Args:
shape (list or tuple): 原始形状表示一个多维数组的尺寸
axis (int, list or tuple): 指定要插入新维度的轴位置如果为整数表示在指定位置插入一个维度如果为列表或元组则按顺序在指定位置插入多个维度
Returns:
list: 插入新维度后的新形状
Raises:
ValueError: 如果axis的值或类型不符合要求时抛出
"""
"""infer shape for expand_dims"""
def insert_axis(shape, axis):
"""
在指定轴上插入一个新的维度
Args:
shape (list): 原始数组的形状类型为列表
axis (int): 要插入新维度的轴的位置
Returns:
list: 插入新维度后的数组形状
Raises:
ValueError: 如果axis的类型不是int或者axis的值不在合法范围内将抛出异常
"""
if not isinstance(axis, int) or axis > len(shape) or axis < -len(shape) - 1:
raise ValueError("For 'ExpandDims', value of attr 'axis' should be of type int and in the range [{}, "
"{}], but got {} with type {}".format(-len(shape) - 1, len(shape), axis, type(axis)))

Loading…
Cancel
Save