From 8388b30ee2f3521acc041dcb14c8b15a815583d3 Mon Sep 17 00:00:00 2001 From: yixin <2050485123@qq.com> Date: Wed, 25 Dec 2024 10:25:24 +0800 Subject: [PATCH] add comments for _extends\graph_kernel\expanders\clip_by_norm_no_div_sum.py --- .../expanders/clip_by_norm_no_div_sum.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py index e6c345f4..81fe4cbb 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py @@ -21,13 +21,28 @@ class ClipByNormNoDivSum(Expander): """ClipByNormNoDivSum expander""" def _expand(self, graph_builder): + """ + 对输入的张量进行计算,返回计算结果。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于生成计算图。 + + Returns: + Tensor: 计算结果张量。 + + """ input_x0, input_x1, input_x2, input_x3 = self.inputs # cal result + # 计算大于结果 greater_res = graph_builder.emit('Greater', [input_x0, input_x1]) + # 根据大于结果选择input_x0或input_x2 select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2]) + # 计算select_res0的平方根 sqrt_res = graph_builder.emit('Sqrt', [select_res0]) + # 根据大于结果选择sqrt_res或input_x0 select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0]) + # 计算select_res1和input_x3的最大值 result = graph_builder.emit('Maximum', [select_res1, input_x3]) return result