|
|
|
@ -23,13 +23,33 @@ class LayerNormGrad(Expander):
|
|
|
|
|
"""LayerNormGrad expander"""
|
|
|
|
|
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
"""
|
|
|
|
|
对输入进行扩展操作。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
graph_builder (GraphBuilder): 图构建器对象。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
tuple: 包含dx, dg, db的元组。
|
|
|
|
|
dx (Tensor): 梯度相对于输入x的导数。
|
|
|
|
|
dg (Tensor): 梯度相对于gamma的导数。
|
|
|
|
|
db (Tensor): 梯度相对于beta的导数。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# 获取输入参数
|
|
|
|
|
x, dy, variance, mean, gamma = self.inputs
|
|
|
|
|
# 获取处理器类型
|
|
|
|
|
processor = self.processor
|
|
|
|
|
# 获取归一化轴的起始位置
|
|
|
|
|
begin_norm_axis = self.attrs['begin_norm_axis']
|
|
|
|
|
# 获取参数轴的起始位置
|
|
|
|
|
begin_params_axis = self.attrs['begin_params_axis']
|
|
|
|
|
# 获取epsilon值,默认为1e-12
|
|
|
|
|
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-12
|
|
|
|
|
|
|
|
|
|
# 获取输入数据的原始数据类型
|
|
|
|
|
ori_dtype = x.dtype
|
|
|
|
|
# 如果处理器类型为aicore且数据类型为float16,则将输入数据转换为float32
|
|
|
|
|
if processor == 'aicore' and ori_dtype == 'float16':
|
|
|
|
|
x = graph_builder.emit('Cast', [x], attrs={'dst_type': 'float32'})
|
|
|
|
|
dy = graph_builder.emit('Cast', [dy], attrs={'dst_type': 'float32'})
|
|
|
|
@ -37,77 +57,121 @@ class LayerNormGrad(Expander):
|
|
|
|
|
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float32'})
|
|
|
|
|
gamma = graph_builder.emit('Cast', [gamma], attrs={'dst_type': 'float32'})
|
|
|
|
|
|
|
|
|
|
# 如果归一化轴的起始位置小于0,则将其转换为正数
|
|
|
|
|
if begin_norm_axis < 0:
|
|
|
|
|
begin_norm_axis += len(x.shape)
|
|
|
|
|
# 如果参数轴的起始位置小于0,则将其转换为正数
|
|
|
|
|
if begin_params_axis < 0:
|
|
|
|
|
begin_params_axis += len(x.shape)
|
|
|
|
|
|
|
|
|
|
# 获取归一化轴和参数轴的范围
|
|
|
|
|
norm_axis = tuple(range(begin_norm_axis, len(x.shape)))
|
|
|
|
|
param_axis = tuple(range(0, begin_params_axis))
|
|
|
|
|
|
|
|
|
|
# 计算归一化轴的维度乘积
|
|
|
|
|
reduce_size = 1.0
|
|
|
|
|
for i in norm_axis:
|
|
|
|
|
reduce_size *= x.shape[i]
|
|
|
|
|
|
|
|
|
|
# set some constant val.
|
|
|
|
|
# 计算epsilon的值
|
|
|
|
|
eps = graph_builder.value(x.dtype, epsilon)
|
|
|
|
|
# 计算-0.5的值
|
|
|
|
|
const_neg_half = graph_builder.value(x.dtype, -0.5)
|
|
|
|
|
# 计算-2.0的值
|
|
|
|
|
const_neg_two = graph_builder.value(x.dtype, -2.0)
|
|
|
|
|
# 计算2.0的值
|
|
|
|
|
const_two = graph_builder.value(x.dtype, 2.0)
|
|
|
|
|
# 计算-1.0的值
|
|
|
|
|
const_neg_one = graph_builder.value(x.dtype, -1.0)
|
|
|
|
|
# 计算mean_cof的值
|
|
|
|
|
mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size))
|
|
|
|
|
|
|
|
|
|
# cal dg db
|
|
|
|
|
# 计算方差和eps的和
|
|
|
|
|
var_eps = graph_builder.emit('Add', [variance, eps])
|
|
|
|
|
# 计算方差和eps的和的对数
|
|
|
|
|
var_eps_log = graph_builder.emit('Log', [var_eps])
|
|
|
|
|
# 计算方差和eps的对数乘以-0.5
|
|
|
|
|
var_eps_mul = graph_builder.emit('Mul', [var_eps_log, const_neg_half])
|
|
|
|
|
# 计算方差和eps的对数乘以-0.5的指数
|
|
|
|
|
rsqrt_var_eps = graph_builder.emit('Exp', [var_eps_mul])
|
|
|
|
|
|
|
|
|
|
# 计算x和mean的差
|
|
|
|
|
# 计算输入x减去均值
|
|
|
|
|
x_sub_mean = graph_builder.emit('Sub', [x, mean])
|
|
|
|
|
|
|
|
|
|
# 计算x减去均值乘以rsqrt_var_eps
|
|
|
|
|
x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean])
|
|
|
|
|
# 计算dy乘以x减去均值乘以rsqrt_var_eps
|
|
|
|
|
dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps])
|
|
|
|
|
# 计算dg,对dg_mul进行求和,reduce_axis为param_axis,keep_dims为False
|
|
|
|
|
dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
|
|
|
|
# 计算db,对dy进行求和,reduce_axis为param_axis,keep_dims为False
|
|
|
|
|
db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
|
|
|
|
|
|
|
|
|
# pd_var
|
|
|
|
|
# 计算tmp_var_eps
|
|
|
|
|
tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, rsqrt_var_eps])
|
|
|
|
|
# 计算r_tmp_var_eps
|
|
|
|
|
r_tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, tmp_var_eps])
|
|
|
|
|
|
|
|
|
|
# 计算dy_mul_gamma
|
|
|
|
|
dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma])
|
|
|
|
|
# 计算tmp_mul
|
|
|
|
|
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean])
|
|
|
|
|
# 计算padvar_mul1
|
|
|
|
|
padvar_mul1 = graph_builder.emit('ReduceSum', [tmp_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
|
|
|
|
# 计算padvar_mul3
|
|
|
|
|
padvar_mul3 = graph_builder.emit('Mul', [padvar_mul1, r_tmp_var_eps])
|
|
|
|
|
# 计算pd_var
|
|
|
|
|
pd_var = graph_builder.emit('Mul', [padvar_mul3, const_neg_half])
|
|
|
|
|
|
|
|
|
|
# pd_mean
|
|
|
|
|
# 计算pdmean1_sum,使用ReduceSum函数,输入为dy_mul_gamma,归约轴为norm_axis,保持维度为True
|
|
|
|
|
pdmean1_sum = graph_builder.emit('ReduceSum', [dy_mul_gamma],
|
|
|
|
|
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
|
|
|
|
# 计算neg_rsqrt_var_eps,使用Mul函数,输入为rsqrt_var_eps和const_neg_one
|
|
|
|
|
neg_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, const_neg_one])
|
|
|
|
|
# 计算pd_mean_1,使用Mul函数,输入为neg_rsqrt_var_eps和pdmean1_sum
|
|
|
|
|
pd_mean_1 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, pdmean1_sum])
|
|
|
|
|
|
|
|
|
|
# 计算pdmean2_mul1,使用Mul函数,输入为const_neg_two和x_sub_mean
|
|
|
|
|
pdmean2_mul1 = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
|
|
|
|
|
# 计算pdmean2_sum,使用ReduceSum函数,输入为pdmean2_mul1,归约轴为norm_axis,保持维度为True
|
|
|
|
|
pdmean2_sum = graph_builder.emit('ReduceSum', [pdmean2_mul1],
|
|
|
|
|
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
|
|
|
|
# 计算pdmean2_mul3,使用Mul函数,输入为pdmean2_sum和mean_cof
|
|
|
|
|
pdmean2_mul3 = graph_builder.emit('Mul', [pdmean2_sum, mean_cof])
|
|
|
|
|
# 计算pd_mean_2,使用Mul函数,输入为pdmean2_mul3和pd_var
|
|
|
|
|
pd_mean_2 = graph_builder.emit('Mul', [pdmean2_mul3, pd_var])
|
|
|
|
|
|
|
|
|
|
# 计算pd_mean,使用Add函数,输入为pd_mean_1和pd_mean_2
|
|
|
|
|
pd_mean = graph_builder.emit('Add', [pd_mean_1, pd_mean_2])
|
|
|
|
|
|
|
|
|
|
# cal dx
|
|
|
|
|
# 计算pd_x_1
|
|
|
|
|
pd_x_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
|
|
|
|
|
|
|
|
|
|
# 计算pdx2_mul
|
|
|
|
|
pdx2_mul = graph_builder.emit('Mul', [pd_var, x_sub_mean])
|
|
|
|
|
# 计算pdx2_mul_two
|
|
|
|
|
pdx2_mul_two = graph_builder.emit('Mul', [pdx2_mul, const_two])
|
|
|
|
|
# 计算pd_x_2
|
|
|
|
|
pd_x_2 = graph_builder.emit('Mul', [pdx2_mul_two, mean_cof])
|
|
|
|
|
|
|
|
|
|
# 计算pd_x_3
|
|
|
|
|
pd_x_3 = graph_builder.emit('Mul', [pd_mean, mean_cof])
|
|
|
|
|
|
|
|
|
|
# 计算dx_tmp
|
|
|
|
|
dx_tmp = graph_builder.emit('Add', [pd_x_1, pd_x_2])
|
|
|
|
|
# 计算dx
|
|
|
|
|
dx = graph_builder.emit('Add', [dx_tmp, pd_x_3])
|
|
|
|
|
|
|
|
|
|
# 如果处理器为aicore且原始数据类型为float16,则将dx、dg、db转换为float16
|
|
|
|
|
if processor == 'aicore' and ori_dtype == 'float16':
|
|
|
|
|
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
|
|
|
|
|
dg = graph_builder.emit('Cast', [dg], attrs={'dst_type': 'float16'})
|
|
|
|
|
db = graph_builder.emit('Cast', [db], attrs={'dst_type': 'float16'})
|
|
|
|
|
# 返回dx、dg、db
|
|
|
|
|
return dx, dg, db
|
|
|
|
|