|
|
|
@ -13,20 +13,47 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ===========================================================================
|
|
|
|
|
"""generate json desc for addn"""
|
|
|
|
|
# 导入GraphKernelUnsupportedException异常类
|
|
|
|
|
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
|
|
|
|
# 导入Expander和ExpanderInfoValidator类
|
|
|
|
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 使用VLD.check_all_formats_same装饰器,确保所有输入格式相同
|
|
|
|
|
@VLD.check_all_formats_same
|
|
|
|
|
class AddN(Expander):
|
|
|
|
|
"""Expand AddN to multiple Adds"""
|
|
|
|
|
|
|
|
|
|
# 检查输入数量是否大于1
|
|
|
|
|
def _check(self):
|
|
|
|
|
"""
|
|
|
|
|
检查输入的数量是否满足要求。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
无
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
无
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
GKException: 如果输入的数量小于2,则抛出GKException异常。
|
|
|
|
|
"""
|
|
|
|
|
if len(self.inputs) < 2:
|
|
|
|
|
raise GKException("For 'AddN', the inputs num should be greater than 1, but got {}"
|
|
|
|
|
.format(len(self.inputs)))
|
|
|
|
|
|
|
|
|
|
# 将AddN展开为多个Add操作
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
"""
|
|
|
|
|
对输入张量进行逐元素加法运算。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
graph_builder (GraphBuilder): 图构建器对象,用于生成图节点。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: 逐元素加法运算后的结果张量。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
result = self.inputs[0]
|
|
|
|
|
for inp in self.inputs[1:]:
|
|
|
|
|
result = graph_builder.emit('Add', [result, inp])
|
|
|
|
|