From eef69d070e6af65bfe8ffb54f7d7da684d8236f0 Mon Sep 17 00:00:00 2001 From: yixin <2050485123@qq.com> Date: Wed, 25 Dec 2024 09:58:37 +0800 Subject: [PATCH] add comments for _extends\graph_kernel\expanders\batchnorm_grad.py --- .../graph_kernel/expanders/batchnorm_grad.py | 74 +++++++++++-------- 1 file changed, 44 insertions(+), 30 deletions(-) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py index eeb94ca1..2393ba90 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py @@ -11,46 +11,59 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =========================================================================== -"""generate json desc for BatchNormGrad""" +# ======================================================================== +# === +# 版权声明 +# 根据Apache License 2.0授权 +# 除非遵守许可,否则不得使用此文件 + +""" +为BatchNormGrad生成json描述,BatchNormGrad是用于计算Batch Normalization层梯度的类。 +""" + +# 导入必要的模块和类 from mindspore._extends.graph_kernel.model.model import DataFormat as DF from ._utils import Expander, ExpanderInfoValidator as VLD from .expand_dims import ExpandDims - +# 定义BatchNormGrad类,继承自Expander @VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.check_attrs('is_training', 'epsilon') class BatchNormGrad(Expander): - """BatchNormGrad expander""" + """BatchNormGrad扩展器,用于计算Batch Normalization层的梯度""" + # 定义扩展方法,该方法将被调用来执行BatchNormGrad的计算 def _expand(self, graph_builder): - # get op info - input_dy = self.inputs[0] - input_x = self.inputs[1] - input_scale = self.inputs[2] - input_save_mean = self.inputs[3] - input_save_inv_variance = self.inputs[4] + # 获取操作信息,包括梯度、输入数据、尺度、保存的均值和倒数方差 + input_dy = self.inputs[0] # 输入数据的梯度 + input_x = self.inputs[1] # 输入数据 + input_scale = self.inputs[2] # 输入数据的尺度 + input_save_mean = self.inputs[3] # 保存的均值 + input_save_inv_variance = self.inputs[4] # 保存的倒数方差 + # 根据输入数据的格式计算reduce_axis,用于后续的ReduceSum操作 reduce_axis = () shape_x = input_x.shape - if input_x.data_format == DF.NHWC: - reduce_axis = (0, 1, 2) - num = shape_x[0] * shape_x[1] * shape_x[2] - else: - reduce_axis = (0, 2, 3) - num = shape_x[0] * shape_x[2] * shape_x[3] - ori_type = input_x.dtype + if input_x.data_format == DF.NHWC: # 如果数据格式为NHWC + reduce_axis = (0, 1, 2) # 指定ReduceSum的轴 + num = shape_x[0] * shape_x[1] * shape_x[2] # 计算元素总数 + else: # 否则,假设数据格式为NCHW + reduce_axis = (0, 2, 3) # 指定ReduceSum的轴 + num = shape_x[0] * shape_x[2] * shape_x[3] # 计算元素总数 + ori_type = input_x.dtype # 原始数据类型 + + # 如果原始数据类型为float16,则转换为float32进行计算,以避免精度损失 if ori_type == 'float16': input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'}) if input_dy.dtype == 'float16': input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'}) - num_rec = -1.0 / num - num_rec_v = graph_builder.value(input_scale.dtype, num_rec) - dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) + num_rec = -1.0 / num # 计算倒数 + num_rec_v = graph_builder.value(input_scale.dtype, num_rec) # 创建倒数的值 + dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) # 计算dbeta,即beta的梯度 - # in training input_save_inv_variance means 1 / sqrt(variance + epsilon), which is calculated in forward pass + # 根据是否在训练中,计算inv_variance(倒数方差) if self.attrs['is_training']: inv_variance = input_save_inv_variance else: @@ -61,7 +74,7 @@ class BatchNormGrad(Expander): scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one) inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps]) - # compute dgamma + # 计算dgamma(gamma的梯度) if input_x.data_format in (DF.DEFAULT, DF.NCHW): input_save_mean = graph_builder.emit( 'Reshape', [input_save_mean], attrs={'shape': ExpandDims.infer_shape(input_save_mean.shape, [-1, -1])}) @@ -69,13 +82,13 @@ class BatchNormGrad(Expander): 'Reshape', [inv_variance], attrs={'shape': ExpandDims.infer_shape(inv_variance.shape, [-1, -1])}) input_scale = graph_builder.emit( 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])}) - x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean]) - x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance]) - dgamma_param = graph_builder.emit('Mul', [input_dy, x_div]) + x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean]) # 计算x减去均值 + x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance]) # 计算x除以倒数方差 + dgamma_param = graph_builder.emit('Mul', [input_dy, x_div]) # 计算dgamma参数 dgamma = graph_builder.emit( - 'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) + 'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) # 计算dgamma - # compute dx + # 计算dx(x的梯度) if self.attrs['is_training']: tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta]) if input_x.data_format in (DF.DEFAULT, DF.NCHW): @@ -95,11 +108,12 @@ class BatchNormGrad(Expander): y_scale = graph_builder.emit('Mul', [input_scale, input_dy]) dx = graph_builder.emit('Mul', [inv_variance, y_scale]) if ori_type == 'float16': - dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) + dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) # 如果原始数据类型为float16,则转换回float16 - # set output tensors' data_format + # 设置输出张量的数据格式 dx.data_format = self.outputs[0]['format'] dgamma.data_format = self.outputs[1]['format'] dbeta.data_format = self.outputs[2]['format'] - return dx, dgamma, dbeta + # 返回计算结果 + return dx, dgamma, dbeta \ No newline at end of file