|
|
|
@ -20,9 +20,23 @@ class FusedMulAdd(Expander):
|
|
|
|
|
"""FusedMulAdd expander"""
|
|
|
|
|
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
"""
|
|
|
|
|
执行扩展操作。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
graph_builder (GraphBuilder): 图构建器对象。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: 执行加法操作后的结果。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# 获取输入
|
|
|
|
|
input_x, input_y, input_z = self.inputs
|
|
|
|
|
|
|
|
|
|
# 发射乘法操作
|
|
|
|
|
mul_res = graph_builder.emit('Mul', [input_x, input_y])
|
|
|
|
|
# 发射加法操作
|
|
|
|
|
result = graph_builder.emit('Add', [mul_res, input_z])
|
|
|
|
|
|
|
|
|
|
# 返回结果
|
|
|
|
|
return result
|
|
|
|
|