_extends\graph_kernel\expanders\erfc.py

branch-yixin
yixin 8 months ago
parent 0015c083a7
commit 979f67f6fa

@ -16,20 +16,44 @@
from ._utils import Expander
# 定义一个Erfc类继承自Expander类
class Erfc(Expander):
"""Erfc expander"""
def _expand(self, graph_builder):
"""
对输入数据进行扩展处理
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展处理后的结果
"""
# 获取输入数据
input_x = self.inputs[0]
# 初始化结果
result = None
# 如果输入数据的类型是float16
if input_x.dtype == "float16":
# 创建一个float32类型的常量1
const_one = graph_builder.value("float32", 1)
# 将输入数据转换为float32类型
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"})
# 计算输入数据的erf值
erf_result = graph_builder.emit('Erf', [input_x])
# 计算结果
result = graph_builder.emit('Sub', [const_one, erf_result])
# 将结果转换为float16类型
result = graph_builder.emit('Cast', [result], attrs={'dst_type': "float16"})
# 返回结果
return result
# 创建一个与输入数据类型相同的常量1
const_one = graph_builder.value(input_x.dtype, 1)
# 计算输入数据的erf值
erf_result = graph_builder.emit('Erf', [input_x])
# 计算结果
result = graph_builder.emit('Sub', [const_one, erf_result])
# 返回结果
return result

Loading…
Cancel
Save