diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/erfc.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/erfc.py index 7e97c455..198120d6 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/erfc.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/erfc.py @@ -16,20 +16,44 @@ from ._utils import Expander +# 定义一个Erfc类,继承自Expander类 class Erfc(Expander): """Erfc expander""" def _expand(self, graph_builder): + """ + 对输入数据进行扩展处理。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tensor: 扩展处理后的结果。 + + """ + # 获取输入数据 input_x = self.inputs[0] + # 初始化结果 result = None + # 如果输入数据的类型是float16 if input_x.dtype == "float16": + # 创建一个float32类型的常量1 const_one = graph_builder.value("float32", 1) + # 将输入数据转换为float32类型 input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"}) + # 计算输入数据的erf值 erf_result = graph_builder.emit('Erf', [input_x]) + # 计算结果 result = graph_builder.emit('Sub', [const_one, erf_result]) + # 将结果转换为float16类型 result = graph_builder.emit('Cast', [result], attrs={'dst_type': "float16"}) + # 返回结果 return result + # 创建一个与输入数据类型相同的常量1 const_one = graph_builder.value(input_x.dtype, 1) + # 计算输入数据的erf值 erf_result = graph_builder.emit('Erf', [input_x]) + # 计算结果 result = graph_builder.emit('Sub', [const_one, erf_result]) + # 返回结果 return result