_extends\graph_kernel\expanders\fused_adam.py

branch-yixin
yixin 2 months ago
parent 3e1df14f3b
commit 17d5bbbe72

@ -21,24 +21,51 @@ class FusedAdam(Expander):
"""FusedAdam expander"""
def _expand(self, graph_builder):
"""
使用图构建器对模型参数进行更新
Args:
graph_builder (GraphBuilder): 图构建器实例用于生成计算图
Returns:
Tensor: 更新后的参数结果
"""
# 获取输入参数
beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient = self.inputs
# 计算beta_1乘以m
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
# 计算one_sub_beta_1乘以gradient
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
# 计算next_m
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
# 计算beta_2乘以v
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
# 计算gradient的平方
grad_square = graph_builder.emit('Mul', [gradient, gradient])
# 计算one_sub_beta_2乘以gradient的平方
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
# 计算next_v
next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
# 计算next_v的平方根
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
# 计算sqrt_next_v加上eps
sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
# 计算更新值
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
# 计算更新值乘以lr
update_with_lr = graph_builder.emit('Mul', [lr, update])
# 计算next_para
next_para = graph_builder.emit('Sub', [param, update_with_lr])
# 将next_para赋值给param
param_result = graph_builder.emit(
'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
# 将next_m赋值给m
param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True})
# 将next_v赋值给v
param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True})
# 返回param_result
return param_result

Loading…
Cancel
Save