diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/expand_dims.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/expand_dims.py index 7403f119..ff2b46e5 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/expand_dims.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/expand_dims.py @@ -21,6 +21,16 @@ class ExpandDims(Expander): """ExpandDims expander""" def _expand(self, graph_builder): + """ + 对输入数据进行维度扩展。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tensor: 扩展后的数据。 + + """ input_x = self.inputs[0] shape = self.infer_shape(input_x.shape, self.attrs['axis']) result = graph_builder.emit('Reshape', [input_x], attrs={'shape': shape}) @@ -29,8 +39,35 @@ class ExpandDims(Expander): @staticmethod def infer_shape(shape, axis): + """ + 根据给定的轴位置推断新的形状。 + + Args: + shape (list or tuple): 原始形状,表示一个多维数组的尺寸。 + axis (int, list or tuple): 指定要插入新维度的轴位置。如果为整数,表示在指定位置插入一个维度;如果为列表或元组,则按顺序在指定位置插入多个维度。 + + Returns: + list: 插入新维度后的新形状。 + + Raises: + ValueError: 如果axis的值或类型不符合要求时抛出。 + + """ """infer shape for expand_dims""" def insert_axis(shape, axis): + """ + 在指定轴上插入一个新的维度。 + + Args: + shape (list): 原始数组的形状,类型为列表。 + axis (int): 要插入新维度的轴的位置。 + + Returns: + list: 插入新维度后的数组形状。 + + Raises: + ValueError: 如果axis的类型不是int,或者axis的值不在合法范围内,将抛出异常。 + """ if not isinstance(axis, int) or axis > len(shape) or axis < -len(shape) - 1: raise ValueError("For 'ExpandDims', value of attr 'axis' should be of type int and in the range [{}, " "{}], but got {} with type {}".format(-len(shape) - 1, len(shape), axis, type(axis)))