From 979f67f6fab82c0be851605189c5ad6241743398 Mon Sep 17 00:00:00 2001 From: yixin <2050485123@qq.com> Date: Wed, 25 Dec 2024 10:34:17 +0800 Subject: [PATCH] _extends\graph_kernel\expanders\erfc.py --- .../_extends/graph_kernel/expanders/erfc.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/erfc.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/erfc.py index 7e97c455..198120d6 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/erfc.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/erfc.py @@ -16,20 +16,44 @@ from ._utils import Expander +# 定义一个Erfc类,继承自Expander类 class Erfc(Expander): """Erfc expander""" def _expand(self, graph_builder): + """ + 对输入数据进行扩展处理。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tensor: 扩展处理后的结果。 + + """ + # 获取输入数据 input_x = self.inputs[0] + # 初始化结果 result = None + # 如果输入数据的类型是float16 if input_x.dtype == "float16": + # 创建一个float32类型的常量1 const_one = graph_builder.value("float32", 1) + # 将输入数据转换为float32类型 input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"}) + # 计算输入数据的erf值 erf_result = graph_builder.emit('Erf', [input_x]) + # 计算结果 result = graph_builder.emit('Sub', [const_one, erf_result]) + # 将结果转换为float16类型 result = graph_builder.emit('Cast', [result], attrs={'dst_type': "float16"}) + # 返回结果 return result + # 创建一个与输入数据类型相同的常量1 const_one = graph_builder.value(input_x.dtype, 1) + # 计算输入数据的erf值 erf_result = graph_builder.emit('Erf', [input_x]) + # 计算结果 result = graph_builder.emit('Sub', [const_one, erf_result]) + # 返回结果 return result