|
|
|
@ -17,34 +17,84 @@ from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedEx
|
|
|
|
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# @VLD.check_all_formats_same:检查所有格式的相同性
|
|
|
|
|
@VLD.check_all_formats_same
|
|
|
|
|
class EqualCount(Expander):
|
|
|
|
|
"""EqualCount expander"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, expand_info):
|
|
|
|
|
"""
|
|
|
|
|
初始化方法。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
expand_info (dict): 扩展信息字典。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# 调用父类的初始化方法
|
|
|
|
|
super().__init__(expand_info)
|
|
|
|
|
# 获取输入x的形状
|
|
|
|
|
self.shape_x = self.inputs[0]['shape']
|
|
|
|
|
# 获取输入y的形状
|
|
|
|
|
self.shape_y = self.inputs[1]['shape']
|
|
|
|
|
# 获取输入x的数据类型
|
|
|
|
|
self.dtype_x = self.inputs[0]['data_type']
|
|
|
|
|
# 获取输入y的数据类型
|
|
|
|
|
self.dtype_y = self.inputs[1]['data_type']
|
|
|
|
|
|
|
|
|
|
def _check(self):
|
|
|
|
|
"""
|
|
|
|
|
检查输入的两个张量是否具有相同的形状和数据类型。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
无
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
无
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
GKException: 如果两个张量的形状不同,则引发异常,异常信息中包含两个张量的形状。
|
|
|
|
|
GKException: 如果两个张量的数据类型不同,则引发异常,异常信息中包含两个张量的数据类型。
|
|
|
|
|
"""
|
|
|
|
|
# 判断输入的形状是否相同
|
|
|
|
|
if self.shape_x != self.shape_y:
|
|
|
|
|
# 如果不相同,抛出异常
|
|
|
|
|
raise GKException("For 'EqualCount', the inputs shape should be same, but got {} and {}"
|
|
|
|
|
.format(self.shape_x, self.shape_y))
|
|
|
|
|
# 判断输入的数据类型是否相同
|
|
|
|
|
if self.dtype_x != self.dtype_y:
|
|
|
|
|
# 如果不相同,抛出异常
|
|
|
|
|
raise GKException("For 'EqualCount', the inputs data type should be same, but got {} and {}"
|
|
|
|
|
.format(self.dtype_x, self.dtype_y))
|
|
|
|
|
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
"""
|
|
|
|
|
扩展输入维度的方法。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
graph_builder: 图构建器对象,用于生成计算图。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
扩展后的张量。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# 获取输入张量
|
|
|
|
|
input_x = self.inputs[0]
|
|
|
|
|
input_y = self.inputs[1]
|
|
|
|
|
|
|
|
|
|
# 比较输入张量是否相等
|
|
|
|
|
eql_val = graph_builder.emit('Equal', [input_x, input_y])
|
|
|
|
|
# 将比较结果转换为float32类型
|
|
|
|
|
cast_val = graph_builder.emit('Cast', [eql_val], attrs={'dst_type': 'float32'})
|
|
|
|
|
# 获取输入张量的维度
|
|
|
|
|
axis = list(range(len(input_x.shape)))
|
|
|
|
|
# 对比较结果进行求和
|
|
|
|
|
result = graph_builder.emit('ReduceSum', [cast_val], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
|
|
|
|
|
|
|
|
# 如果求和结果的数据类型与输入张量的数据类型不同,则将求和结果转换为输入张量的数据类型
|
|
|
|
|
if result.dtype != input_x.dtype:
|
|
|
|
|
result = graph_builder.emit('Cast', [result], attrs={'dst_type': input_x.dtype})
|
|
|
|
|
# 返回求和结果
|
|
|
|
|
return result
|
|
|
|
|