diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py index 53598068..47dcbce1 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py @@ -21,27 +21,54 @@ class FusedAdamWeightDecay(Expander): """FusedAdamWeightDecay expander""" def _expand(self, graph_builder): + """ + 对输入参数进行梯度下降更新。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于在图中添加节点。 + + Returns: + ParaResult: 更新后的参数结果节点。 + + """ beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient, weight_decay = self.inputs # compute result + # 计算beta_1和m的乘积 beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) + # 计算one_sub_beta_1和gradient的乘积 one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) + # 计算next_m next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad]) + # 计算beta_2和v的乘积 beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) + # 计算gradient的平方 grad_square = graph_builder.emit('Mul', [gradient, gradient]) + # 计算one_sub_beta_2和grad_square的乘积 one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) + # 计算next_v next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) + # 计算sqrt_next_v sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) + # 计算sqrt_next_v和eps的和 sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps]) + # 计算update update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) + # 计算param_with_weight_decay param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param]) + # 计算update和param_with_weight_decay的和 update = graph_builder.emit('Add', [update, param_with_weight_decay]) + # 计算update_with_lr update_with_lr = graph_builder.emit('Mul', [lr, update]) + # 计算next_para next_para = graph_builder.emit('Sub', [param, update_with_lr]) + # 将next_para赋值给param para_result = graph_builder.emit( 'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) + # 将next_m赋值给m para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True}) + # 将next_v赋值给v para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True}) return para_result