_extends\graph_kernel\expanders\fused_adam_weight_decay.py

branch-yixin
yixin 7 months ago
parent 60f231dc1c
commit 3e1df14f3b

@ -21,27 +21,54 @@ class FusedAdamWeightDecay(Expander):
"""FusedAdamWeightDecay expander"""
def _expand(self, graph_builder):
"""
对输入参数进行梯度下降更新
Args:
graph_builder (GraphBuilder): 图构建器对象用于在图中添加节点
Returns:
ParaResult: 更新后的参数结果节点
"""
beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient, weight_decay = self.inputs
# compute result
# 计算beta_1和m的乘积
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
# 计算one_sub_beta_1和gradient的乘积
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
# 计算next_m
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
# 计算beta_2和v的乘积
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
# 计算gradient的平方
grad_square = graph_builder.emit('Mul', [gradient, gradient])
# 计算one_sub_beta_2和grad_square的乘积
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
# 计算next_v
next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
# 计算sqrt_next_v
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
# 计算sqrt_next_v和eps的和
sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
# 计算update
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
# 计算param_with_weight_decay
param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param])
# 计算update和param_with_weight_decay的和
update = graph_builder.emit('Add', [update, param_with_weight_decay])
# 计算update_with_lr
update_with_lr = graph_builder.emit('Mul', [lr, update])
# 计算next_para
next_para = graph_builder.emit('Sub', [param, update_with_lr])
# 将next_para赋值给param
para_result = graph_builder.emit(
'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
# 将next_m赋值给m
para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True})
# 将next_v赋值给v
para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True})
return para_result

Loading…
Cancel
Save