|
|
|
@ -36,15 +36,19 @@ class BatchNorm(Expander):
|
|
|
|
|
|
|
|
|
|
input_x_ori_type = input_x.dtype
|
|
|
|
|
input_x_new_type = input_x.dtype
|
|
|
|
|
# 如果输入数据的类型为float16,而scale、offset、mean、variance的类型为float32,则将输入数据类型转换为float32
|
|
|
|
|
if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \
|
|
|
|
|
input_mean.dtype == "float32" and input_variance.dtype == "float32":
|
|
|
|
|
input_x_new_type = "float32"
|
|
|
|
|
# 如果输入数据类型与原始类型不同,则进行类型转换
|
|
|
|
|
if input_x_new_type != input_x_ori_type:
|
|
|
|
|
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
|
|
|
|
|
|
|
|
|
|
# 如果是训练模式
|
|
|
|
|
if self.attrs['is_training']:
|
|
|
|
|
self.inputs[0] = input_x
|
|
|
|
|
res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder)
|
|
|
|
|
# 如果输入数据类型与原始类型不同,则将输出数据类型转换为原始类型
|
|
|
|
|
if input_x_new_type != input_x_ori_type:
|
|
|
|
|
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
|
|
|
|
|
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
|
|
|
|
@ -70,21 +74,42 @@ class BatchNorm(Expander):
|
|
|
|
|
return res_y, var_add, var_add, var_add, var_add
|
|
|
|
|
|
|
|
|
|
def _bn_train(self, graph_builder):
|
|
|
|
|
"""
|
|
|
|
|
在训练模式下扩展BatchNorm。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
graph_builder (GraphBuilder): 图构建器实例。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
tuple: 包含以下内容的元组:
|
|
|
|
|
- res_y (Tensor): 归一化后的输出。
|
|
|
|
|
- mean_res (Tensor): 更新后的移动均值。
|
|
|
|
|
- variance_res (Tensor): 更新后的移动方差。
|
|
|
|
|
- mean_muls (Tensor): 输入数据的均值。
|
|
|
|
|
- y_sqrt_rec (Tensor): 1 / sqrt(方差 + epsilon),用于反向传播。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""expand BatchNorm for training mode"""
|
|
|
|
|
# 获取输入数据
|
|
|
|
|
input_x = self.inputs[0]
|
|
|
|
|
input_scale = self.inputs[1]
|
|
|
|
|
input_offset = self.inputs[2]
|
|
|
|
|
input_mean = self.inputs[3]
|
|
|
|
|
input_variance = self.inputs[4]
|
|
|
|
|
# 获取epsilon值
|
|
|
|
|
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
|
|
|
|
|
# 获取reduce轴
|
|
|
|
|
reduce_axis = ()
|
|
|
|
|
# 获取输入数据的形状
|
|
|
|
|
shape_x = input_x.shape
|
|
|
|
|
# 根据输入数据的格式,设置reduce轴和num值
|
|
|
|
|
if input_x.data_format == DF.NHWC:
|
|
|
|
|
reduce_axis = (0, 1, 2)
|
|
|
|
|
num = shape_x[0] * shape_x[1] * shape_x[2]
|
|
|
|
|
else:
|
|
|
|
|
reduce_axis = (0, 2, 3)
|
|
|
|
|
num = shape_x[0] * shape_x[2] * shape_x[3]
|
|
|
|
|
# 计算num的倒数
|
|
|
|
|
num_rec = 1.0 / num
|
|
|
|
|
num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
|
|
|
|
|
|
|
|
|
@ -112,41 +137,67 @@ class BatchNorm(Expander):
|
|
|
|
|
y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
|
|
|
|
|
|
|
|
|
|
# compute res_y
|
|
|
|
|
# 计算输入x和mean_muls_expand的差值
|
|
|
|
|
tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
|
|
|
|
|
# 如果输入x的数据格式为DF.DEFAULT或DF.NCHW,则对y_sqrt_rec进行reshape操作
|
|
|
|
|
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
|
|
|
y_sqrt_rec_expand = graph_builder.emit(
|
|
|
|
|
'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])})
|
|
|
|
|
# 否则,y_sqrt_rec保持不变
|
|
|
|
|
else:
|
|
|
|
|
y_sqrt_rec_expand = y_sqrt_rec
|
|
|
|
|
# 计算tmp_sub和y_sqrt_rec_expand的乘积
|
|
|
|
|
y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
|
|
|
|
|
# 如果输入x的数据格式为DF.DEFAULT或DF.NCHW,则对input_scale进行reshape操作
|
|
|
|
|
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
|
|
|
input_scale_expand = graph_builder.emit(
|
|
|
|
|
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
|
|
|
|
|
# 否则,input_scale保持不变
|
|
|
|
|
else:
|
|
|
|
|
input_scale_expand = input_scale
|
|
|
|
|
# 计算input_scale_expand和y_norm的乘积
|
|
|
|
|
res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
|
|
|
|
|
# 如果输入x的数据格式为DF.DEFAULT或DF.NCHW,则对input_offset进行reshape操作
|
|
|
|
|
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
|
|
|
input_offset_expand = graph_builder.emit(
|
|
|
|
|
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
|
|
|
|
|
# 否则,input_offset保持不变
|
|
|
|
|
else:
|
|
|
|
|
input_offset_expand = input_offset
|
|
|
|
|
# 计算res_y_mul和input_offset_expand的和
|
|
|
|
|
res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
|
|
|
|
|
|
|
|
|
|
# compute mean_res
|
|
|
|
|
# 计算动量减去1的值
|
|
|
|
|
momentum_sub = scalar_one - self.attrs['momentum']
|
|
|
|
|
# 将动量减去1的值转换为输入数据的类型
|
|
|
|
|
momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub)
|
|
|
|
|
# 计算新的移动平均值的临时值
|
|
|
|
|
# 计算新的running_mean_tmp
|
|
|
|
|
new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
|
|
|
|
|
# 计算momentum_v
|
|
|
|
|
momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'])
|
|
|
|
|
# 计算current_mean_tmp
|
|
|
|
|
current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
|
|
|
|
|
# 计算updated_moving_mean
|
|
|
|
|
updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
|
|
|
|
|
# 将updated_moving_mean赋值给input_mean
|
|
|
|
|
mean_res = graph_builder.emit('Assign', [input_mean, updated_moving_mean])
|
|
|
|
|
|
|
|
|
|
# variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
|
|
|
|
|
# 计算方差
|
|
|
|
|
var_num = float(num) / (num - 1)
|
|
|
|
|
# 将方差转换为输入数据的类型
|
|
|
|
|
var_num_v = graph_builder.value(input_scale.dtype, var_num)
|
|
|
|
|
# 计算方差乘积
|
|
|
|
|
var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
|
|
|
|
|
# 计算新的移动方差
|
|
|
|
|
new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
|
|
|
|
|
# 计算当前移动方差
|
|
|
|
|
current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
|
|
|
|
|
# 更新移动方差
|
|
|
|
|
updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
|
|
|
|
|
# 将更新后的移动方差赋值给输入方差
|
|
|
|
|
variance_res = graph_builder.emit('Assign', [input_variance, updated_moving_variance])
|
|
|
|
|
# 返回结果
|
|
|
|
|
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
|
|
|
|
|