|
|
|
@ -25,24 +25,32 @@ def infer(op_name, inputs, attrs):
|
|
|
|
|
"""infer shape dtype and format"""
|
|
|
|
|
|
|
|
|
|
def _create_opinfer():
|
|
|
|
|
# 获取当前模块
|
|
|
|
|
self_module = sys.modules.get(__name__, None)
|
|
|
|
|
# 如果当前模块为空,则抛出异常
|
|
|
|
|
if self_module is None:
|
|
|
|
|
raise GKException("OpInfo does not support op {}".format(op_name))
|
|
|
|
|
|
|
|
|
|
# 如果当前模块有op_name属性,则获取该属性
|
|
|
|
|
if hasattr(self_module, op_name):
|
|
|
|
|
op_cls = getattr(self_module, op_name)
|
|
|
|
|
return op_cls(op_name, inputs, attrs)
|
|
|
|
|
# common infer
|
|
|
|
|
# 定义一个字典,将PrimLib中的iter_type映射到对应的类名
|
|
|
|
|
class_name_map = {
|
|
|
|
|
PrimLib.ELEMWISE: "_Elemwise",
|
|
|
|
|
PrimLib.REDUCE: "_Reduce",
|
|
|
|
|
}
|
|
|
|
|
# 获取op_name对应的iter_type
|
|
|
|
|
cls_name = class_name_map.get(PrimLib.primtives.get(op_name, PrimLib.default_primtive).iter_type, None)
|
|
|
|
|
# 如果没有对应的iter_type,则抛出异常
|
|
|
|
|
if not cls_name:
|
|
|
|
|
raise GKException("OpInfo does not support op {}".format(op_name))
|
|
|
|
|
# 获取对应的类
|
|
|
|
|
op_cls = getattr(self_module, cls_name)
|
|
|
|
|
return op_cls(op_name, inputs, attrs)
|
|
|
|
|
|
|
|
|
|
# 返回infer方法
|
|
|
|
|
return _create_opinfer().infer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -55,49 +63,65 @@ class OpInfer:
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, name, inputs, attrs):
|
|
|
|
|
# 初始化函数,传入参数name、inputs、attrs
|
|
|
|
|
self.name = name
|
|
|
|
|
self.inputs = inputs
|
|
|
|
|
self.attrs = attrs
|
|
|
|
|
|
|
|
|
|
def infer(self):
|
|
|
|
|
"""Infer shape, type and format by op inputs"""
|
|
|
|
|
# 根据op的输入推断shape、type和format
|
|
|
|
|
self._check()
|
|
|
|
|
return self._infer_shape(), self._infer_type(), self._infer_format()
|
|
|
|
|
|
|
|
|
|
def _infer_shape(self):
|
|
|
|
|
# 根据op的输入推断shape
|
|
|
|
|
return self.inputs[0].shape
|
|
|
|
|
|
|
|
|
|
def _infer_type(self):
|
|
|
|
|
# 根据op的输入推断type
|
|
|
|
|
return self.inputs[0].dtype
|
|
|
|
|
|
|
|
|
|
def _infer_format(self):
|
|
|
|
|
# 根据op的输入推断format
|
|
|
|
|
return self.inputs[0].data_format
|
|
|
|
|
|
|
|
|
|
def _check(self):
|
|
|
|
|
# 检查shape、type和format
|
|
|
|
|
self._check_shape()
|
|
|
|
|
self._check_type()
|
|
|
|
|
self._check_format()
|
|
|
|
|
|
|
|
|
|
def _check_shape(self):
|
|
|
|
|
# 检查shape
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def _check_type(self):
|
|
|
|
|
"""check all dtypes are same"""
|
|
|
|
|
# 获取第一个输入的dtype
|
|
|
|
|
dtype = self.inputs[0].dtype
|
|
|
|
|
# 遍历剩下的输入
|
|
|
|
|
for i, t in enumerate(self.inputs[1:]):
|
|
|
|
|
# 如果当前输入的dtype与第一个输入的dtype不同,则抛出异常
|
|
|
|
|
if t.dtype != dtype:
|
|
|
|
|
raise GKException(
|
|
|
|
|
"Incompatible data type between input {}({}) and {}({})".format(0, dtype, i + 1, t.dtype))
|
|
|
|
|
|
|
|
|
|
def _check_format(self):
|
|
|
|
|
"""check formats are compatible. only DefaultFormat is compatible with others"""
|
|
|
|
|
# 获取第一个输入的data_format
|
|
|
|
|
result = self.inputs[0].data_format
|
|
|
|
|
# 初始化i为0
|
|
|
|
|
i = 0
|
|
|
|
|
# 遍历剩下的输入
|
|
|
|
|
for j, t in enumerate(self.inputs[1:]):
|
|
|
|
|
# 如果当前输入的data_format与第一个输入的data_format不同,则进行判断
|
|
|
|
|
if t.data_format != result:
|
|
|
|
|
# 如果第一个输入的data_format和当前输入的data_format都不是DefaultFormat,则抛出异常
|
|
|
|
|
if DF.DEFAULT not in (result, t.data_format):
|
|
|
|
|
raise GKException("Incompatible format between input {}({}) and {}({})".format(
|
|
|
|
|
i, result, j + 1, t.data_format))
|
|
|
|
|
# 如果第一个输入的data_format是DefaultFormat,则将result设置为当前输入的data_format,并将i设置为j+1
|
|
|
|
|
if result == DF.DEFAULT:
|
|
|
|
|
result = t.data_format
|
|
|
|
|
i = j + 1
|
|
|
|
@ -109,17 +133,26 @@ class _Elemwise(OpInfer):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def broadcast_shape(shapes):
|
|
|
|
|
"""deduce broadcast shape using same rules as numpy"""
|
|
|
|
|
# 计算所有shape的最大维度
|
|
|
|
|
dim_size = max(len(shape) for shape in shapes)
|
|
|
|
|
# 将所有shape扩展到最大维度,不足的部分用1填充
|
|
|
|
|
align_shapes = [[1] * (dim_size - len(shape)) + shape for shape in shapes]
|
|
|
|
|
# 初始化输出shape为全1
|
|
|
|
|
out_shape = [1] * dim_size
|
|
|
|
|
# 遍历每个维度
|
|
|
|
|
for i in range(dim_size):
|
|
|
|
|
# 遍历每个shape
|
|
|
|
|
for align_shape in align_shapes:
|
|
|
|
|
# 如果当前维度为1,则跳过
|
|
|
|
|
if align_shape[i] == 1:
|
|
|
|
|
continue
|
|
|
|
|
# 如果输出shape当前维度为1,则将输出shape当前维度设置为当前shape当前维度的值
|
|
|
|
|
if out_shape[i] == 1:
|
|
|
|
|
out_shape[i] = align_shape[i]
|
|
|
|
|
# 如果输出shape当前维度和当前shape当前维度不相等,则抛出异常
|
|
|
|
|
elif out_shape[i] != align_shape[i]:
|
|
|
|
|
raise GKException("Input shapes {} can not broadcast.".format(shapes))
|
|
|
|
|
# 返回输出shape
|
|
|
|
|
return out_shape
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@ -174,9 +207,12 @@ class _Elemwise(OpInfer):
|
|
|
|
|
.format(inputs_format))
|
|
|
|
|
|
|
|
|
|
def _infer_format(self):
|
|
|
|
|
# 遍历输入张量
|
|
|
|
|
for tensor in self.inputs:
|
|
|
|
|
# 如果张量的数据格式不是默认格式,则返回该数据格式
|
|
|
|
|
if tensor.data_format != DF.DEFAULT:
|
|
|
|
|
return tensor.data_format
|
|
|
|
|
# 如果所有输入张量的数据格式都是默认格式,则返回默认格式
|
|
|
|
|
return DF.DEFAULT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -184,6 +220,7 @@ class _Reduce(OpInfer):
|
|
|
|
|
"""Common infer for reduction operators"""
|
|
|
|
|
|
|
|
|
|
def _check(self):
|
|
|
|
|
# 调用父类的方法
|
|
|
|
|
super(_Reduce, self)._check()
|
|
|
|
|
# check reduce axis in the range [-len, len)
|
|
|
|
|
shape_len = len(self.inputs[0].shape)
|
|
|
|
@ -195,21 +232,29 @@ class _Reduce(OpInfer):
|
|
|
|
|
"Reduce axis should be in range [{},{}) but got {}".format(-shape_len, shape_len, axis))
|
|
|
|
|
|
|
|
|
|
def _infer_shape(self):
|
|
|
|
|
# 深度拷贝输入的形状
|
|
|
|
|
shape = copy.deepcopy(self.inputs[0].shape)
|
|
|
|
|
# 获取reduce_axis属性
|
|
|
|
|
axis = self.attrs['reduce_axis']
|
|
|
|
|
|
|
|
|
|
# 如果axis是整数,则将其转换为列表
|
|
|
|
|
if isinstance(axis, int):
|
|
|
|
|
axis = [axis]
|
|
|
|
|
# 如果axis中的元素小于0,则将其转换为非负数
|
|
|
|
|
if any(i < 0 for i in axis):
|
|
|
|
|
# change the axis to non-negative number.
|
|
|
|
|
# 将axis中的负数转换为正数
|
|
|
|
|
axis = list(map(lambda i: i + len(shape) if i < 0 else i, axis))
|
|
|
|
|
# 将axis排序
|
|
|
|
|
self.attrs['reduce_axis'] = sorted(axis)
|
|
|
|
|
|
|
|
|
|
# 如果keep_dims为True,则将axis中的维度设置为1
|
|
|
|
|
if self.attrs['keep_dims']:
|
|
|
|
|
for i in axis:
|
|
|
|
|
shape[i] = 1
|
|
|
|
|
return shape
|
|
|
|
|
|
|
|
|
|
# 如果keep_dims为False,则将axis中的维度从shape中移除
|
|
|
|
|
real_shape = []
|
|
|
|
|
for i, s in enumerate(shape):
|
|
|
|
|
if i not in axis:
|
|
|
|
@ -223,10 +268,14 @@ class _Reduce(OpInfer):
|
|
|
|
|
class _Reshape(OpInfer):
|
|
|
|
|
"""Common infer for reshape operators, should not be instantiated"""
|
|
|
|
|
|
|
|
|
|
# 定义一个函数,用于推断形状
|
|
|
|
|
def _infer_shape(self):
|
|
|
|
|
# 抛出一个异常,提示子类需要实现这个函数
|
|
|
|
|
raise GKException("_infer_shape should be implemented by subclass")
|
|
|
|
|
|
|
|
|
|
# 定义一个函数,用于推断格式
|
|
|
|
|
def _infer_format(self):
|
|
|
|
|
# 如果attrs中不存在"format"这个属性,则返回DF.DEFAULT,否则返回attrs中"format"的值
|
|
|
|
|
return DF.DEFAULT if "format" not in self.attrs else self.attrs["format"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -234,14 +283,20 @@ class Reshape(_Reshape):
|
|
|
|
|
"""Reshape op infer"""
|
|
|
|
|
|
|
|
|
|
def _check_shape(self):
|
|
|
|
|
# 获取输入形状
|
|
|
|
|
input_shape = self.inputs[0].shape
|
|
|
|
|
# 获取输出形状
|
|
|
|
|
output_shape = self.attrs["shape"]
|
|
|
|
|
# 计算输入形状的乘积
|
|
|
|
|
size_before_reshape = prod_reduce(lambda x, y: x * y, input_shape)
|
|
|
|
|
# 计算输出形状的乘积
|
|
|
|
|
size_after_reshape = prod_reduce(lambda x, y: x * y, output_shape)
|
|
|
|
|
# 如果输入形状的乘积不等于输出形状的乘积,则抛出异常
|
|
|
|
|
if size_before_reshape != size_after_reshape:
|
|
|
|
|
raise GKException("For 'Reshape', can not reshape {} to {}".format(input_shape, output_shape))
|
|
|
|
|
|
|
|
|
|
def _infer_shape(self):
|
|
|
|
|
# 返回输出形状
|
|
|
|
|
return self.attrs["shape"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -249,6 +304,7 @@ class Cast(_Elemwise):
|
|
|
|
|
"""Cast op infer"""
|
|
|
|
|
|
|
|
|
|
def _infer_type(self):
|
|
|
|
|
# 返回dst_type属性
|
|
|
|
|
return self.attrs["dst_type"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -256,29 +312,38 @@ class InplaceAssign(_Elemwise):
|
|
|
|
|
"""InplaceAssign op infer"""
|
|
|
|
|
|
|
|
|
|
def _infer_shape(self):
|
|
|
|
|
# 返回第3个输入的shape属性
|
|
|
|
|
return self.inputs[2].shape
|
|
|
|
|
|
|
|
|
|
def _infer_type(self):
|
|
|
|
|
# 返回第3个输入的dtype属性
|
|
|
|
|
return self.inputs[2].dtype
|
|
|
|
|
|
|
|
|
|
def _infer_format(self):
|
|
|
|
|
# 返回第3个输入的data_format属性
|
|
|
|
|
return self.inputs[2].data_format
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BroadcastTo(OpInfer):
|
|
|
|
|
"""BroadcastTo op infer"""
|
|
|
|
|
|
|
|
|
|
# 定义一个函数,用于推断形状
|
|
|
|
|
def _infer_shape(self):
|
|
|
|
|
# 返回self.attrs字典中的"shape"键对应的值
|
|
|
|
|
return self.attrs["shape"]
|
|
|
|
|
|
|
|
|
|
# 定义一个函数,用于推断格式
|
|
|
|
|
def _infer_format(self):
|
|
|
|
|
# 返回self.inputs列表中第一个元素的data_format属性
|
|
|
|
|
return self.inputs[0].data_format
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _CompareOp(_Elemwise):
|
|
|
|
|
"""Compare operators"""
|
|
|
|
|
|
|
|
|
|
# 定义一个函数,用于推断类型
|
|
|
|
|
def _infer_type(self):
|
|
|
|
|
# 返回类型为bool
|
|
|
|
|
return "bool"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -286,11 +351,14 @@ class CImag(OpInfer):
|
|
|
|
|
"""CImag op infer"""
|
|
|
|
|
|
|
|
|
|
def _check_type(self):
|
|
|
|
|
# 检查输入数据的类型是否为complex64
|
|
|
|
|
if self.inputs[0].dtype != "complex64":
|
|
|
|
|
# 如果不是,则抛出异常
|
|
|
|
|
raise GKException(
|
|
|
|
|
"For 'CImag', input[0] should be of type complex64, but got {}".format(self.inputs[0].dtype))
|
|
|
|
|
|
|
|
|
|
def _infer_type(self):
|
|
|
|
|
# 返回数据类型为float32
|
|
|
|
|
return "float32"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -298,11 +366,14 @@ class CReal(OpInfer):
|
|
|
|
|
"""CReal op infer"""
|
|
|
|
|
|
|
|
|
|
def _check_type(self):
|
|
|
|
|
# 检查输入数据的类型是否为complex64
|
|
|
|
|
if self.inputs[0].dtype != "complex64":
|
|
|
|
|
# 如果不是,则抛出异常
|
|
|
|
|
raise GKException(
|
|
|
|
|
"For 'CReal', input[0] should be of type complex64, but got {}".format(self.inputs[0].dtype))
|
|
|
|
|
|
|
|
|
|
def _infer_type(self):
|
|
|
|
|
# 返回数据类型为float32
|
|
|
|
|
return "float32"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -310,14 +381,17 @@ class Complex(OpInfer):
|
|
|
|
|
"""Complex op infer"""
|
|
|
|
|
|
|
|
|
|
def _check_type(self):
|
|
|
|
|
# 检查输入数据的类型是否为float32
|
|
|
|
|
if self.inputs[0].dtype != "float32":
|
|
|
|
|
raise GKException(
|
|
|
|
|
"For 'Complex', input[0] should be of type float32, but got {}".format(self.inputs[0].dtype))
|
|
|
|
|
# 检查输入数据的类型是否一致
|
|
|
|
|
if self.inputs[0].dtype != self.inputs[1].dtype:
|
|
|
|
|
raise GKException("For 'Complex', inputs data type mismatch ({} vs {})"
|
|
|
|
|
.format(self.inputs[0].dtype, self.inputs[1].dtype))
|
|
|
|
|
|
|
|
|
|
def _infer_type(self):
|
|
|
|
|
# 返回复数类型
|
|
|
|
|
return "complex64"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -345,40 +419,53 @@ class Select(_Elemwise):
|
|
|
|
|
"""Select op infer"""
|
|
|
|
|
|
|
|
|
|
def _check_type(self):
|
|
|
|
|
# 检查输入数据的类型
|
|
|
|
|
if self.inputs[0].dtype != "bool":
|
|
|
|
|
# 如果输入数据的类型不是bool,则抛出异常
|
|
|
|
|
raise GKException("For 'Select', input[0] should be of type bool, but got {}".format(self.inputs[0].dtype))
|
|
|
|
|
if self.inputs[1].dtype != self.inputs[2].dtype:
|
|
|
|
|
# 如果输入数据的类型不一致,则抛出异常
|
|
|
|
|
raise GKException("For 'Select', input[1] and input[2] data type mismatch ({} vs {})"
|
|
|
|
|
.format(self.inputs[1].dtype, self.inputs[2].dtype))
|
|
|
|
|
|
|
|
|
|
def _infer_type(self):
|
|
|
|
|
# 推断输入数据的类型
|
|
|
|
|
return self.inputs[1].dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_format_any(formats, checked_format):
|
|
|
|
|
"""Check whether input format in formats list"""
|
|
|
|
|
# 检查输入格式是否在formats列表中
|
|
|
|
|
if not isinstance(formats, (list, tuple)):
|
|
|
|
|
# 如果formats不是list或tuple类型,则抛出异常
|
|
|
|
|
raise GKException("formats {} should be of type list or tuple, but got {}.".format(formats, type(formats)))
|
|
|
|
|
if checked_format not in formats:
|
|
|
|
|
# 如果checked_format不在formats列表中,则抛出异常
|
|
|
|
|
raise GKException("Check {} failed: can not find it in {}".format(checked_format, formats))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_nd(data, nd):
|
|
|
|
|
"""Check whether data are nd format"""
|
|
|
|
|
# 检查数据是否为nd格式
|
|
|
|
|
if not isinstance(data, (list, tuple)) or len(data) != nd:
|
|
|
|
|
# 如果数据不是list或tuple类型,或者数据的维度不等于nd,则抛出异常
|
|
|
|
|
raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def conv_had_pad(pad_list, pad_mode):
|
|
|
|
|
"""Check whether conv need to add pad"""
|
|
|
|
|
# 检查pad_list是否为4D list或tuple,如果不是则抛出异常
|
|
|
|
|
if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4:
|
|
|
|
|
raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list))
|
|
|
|
|
# 如果pad_list的前两个元素不相等或后两个元素不相等,则返回True
|
|
|
|
|
if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]:
|
|
|
|
|
return True
|
|
|
|
|
# 如果pad_mode不是"VALID"或"valid",则遍历pad_list,如果有元素不为0,则返回True
|
|
|
|
|
if pad_mode not in ["VALID", "valid"]:
|
|
|
|
|
for _, pad in enumerate(pad_list):
|
|
|
|
|
if pad != 0:
|
|
|
|
|
return True
|
|
|
|
|
# 否则返回False
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -386,38 +473,50 @@ class Conv2D(OpInfer):
|
|
|
|
|
"""Conv2D infer"""
|
|
|
|
|
|
|
|
|
|
def _infer_type(self):
|
|
|
|
|
# 如果attrs是dict类型且包含"dst_type"键,则返回"dst_type"的值,否则返回输入的第一个元素的dtype
|
|
|
|
|
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
|
|
|
|
|
return self.attrs["dst_type"]
|
|
|
|
|
return self.inputs[0].dtype
|
|
|
|
|
|
|
|
|
|
def _infer_shape(self):
|
|
|
|
|
# 将输入的第一个和第二个元素的shape转换为list
|
|
|
|
|
shape_0 = list(self.inputs[0].shape)
|
|
|
|
|
shape_1 = list(self.inputs[1].shape)
|
|
|
|
|
# 检查shape_0和shape_1的维度是否为4
|
|
|
|
|
check_nd(shape_0, 4)
|
|
|
|
|
check_nd(shape_1, 4)
|
|
|
|
|
|
|
|
|
|
# 检查输入的data_format是否为NHWC
|
|
|
|
|
formats = [self.inputs[0].data_format, self.inputs[1].data_format, self.attrs["format"]]
|
|
|
|
|
check_format_any(formats, DF.NHWC)
|
|
|
|
|
|
|
|
|
|
# 获取输入的n、h、w和out_channel
|
|
|
|
|
n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0]
|
|
|
|
|
# 获取pad_list和pad_mode
|
|
|
|
|
pad_list = self.attrs["pad_list"]
|
|
|
|
|
pad_mode = self.attrs["pad_mode"]
|
|
|
|
|
# 获取kernel_size、stride和dilation
|
|
|
|
|
kernel_size = self.attrs["kernel_size"]
|
|
|
|
|
stride = self.attrs["stride"]
|
|
|
|
|
dilation = self.attrs["dilation"]
|
|
|
|
|
# 检查pad_list、kernel_size、stride和dilation的维度是否为4、2、4和4
|
|
|
|
|
check_nd(pad_list, 4)
|
|
|
|
|
check_nd(kernel_size, 2)
|
|
|
|
|
check_nd(stride, 4)
|
|
|
|
|
check_nd(dilation, 4)
|
|
|
|
|
|
|
|
|
|
# 调用conv_had_pad函数,判断是否需要pad
|
|
|
|
|
has_pad = conv_had_pad(pad_list, pad_mode)
|
|
|
|
|
# 如果不需要pad,则将pad_list设置为[0, 0, 0, 0]
|
|
|
|
|
if not has_pad:
|
|
|
|
|
pad_list = [0, 0, 0, 0]
|
|
|
|
|
|
|
|
|
|
# 计算输出的h和w
|
|
|
|
|
k_h = (kernel_size[0] - 1) * dilation[-2] + 1
|
|
|
|
|
k_w = (kernel_size[1] - 1) * dilation[-1] + 1
|
|
|
|
|
out_h = (h + pad_list[0] + pad_list[1] - k_h) // stride[-2] + 1
|
|
|
|
|
out_w = (w + pad_list[2] + pad_list[3] - k_w) // stride[-1] + 1
|
|
|
|
|
# 返回输出的shape
|
|
|
|
|
return [n, out_h, out_w, out_channel]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -425,23 +524,31 @@ class MatMul(OpInfer):
|
|
|
|
|
"""MatMul infer"""
|
|
|
|
|
|
|
|
|
|
def _infer_type(self):
|
|
|
|
|
# 如果attrs是dict类型且包含"dst_type"键,则返回"dst_type"的值,否则返回输入的第一个元素的dtype
|
|
|
|
|
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
|
|
|
|
|
return self.attrs["dst_type"]
|
|
|
|
|
return self.inputs[0].dtype
|
|
|
|
|
|
|
|
|
|
def _infer_shape(self):
|
|
|
|
|
# 将输入的第一个和第二个元素的shape转换为list
|
|
|
|
|
shape_0 = list(self.inputs[0].shape)
|
|
|
|
|
shape_1 = list(self.inputs[1].shape)
|
|
|
|
|
# 检查shape_0和shape_1的维度是否为2
|
|
|
|
|
if len(shape_0) != 2 or len(shape_1) != 2:
|
|
|
|
|
raise GKException("For 'MatMul', inputs shape must be 2D, but got {}, {}"
|
|
|
|
|
.format(shape_0, shape_1))
|
|
|
|
|
# 获取transpose_a和transpose_b
|
|
|
|
|
transpose_a = self.attrs["transpose_a"]
|
|
|
|
|
transpose_b = self.attrs["transpose_b"]
|
|
|
|
|
# 根据transpose_a和transpose_b获取m、k1、k2和n
|
|
|
|
|
m, k1 = (shape_0[-1], shape_0[-2]) if transpose_a else (shape_0[-2], shape_0[-1])
|
|
|
|
|
k2, n = (shape_1[-1], shape_1[-2]) if transpose_b else (shape_1[-2], shape_1[-1])
|
|
|
|
|
# 如果k1和k2不相等,则抛出异常
|
|
|
|
|
if k1 != k2:
|
|
|
|
|
raise GKException("For 'MatMul', inputs have different k value: {} vs {}".format(k1, k2))
|
|
|
|
|
# 计算输出的shape
|
|
|
|
|
output_shape = [m, n]
|
|
|
|
|
# 返回输出的shape
|
|
|
|
|
return output_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -449,14 +556,20 @@ class PadAkg(OpInfer):
|
|
|
|
|
"""PadAkg infer"""
|
|
|
|
|
|
|
|
|
|
def _infer_shape(self):
|
|
|
|
|
# 将输入的第一个元素的shape转换为list
|
|
|
|
|
shape = list(self.inputs[0].shape)
|
|
|
|
|
# 获取输入的维度
|
|
|
|
|
n = len(shape)
|
|
|
|
|
# 获取pad_before和pad_after
|
|
|
|
|
pad_before = list(self.attrs["head"])
|
|
|
|
|
pad_after = list(self.attrs["tail"])
|
|
|
|
|
# 检查pad_before和pad_after的维度是否与输入的维度相等
|
|
|
|
|
if len(pad_before) != n or len(pad_after) != n:
|
|
|
|
|
raise GKException("For 'PadAkg', input dimension and pad mismatch: {}d vs {}d vs {}d"
|
|
|
|
|
.format(n, len(pad_before), len(pad_after)))
|
|
|
|
|
# 计算输出的shape
|
|
|
|
|
out_shape = [shape[i] + pad_before[i] + pad_after[i] for i in range(n)]
|
|
|
|
|
# 返回输出的shape
|
|
|
|
|
return out_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -464,13 +577,19 @@ class UnPadAkg(OpInfer):
|
|
|
|
|
"""UnPadAkg infer"""
|
|
|
|
|
|
|
|
|
|
def _infer_shape(self):
|
|
|
|
|
# 将输入的第一个元素的shape转换为list
|
|
|
|
|
shape = list(self.inputs[0].shape)
|
|
|
|
|
# 获取输入的维度
|
|
|
|
|
n = len(shape)
|
|
|
|
|
# 获取unpad_after
|
|
|
|
|
unpad_after = list(self.attrs["tail"])
|
|
|
|
|
# 检查unpad_after的维度是否与输入的维度相等
|
|
|
|
|
if len(unpad_after) != n:
|
|
|
|
|
raise GKException("For 'UnPadAkg', input dimension and pad mismatch: {}d vs {}d"
|
|
|
|
|
.format(n, len(unpad_after)))
|
|
|
|
|
# 计算输出的shape
|
|
|
|
|
out_shape = [shape[i] - unpad_after[i] for i in range(n)]
|
|
|
|
|
# 返回输出的shape
|
|
|
|
|
return out_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -478,23 +597,32 @@ class Gather(OpInfer):
|
|
|
|
|
"""Gather infer"""
|
|
|
|
|
|
|
|
|
|
def _infer_shape(self):
|
|
|
|
|
# 获取输入的第一个和第二个元素的shape
|
|
|
|
|
input_shape = self.inputs[0].shape
|
|
|
|
|
indices_shape = self.inputs[1].shape
|
|
|
|
|
# 获取axis
|
|
|
|
|
axis = self.attrs['axis']
|
|
|
|
|
# 将输出的shape设置为输入的第一个元素的shape
|
|
|
|
|
output_shape = input_shape
|
|
|
|
|
# 计算indices_shape的维度
|
|
|
|
|
indices_shape_one_dim = 1
|
|
|
|
|
for dim in indices_shape:
|
|
|
|
|
indices_shape_one_dim *= dim
|
|
|
|
|
# 将输出的shape的axis维度设置为indices_shape的维度
|
|
|
|
|
output_shape[axis] = indices_shape_one_dim
|
|
|
|
|
# 返回输出的shape
|
|
|
|
|
return output_shape
|
|
|
|
|
|
|
|
|
|
def _infer_type(self):
|
|
|
|
|
# 返回输入的第一个元素的dtype
|
|
|
|
|
return self.inputs[0].dtype
|
|
|
|
|
|
|
|
|
|
def _infer_format(self):
|
|
|
|
|
# 返回输入的第一个元素的data_format
|
|
|
|
|
return self.inputs[0].data_format
|
|
|
|
|
|
|
|
|
|
def _check_type(self):
|
|
|
|
|
# 检查输入的第二个元素的dtype是否为int32,如果不是则抛出异常
|
|
|
|
|
if self.inputs[1].dtype != "int32":
|
|
|
|
|
raise GKException("For 'Gather', inputs[1] should be of type int32, but got {}"
|
|
|
|
|
.format(self.inputs[1].dtype))
|
|
|
|
|