_extends\graph_kernel\expanders\equal_count.py

branch-yixin
yixin 2 months ago
parent c64739f456
commit 0015c083a7

@ -17,34 +17,84 @@ from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedEx
from ._utils import Expander, ExpanderInfoValidator as VLD
# @VLD.check_all_formats_same检查所有格式的相同性
@VLD.check_all_formats_same
class EqualCount(Expander):
"""EqualCount expander"""
def __init__(self, expand_info):
"""
初始化方法
Args:
expand_info (dict): 扩展信息字典
Returns:
None
"""
# 调用父类的初始化方法
super().__init__(expand_info)
# 获取输入x的形状
self.shape_x = self.inputs[0]['shape']
# 获取输入y的形状
self.shape_y = self.inputs[1]['shape']
# 获取输入x的数据类型
self.dtype_x = self.inputs[0]['data_type']
# 获取输入y的数据类型
self.dtype_y = self.inputs[1]['data_type']
def _check(self):
"""
检查输入的两个张量是否具有相同的形状和数据类型
Args:
Returns:
Raises:
GKException: 如果两个张量的形状不同则引发异常异常信息中包含两个张量的形状
GKException: 如果两个张量的数据类型不同则引发异常异常信息中包含两个张量的数据类型
"""
# 判断输入的形状是否相同
if self.shape_x != self.shape_y:
# 如果不相同,抛出异常
raise GKException("For 'EqualCount', the inputs shape should be same, but got {} and {}"
.format(self.shape_x, self.shape_y))
# 判断输入的数据类型是否相同
if self.dtype_x != self.dtype_y:
# 如果不相同,抛出异常
raise GKException("For 'EqualCount', the inputs data type should be same, but got {} and {}"
.format(self.dtype_x, self.dtype_y))
def _expand(self, graph_builder):
"""
扩展输入维度的方法
Args:
graph_builder: 图构建器对象用于生成计算图
Returns:
扩展后的张量
"""
# 获取输入张量
input_x = self.inputs[0]
input_y = self.inputs[1]
# 比较输入张量是否相等
eql_val = graph_builder.emit('Equal', [input_x, input_y])
# 将比较结果转换为float32类型
cast_val = graph_builder.emit('Cast', [eql_val], attrs={'dst_type': 'float32'})
# 获取输入张量的维度
axis = list(range(len(input_x.shape)))
# 对比较结果进行求和
result = graph_builder.emit('ReduceSum', [cast_val], attrs={'reduce_axis': axis, 'keep_dims': True})
# 如果求和结果的数据类型与输入张量的数据类型不同,则将求和结果转换为输入张量的数据类型
if result.dtype != input_x.dtype:
result = graph_builder.emit('Cast', [result], attrs={'dst_type': input_x.dtype})
# 返回求和结果
return result

Loading…
Cancel
Save