|
|
|
@ -19,11 +19,29 @@ from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
|
|
|
@VLD.check_all_formats_same
|
|
|
|
|
class GeLUGrad(Expander):
|
|
|
|
|
"""GeLUGrad expander"""
|
|
|
|
|
CSVALUE = 0.044715
|
|
|
|
|
|
|
|
|
|
# CSVALUE = 0.044715
|
|
|
|
|
CSVALUE = 0.044715 # CSVALUE的值为0.044715
|
|
|
|
|
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
|
|
|
|
|
CSVALUE_TRI = 0.134141 # CSVALUE * 3
|
|
|
|
|
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
"""
|
|
|
|
|
计算GELU函数的梯度。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
graph_builder (GraphBuilder): 图构建器对象,用于生成计算图。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: GELU函数的梯度。
|
|
|
|
|
|
|
|
|
|
计算公式如下:
|
|
|
|
|
GELU的梯度dy和x是dy * y'
|
|
|
|
|
y' = 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
|
|
|
|
|
tanh_para = sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
|
|
|
|
|
mul_right = sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# cal formula are:
|
|
|
|
|
# gelu_grad of dy and x is dy * y'
|
|
|
|
|
# y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
|
|
|
|
@ -33,21 +51,33 @@ class GeLUGrad(Expander):
|
|
|
|
|
input_dy, input_x, _ = self.inputs
|
|
|
|
|
|
|
|
|
|
# create some const var
|
|
|
|
|
# 创建一个常量,值为self.CSVALUE,数据类型为input_dy.dtype
|
|
|
|
|
const_csvalue = graph_builder.value(input_dy.dtype, self.CSVALUE)
|
|
|
|
|
# 创建一个常量,值为self.CSVALUE_SQRT_TWO_DIV_PI,数据类型为input_dy.dtype
|
|
|
|
|
const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, self.CSVALUE_SQRT_TWO_DIV_PI)
|
|
|
|
|
# 创建一个常量,值为self.CSVALUE_TRI,数据类型为input_dy.dtype
|
|
|
|
|
const_csvalue_tri = graph_builder.value(input_dy.dtype, self.CSVALUE_TRI)
|
|
|
|
|
# 创建一个常量,值为1,数据类型为input_dy.dtype
|
|
|
|
|
const_one = graph_builder.value(input_dy.dtype, 1)
|
|
|
|
|
# 创建一个常量,值为0.5,数据类型为input_dy.dtype
|
|
|
|
|
const_half = graph_builder.value(input_dy.dtype, 0.5)
|
|
|
|
|
|
|
|
|
|
# cal mul_right
|
|
|
|
|
# 计算input_x的平方
|
|
|
|
|
mul_double = graph_builder.emit('Mul', [input_x, input_x])
|
|
|
|
|
# 将const_csvalue_tri与mul_double相乘
|
|
|
|
|
mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double])
|
|
|
|
|
# 将const_one与mul_double_mul_tri相加
|
|
|
|
|
mul_add_one = graph_builder.emit('Add', [const_one, mul_double_mul_tri])
|
|
|
|
|
# 将const_csvalue_sqrt_two_div_pi与mul_add_one相乘
|
|
|
|
|
mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one])
|
|
|
|
|
|
|
|
|
|
# cal tanh_para
|
|
|
|
|
# 计算input_x和mul_double的乘积
|
|
|
|
|
mul_triple = graph_builder.emit('Mul', [input_x, mul_double])
|
|
|
|
|
# 计算const_csvalue和mul_triple的乘积
|
|
|
|
|
mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple])
|
|
|
|
|
# 计算input_x和mul_triple_mul_csvalue的和
|
|
|
|
|
mul_add_x = graph_builder.emit('Add', [input_x, mul_triple_mul_csvalue])
|
|
|
|
|
tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x])
|
|
|
|
|
|
|
|
|
|