|
|
|
@ -16,20 +16,44 @@
|
|
|
|
|
from ._utils import Expander
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 定义一个Erfc类,继承自Expander类
|
|
|
|
|
class Erfc(Expander):
|
|
|
|
|
"""Erfc expander"""
|
|
|
|
|
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
"""
|
|
|
|
|
对输入数据进行扩展处理。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
graph_builder (GraphBuilder): 图构建器对象。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: 扩展处理后的结果。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# 获取输入数据
|
|
|
|
|
input_x = self.inputs[0]
|
|
|
|
|
# 初始化结果
|
|
|
|
|
result = None
|
|
|
|
|
# 如果输入数据的类型是float16
|
|
|
|
|
if input_x.dtype == "float16":
|
|
|
|
|
# 创建一个float32类型的常量1
|
|
|
|
|
const_one = graph_builder.value("float32", 1)
|
|
|
|
|
# 将输入数据转换为float32类型
|
|
|
|
|
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"})
|
|
|
|
|
# 计算输入数据的erf值
|
|
|
|
|
erf_result = graph_builder.emit('Erf', [input_x])
|
|
|
|
|
# 计算结果
|
|
|
|
|
result = graph_builder.emit('Sub', [const_one, erf_result])
|
|
|
|
|
# 将结果转换为float16类型
|
|
|
|
|
result = graph_builder.emit('Cast', [result], attrs={'dst_type': "float16"})
|
|
|
|
|
# 返回结果
|
|
|
|
|
return result
|
|
|
|
|
# 创建一个与输入数据类型相同的常量1
|
|
|
|
|
const_one = graph_builder.value(input_x.dtype, 1)
|
|
|
|
|
# 计算输入数据的erf值
|
|
|
|
|
erf_result = graph_builder.emit('Erf', [input_x])
|
|
|
|
|
# 计算结果
|
|
|
|
|
result = graph_builder.emit('Sub', [const_one, erf_result])
|
|
|
|
|
# 返回结果
|
|
|
|
|
return result
|
|
|
|
|