|
|
|
@ -22,6 +22,16 @@ class GeLU(Expander):
|
|
|
|
|
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
|
|
|
|
|
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
"""
|
|
|
|
|
计算输入张量的GELU激活函数值。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
graph_builder (GraphBuilder): 图构建器对象,用于生成计算图。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: 输入张量的GELU激活函数值。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# cal formula are:
|
|
|
|
|
# gelu of x is 0.5 * x * (1.0 + tanh(y))
|
|
|
|
|
# y is 'sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)'
|
|
|
|
@ -29,20 +39,33 @@ class GeLU(Expander):
|
|
|
|
|
input_x = self.inputs[0]
|
|
|
|
|
|
|
|
|
|
# cal y
|
|
|
|
|
# 计算 input_x 的平方
|
|
|
|
|
mul_0 = graph_builder.emit('Mul', [input_x, input_x])
|
|
|
|
|
# 计算 input_x 的立方
|
|
|
|
|
pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
|
|
|
|
|
# 创建一个 CSVALUE 常量
|
|
|
|
|
const_csvalue = graph_builder.value(pow_0.dtype, self.CSVALUE)
|
|
|
|
|
# 计算 pow_0 和 CSVALUE 的乘积
|
|
|
|
|
mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
|
|
|
|
|
# 计算 input_x 和 mul_1 的和
|
|
|
|
|
tanh_res = graph_builder.emit('Add', [input_x, mul_1])
|
|
|
|
|
# 创建一个 CSVALUE_SQRT_TWO_DIV_PI 常量
|
|
|
|
|
const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, self.CSVALUE_SQRT_TWO_DIV_PI)
|
|
|
|
|
# 计算 tanh_res 和 CSVALUE_SQRT_TWO_DIV_PI 的乘积
|
|
|
|
|
y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi])
|
|
|
|
|
|
|
|
|
|
# cal gelu(x)
|
|
|
|
|
# 计算 y 的 tanh 值
|
|
|
|
|
tanh_y = graph_builder.emit('Tanh', [y])
|
|
|
|
|
# 创建一个 1 常量
|
|
|
|
|
const_one = graph_builder.value(tanh_y.dtype, 1)
|
|
|
|
|
# 创建一个 0.5 常量
|
|
|
|
|
const_half = graph_builder.value(tanh_y.dtype, 0.5)
|
|
|
|
|
# 计算 tanh_y 和 1 的和
|
|
|
|
|
tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one])
|
|
|
|
|
# 计算 input_x 和 tanh_y_add_one 的乘积
|
|
|
|
|
mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one])
|
|
|
|
|
# 计算 const_half 和 mul_x 的乘积
|
|
|
|
|
result = graph_builder.emit('Mul', [const_half, mul_x])
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|