_extends\graph_kernel\expanders\gelu_grad.py

branch-yixin
yixin 2 months ago
parent 07b3545276
commit ace1dccc7a

@ -19,11 +19,29 @@ from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.check_all_formats_same
class GeLUGrad(Expander):
"""GeLUGrad expander"""
CSVALUE = 0.044715
# CSVALUE = 0.044715
CSVALUE = 0.044715 # CSVALUE的值为0.044715
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
CSVALUE_TRI = 0.134141 # CSVALUE * 3
def _expand(self, graph_builder):
"""
计算GELU函数的梯度
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成计算图
Returns:
Tensor: GELU函数的梯度
计算公式如下
GELU的梯度dy和x是dy * y'
y' = 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
tanh_para = sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
mul_right = sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x)
"""
# cal formula are:
# gelu_grad of dy and x is dy * y'
# y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
@ -33,21 +51,33 @@ class GeLUGrad(Expander):
input_dy, input_x, _ = self.inputs
# create some const var
# 创建一个常量值为self.CSVALUE数据类型为input_dy.dtype
const_csvalue = graph_builder.value(input_dy.dtype, self.CSVALUE)
# 创建一个常量值为self.CSVALUE_SQRT_TWO_DIV_PI数据类型为input_dy.dtype
const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, self.CSVALUE_SQRT_TWO_DIV_PI)
# 创建一个常量值为self.CSVALUE_TRI数据类型为input_dy.dtype
const_csvalue_tri = graph_builder.value(input_dy.dtype, self.CSVALUE_TRI)
# 创建一个常量值为1数据类型为input_dy.dtype
const_one = graph_builder.value(input_dy.dtype, 1)
# 创建一个常量值为0.5数据类型为input_dy.dtype
const_half = graph_builder.value(input_dy.dtype, 0.5)
# cal mul_right
# 计算input_x的平方
mul_double = graph_builder.emit('Mul', [input_x, input_x])
# 将const_csvalue_tri与mul_double相乘
mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double])
# 将const_one与mul_double_mul_tri相加
mul_add_one = graph_builder.emit('Add', [const_one, mul_double_mul_tri])
# 将const_csvalue_sqrt_two_div_pi与mul_add_one相乘
mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one])
# cal tanh_para
# 计算input_x和mul_double的乘积
mul_triple = graph_builder.emit('Mul', [input_x, mul_double])
# 计算const_csvalue和mul_triple的乘积
mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple])
# 计算input_x和mul_triple_mul_csvalue的和
mul_add_x = graph_builder.emit('Add', [input_x, mul_triple_mul_csvalue])
tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x])

Loading…
Cancel
Save