diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gelu.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gelu.py index 24fe81bc..24aeab68 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gelu.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gelu.py @@ -22,6 +22,16 @@ class GeLU(Expander): CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) def _expand(self, graph_builder): + """ + 计算输入张量的GELU激活函数值。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于生成计算图。 + + Returns: + Tensor: 输入张量的GELU激活函数值。 + + """ # cal formula are: # gelu of x is 0.5 * x * (1.0 + tanh(y)) # y is 'sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)' @@ -29,20 +39,33 @@ class GeLU(Expander): input_x = self.inputs[0] # cal y + # 计算 input_x 的平方 mul_0 = graph_builder.emit('Mul', [input_x, input_x]) + # 计算 input_x 的立方 pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) + # 创建一个 CSVALUE 常量 const_csvalue = graph_builder.value(pow_0.dtype, self.CSVALUE) + # 计算 pow_0 和 CSVALUE 的乘积 mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) + # 计算 input_x 和 mul_1 的和 tanh_res = graph_builder.emit('Add', [input_x, mul_1]) + # 创建一个 CSVALUE_SQRT_TWO_DIV_PI 常量 const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, self.CSVALUE_SQRT_TWO_DIV_PI) + # 计算 tanh_res 和 CSVALUE_SQRT_TWO_DIV_PI 的乘积 y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) # cal gelu(x) + # 计算 y 的 tanh 值 tanh_y = graph_builder.emit('Tanh', [y]) + # 创建一个 1 常量 const_one = graph_builder.value(tanh_y.dtype, 1) + # 创建一个 0.5 常量 const_half = graph_builder.value(tanh_y.dtype, 0.5) + # 计算 tanh_y 和 1 的和 tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one]) + # 计算 input_x 和 tanh_y_add_one 的乘积 mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) + # 计算 const_half 和 mul_x 的乘积 result = graph_builder.emit('Mul', [const_half, mul_x]) return result