diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/bias_add_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/bias_add_grad.py index 161f33c0..8e8394dd 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/bias_add_grad.py @@ -13,37 +13,64 @@ # limitations under the License. # =========================================================================== """generate json desc for bias_add""" +# 导入MindSpore的DataFormat类 from mindspore._extends.graph_kernel.model.model import DataFormat as DF +# 导入Expander和ExpanderInfoValidator类 from ._utils import Expander, ExpanderInfoValidator as VLD +# 为BiasAddGrad类添加DF.DEFAULT、DF.NHWC、DF.NCHW、DF.FRAC_NZ格式的验证 @VLD.add_format(DF.DEFAULT) @VLD.add_format(DF.NHWC) @VLD.add_format(DF.NCHW) @VLD.add_format(DF.FRAC_NZ) +# 定义BiasAddGrad类,继承自Expander类 class BiasAddGrad(Expander): """BiasAddGrad expander""" def _expand(self, graph_builder): + """ + 内部方法,用于扩展输入张量的维度。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于生成图操作。 + + Returns: + Tensor: 扩展维度后的张量。 + + """ + # 获取输入张量 x = self.inputs[0] + # 定义reduce_axis,用于指定求和的维度 reduce_axis = () + # 如果输入张量的数据格式为NHWC,则reduce_axis为(0, 1, 2) if x.data_format == DF.NHWC: reduce_axis = (0, 1, 2) + # 如果输入张量的数据格式为NCHW,则reduce_axis为(0, 2, 3) elif x.data_format == DF.NCHW: reduce_axis = (0, 2, 3) + # 如果输入张量的数据格式为FRAC_NZ,则reduce_axis为(-2, -3) elif x.data_format == DF.FRAC_NZ: reduce_axis = (-2, -3) + # 如果输入张量的数据格式为DefaultFormat,则根据shape的长度确定reduce_axis else: # DefaultFormat shape's length should be from 2 to 4 + # 如果x的维度为2,则reduce_axis为(0,) if len(x.shape) == 2: reduce_axis = (0,) + # 如果x的维度为3,则reduce_axis为(0, 1) elif len(x.shape) == 3: reduce_axis = (0, 1) + # 否则,reduce_axis为(0, 2, 3) else: reduce_axis = (0, 2, 3) + # 发射ReduceSum操作,计算x的reduce_sum,reduce_axis为reduce_axis,keep_dims为False result = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) + # 如果x的数据格式为DF.FRAC_NZ,则将result的shape改为x.shape[:-4] + [x.shape[-1] * x.shape[-4]] if x.data_format == DF.FRAC_NZ: out_shape = x.shape[:-4] + [x.shape[-1] * x.shape[-4]] + # 发射Reshape操作,将result的shape改为out_shape result = graph_builder.emit('Reshape', [result], attrs={'shape': out_shape}) + # 返回result return result