add comments for _extends\graph_kernel\expanders\addn.py

branch-yixin
yixin 2 months ago
parent 1531f33582
commit e414d6025d

@ -13,20 +13,47 @@
# limitations under the License.
# ===========================================================================
"""generate json desc for addn"""
# 导入GraphKernelUnsupportedException异常类
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
# 导入Expander和ExpanderInfoValidator类
from ._utils import Expander, ExpanderInfoValidator as VLD
# 使用VLD.check_all_formats_same装饰器确保所有输入格式相同
@VLD.check_all_formats_same
class AddN(Expander):
"""Expand AddN to multiple Adds"""
# 检查输入数量是否大于1
def _check(self):
"""
检查输入的数量是否满足要求
Args:
Returns:
Raises:
GKException: 如果输入的数量小于2则抛出GKException异常
"""
if len(self.inputs) < 2:
raise GKException("For 'AddN', the inputs num should be greater than 1, but got {}"
.format(len(self.inputs)))
# 将AddN展开为多个Add操作
def _expand(self, graph_builder):
"""
对输入张量进行逐元素加法运算
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成图节点
Returns:
Tensor: 逐元素加法运算后的结果张量
"""
result = self.inputs[0]
for inp in self.inputs[1:]:
result = graph_builder.emit('Add', [result, inp])

Loading…
Cancel
Save