add comments for _extends\graph_kernel\expanders\batchnorm.py

branch-yixin
yixin 2 months ago
parent eef69d070e
commit b0c7662155

@ -36,15 +36,19 @@ class BatchNorm(Expander):
input_x_ori_type = input_x.dtype
input_x_new_type = input_x.dtype
# 如果输入数据的类型为float16而scale、offset、mean、variance的类型为float32则将输入数据类型转换为float32
if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \
input_mean.dtype == "float32" and input_variance.dtype == "float32":
input_x_new_type = "float32"
# 如果输入数据类型与原始类型不同,则进行类型转换
if input_x_new_type != input_x_ori_type:
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
# 如果是训练模式
if self.attrs['is_training']:
self.inputs[0] = input_x
res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder)
# 如果输入数据类型与原始类型不同,则将输出数据类型转换为原始类型
if input_x_new_type != input_x_ori_type:
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
@ -70,21 +74,42 @@ class BatchNorm(Expander):
return res_y, var_add, var_add, var_add, var_add
def _bn_train(self, graph_builder):
"""
在训练模式下扩展BatchNorm
Args:
graph_builder (GraphBuilder): 图构建器实例
Returns:
tuple: 包含以下内容的元组:
- res_y (Tensor): 归一化后的输出
- mean_res (Tensor): 更新后的移动均值
- variance_res (Tensor): 更新后的移动方差
- mean_muls (Tensor): 输入数据的均值
- y_sqrt_rec (Tensor): 1 / sqrt(方差 + epsilon)用于反向传播
"""
"""expand BatchNorm for training mode"""
# 获取输入数据
input_x = self.inputs[0]
input_scale = self.inputs[1]
input_offset = self.inputs[2]
input_mean = self.inputs[3]
input_variance = self.inputs[4]
# 获取epsilon值
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
# 获取reduce轴
reduce_axis = ()
# 获取输入数据的形状
shape_x = input_x.shape
# 根据输入数据的格式设置reduce轴和num值
if input_x.data_format == DF.NHWC:
reduce_axis = (0, 1, 2)
num = shape_x[0] * shape_x[1] * shape_x[2]
else:
reduce_axis = (0, 2, 3)
num = shape_x[0] * shape_x[2] * shape_x[3]
# 计算num的倒数
num_rec = 1.0 / num
num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
@ -112,41 +137,67 @@ class BatchNorm(Expander):
y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
# compute res_y
# 计算输入x和mean_muls_expand的差值
tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
# 如果输入x的数据格式为DF.DEFAULT或DF.NCHW则对y_sqrt_rec进行reshape操作
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
y_sqrt_rec_expand = graph_builder.emit(
'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])})
# 否则y_sqrt_rec保持不变
else:
y_sqrt_rec_expand = y_sqrt_rec
# 计算tmp_sub和y_sqrt_rec_expand的乘积
y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
# 如果输入x的数据格式为DF.DEFAULT或DF.NCHW则对input_scale进行reshape操作
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_scale_expand = graph_builder.emit(
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
# 否则input_scale保持不变
else:
input_scale_expand = input_scale
# 计算input_scale_expand和y_norm的乘积
res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
# 如果输入x的数据格式为DF.DEFAULT或DF.NCHW则对input_offset进行reshape操作
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_offset_expand = graph_builder.emit(
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
# 否则input_offset保持不变
else:
input_offset_expand = input_offset
# 计算res_y_mul和input_offset_expand的和
res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
# compute mean_res
# 计算动量减去1的值
momentum_sub = scalar_one - self.attrs['momentum']
# 将动量减去1的值转换为输入数据的类型
momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub)
# 计算新的移动平均值的临时值
# 计算新的running_mean_tmp
new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
# 计算momentum_v
momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'])
# 计算current_mean_tmp
current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
# 计算updated_moving_mean
updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
# 将updated_moving_mean赋值给input_mean
mean_res = graph_builder.emit('Assign', [input_mean, updated_moving_mean])
# variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
# 计算方差
var_num = float(num) / (num - 1)
# 将方差转换为输入数据的类型
var_num_v = graph_builder.value(input_scale.dtype, var_num)
# 计算方差乘积
var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
# 计算新的移动方差
new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
# 计算当前移动方差
current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
# 更新移动方差
updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
# 将更新后的移动方差赋值给输入方差
variance_res = graph_builder.emit('Assign', [input_variance, updated_moving_variance])
# 返回结果
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec

Loading…
Cancel
Save