From b0c766215598fd0ec9324a022777cc5314a09f7e Mon Sep 17 00:00:00 2001 From: yixin <2050485123@qq.com> Date: Wed, 25 Dec 2024 10:23:02 +0800 Subject: [PATCH] add comments for _extends\graph_kernel\expanders\batchnorm.py --- .../graph_kernel/expanders/batchnorm.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm.py index 799dc3f5..bd29f9a3 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm.py @@ -36,15 +36,19 @@ class BatchNorm(Expander): input_x_ori_type = input_x.dtype input_x_new_type = input_x.dtype + # 如果输入数据的类型为float16,而scale、offset、mean、variance的类型为float32,则将输入数据类型转换为float32 if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \ input_mean.dtype == "float32" and input_variance.dtype == "float32": input_x_new_type = "float32" + # 如果输入数据类型与原始类型不同,则进行类型转换 if input_x_new_type != input_x_ori_type: input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type}) + # 如果是训练模式 if self.attrs['is_training']: self.inputs[0] = input_x res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder) + # 如果输入数据类型与原始类型不同,则将输出数据类型转换为原始类型 if input_x_new_type != input_x_ori_type: res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type}) return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec @@ -70,21 +74,42 @@ class BatchNorm(Expander): return res_y, var_add, var_add, var_add, var_add def _bn_train(self, graph_builder): + """ + 在训练模式下扩展BatchNorm。 + + Args: + graph_builder (GraphBuilder): 图构建器实例。 + + Returns: + tuple: 包含以下内容的元组: + - res_y (Tensor): 归一化后的输出。 + - mean_res (Tensor): 更新后的移动均值。 + - variance_res (Tensor): 更新后的移动方差。 + - mean_muls (Tensor): 输入数据的均值。 + - y_sqrt_rec (Tensor): 1 / sqrt(方差 + epsilon),用于反向传播。 + + """ """expand BatchNorm for training mode""" + # 获取输入数据 input_x = self.inputs[0] input_scale = self.inputs[1] input_offset = self.inputs[2] input_mean = self.inputs[3] input_variance = self.inputs[4] + # 获取epsilon值 epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon']) + # 获取reduce轴 reduce_axis = () + # 获取输入数据的形状 shape_x = input_x.shape + # 根据输入数据的格式,设置reduce轴和num值 if input_x.data_format == DF.NHWC: reduce_axis = (0, 1, 2) num = shape_x[0] * shape_x[1] * shape_x[2] else: reduce_axis = (0, 2, 3) num = shape_x[0] * shape_x[2] * shape_x[3] + # 计算num的倒数 num_rec = 1.0 / num num_rec_v = graph_builder.value(input_scale.dtype, num_rec) @@ -112,41 +137,67 @@ class BatchNorm(Expander): y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt]) # compute res_y + # 计算输入x和mean_muls_expand的差值 tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand]) + # 如果输入x的数据格式为DF.DEFAULT或DF.NCHW,则对y_sqrt_rec进行reshape操作 if input_x.data_format in (DF.DEFAULT, DF.NCHW): y_sqrt_rec_expand = graph_builder.emit( 'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])}) + # 否则,y_sqrt_rec保持不变 else: y_sqrt_rec_expand = y_sqrt_rec + # 计算tmp_sub和y_sqrt_rec_expand的乘积 y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand]) + # 如果输入x的数据格式为DF.DEFAULT或DF.NCHW,则对input_scale进行reshape操作 if input_x.data_format in (DF.DEFAULT, DF.NCHW): input_scale_expand = graph_builder.emit( 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])}) + # 否则,input_scale保持不变 else: input_scale_expand = input_scale + # 计算input_scale_expand和y_norm的乘积 res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm]) + # 如果输入x的数据格式为DF.DEFAULT或DF.NCHW,则对input_offset进行reshape操作 if input_x.data_format in (DF.DEFAULT, DF.NCHW): input_offset_expand = graph_builder.emit( 'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])}) + # 否则,input_offset保持不变 else: input_offset_expand = input_offset + # 计算res_y_mul和input_offset_expand的和 res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand]) # compute mean_res + # 计算动量减去1的值 momentum_sub = scalar_one - self.attrs['momentum'] + # 将动量减去1的值转换为输入数据的类型 momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub) + # 计算新的移动平均值的临时值 + # 计算新的running_mean_tmp new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean]) + # 计算momentum_v momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum']) + # 计算current_mean_tmp current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls]) + # 计算updated_moving_mean updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp]) + # 将updated_moving_mean赋值给input_mean mean_res = graph_builder.emit('Assign', [input_mean, updated_moving_mean]) # variance_res is calculated by sample variance, and need to multiply by num / (num - 1) + # 计算方差 var_num = float(num) / (num - 1) + # 将方差转换为输入数据的类型 var_num_v = graph_builder.value(input_scale.dtype, var_num) + # 计算方差乘积 var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul]) + # 计算新的移动方差 new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance]) + # 计算当前移动方差 current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update]) + # 更新移动方差 updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp]) + # 将更新后的移动方差赋值给输入方差 variance_res = graph_builder.emit('Assign', [input_variance, updated_moving_variance]) + # 返回结果 return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec