_extends\graph_kernel\expanders\gelu.py

branch-yixin
yixin 2 months ago
parent ace1dccc7a
commit 95f50332ac

@ -22,6 +22,16 @@ class GeLU(Expander):
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
def _expand(self, graph_builder):
"""
计算输入张量的GELU激活函数值
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成计算图
Returns:
Tensor: 输入张量的GELU激活函数值
"""
# cal formula are:
# gelu of x is 0.5 * x * (1.0 + tanh(y))
# y is 'sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)'
@ -29,20 +39,33 @@ class GeLU(Expander):
input_x = self.inputs[0]
# cal y
# 计算 input_x 的平方
mul_0 = graph_builder.emit('Mul', [input_x, input_x])
# 计算 input_x 的立方
pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
# 创建一个 CSVALUE 常量
const_csvalue = graph_builder.value(pow_0.dtype, self.CSVALUE)
# 计算 pow_0 和 CSVALUE 的乘积
mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
# 计算 input_x 和 mul_1 的和
tanh_res = graph_builder.emit('Add', [input_x, mul_1])
# 创建一个 CSVALUE_SQRT_TWO_DIV_PI 常量
const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, self.CSVALUE_SQRT_TWO_DIV_PI)
# 计算 tanh_res 和 CSVALUE_SQRT_TWO_DIV_PI 的乘积
y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi])
# cal gelu(x)
# 计算 y 的 tanh 值
tanh_y = graph_builder.emit('Tanh', [y])
# 创建一个 1 常量
const_one = graph_builder.value(tanh_y.dtype, 1)
# 创建一个 0.5 常量
const_half = graph_builder.value(tanh_y.dtype, 0.5)
# 计算 tanh_y 和 1 的和
tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one])
# 计算 input_x 和 tanh_y_add_one 的乘积
mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one])
# 计算 const_half 和 mul_x 的乘积
result = graph_builder.emit('Mul', [const_half, mul_x])
return result

Loading…
Cancel
Save