|
|
|
@ -21,13 +21,28 @@ class ClipByNormNoDivSum(Expander):
|
|
|
|
|
"""ClipByNormNoDivSum expander"""
|
|
|
|
|
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
"""
|
|
|
|
|
对输入的张量进行计算,返回计算结果。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
graph_builder (GraphBuilder): 图构建器对象,用于生成计算图。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: 计算结果张量。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
input_x0, input_x1, input_x2, input_x3 = self.inputs
|
|
|
|
|
|
|
|
|
|
# cal result
|
|
|
|
|
# 计算大于结果
|
|
|
|
|
greater_res = graph_builder.emit('Greater', [input_x0, input_x1])
|
|
|
|
|
# 根据大于结果选择input_x0或input_x2
|
|
|
|
|
select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2])
|
|
|
|
|
# 计算select_res0的平方根
|
|
|
|
|
sqrt_res = graph_builder.emit('Sqrt', [select_res0])
|
|
|
|
|
# 根据大于结果选择sqrt_res或input_x0
|
|
|
|
|
select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0])
|
|
|
|
|
# 计算select_res1和input_x3的最大值
|
|
|
|
|
result = graph_builder.emit('Maximum', [select_res1, input_x3])
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|