|
|
|
@ -21,24 +21,51 @@ class FusedAdam(Expander):
|
|
|
|
|
"""FusedAdam expander"""
|
|
|
|
|
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
"""
|
|
|
|
|
使用图构建器对模型参数进行更新。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
graph_builder (GraphBuilder): 图构建器实例,用于生成计算图。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: 更新后的参数结果。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# 获取输入参数
|
|
|
|
|
beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient = self.inputs
|
|
|
|
|
|
|
|
|
|
# 计算beta_1乘以m
|
|
|
|
|
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
|
|
|
|
|
# 计算one_sub_beta_1乘以gradient
|
|
|
|
|
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
|
|
|
|
|
# 计算next_m
|
|
|
|
|
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
|
|
|
|
|
# 计算beta_2乘以v
|
|
|
|
|
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
|
|
|
|
|
# 计算gradient的平方
|
|
|
|
|
grad_square = graph_builder.emit('Mul', [gradient, gradient])
|
|
|
|
|
# 计算one_sub_beta_2乘以gradient的平方
|
|
|
|
|
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
|
|
|
|
|
# 计算next_v
|
|
|
|
|
next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
|
|
|
|
|
# 计算next_v的平方根
|
|
|
|
|
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
|
|
|
|
|
# 计算sqrt_next_v加上eps
|
|
|
|
|
sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
|
|
|
|
|
# 计算更新值
|
|
|
|
|
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
|
|
|
|
|
# 计算更新值乘以lr
|
|
|
|
|
update_with_lr = graph_builder.emit('Mul', [lr, update])
|
|
|
|
|
# 计算next_para
|
|
|
|
|
next_para = graph_builder.emit('Sub', [param, update_with_lr])
|
|
|
|
|
|
|
|
|
|
# 将next_para赋值给param
|
|
|
|
|
param_result = graph_builder.emit(
|
|
|
|
|
'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
|
|
|
|
|
# 将next_m赋值给m
|
|
|
|
|
param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True})
|
|
|
|
|
# 将next_v赋值给v
|
|
|
|
|
param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True})
|
|
|
|
|
|
|
|
|
|
# 返回param_result
|
|
|
|
|
return param_result
|
|
|
|
|