add comments for _extends\graph_kernel\expanders\clip_by_norm_no_div_sum.py

branch-yixin
yixin 7 months ago
parent 2c4a524a6a
commit 8388b30ee2

@ -21,13 +21,28 @@ class ClipByNormNoDivSum(Expander):
"""ClipByNormNoDivSum expander"""
def _expand(self, graph_builder):
"""
对输入的张量进行计算返回计算结果
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成计算图
Returns:
Tensor: 计算结果张量
"""
input_x0, input_x1, input_x2, input_x3 = self.inputs
# cal result
# 计算大于结果
greater_res = graph_builder.emit('Greater', [input_x0, input_x1])
# 根据大于结果选择input_x0或input_x2
select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2])
# 计算select_res0的平方根
sqrt_res = graph_builder.emit('Sqrt', [select_res0])
# 根据大于结果选择sqrt_res或input_x0
select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0])
# 计算select_res1和input_x3的最大值
result = graph_builder.emit('Maximum', [select_res1, input_x3])
return result

Loading…
Cancel
Save