_extends\graph_kernel\model

branch-yixin
yixin 7 months ago
parent f8389e877f
commit ffdf6162c7

@ -14,6 +14,9 @@
# ===========================================================================
"""GraphKernel cost model init"""
# 导入split模块
from .graph_split import split
# 导入GraphBuilder和load_composite模块
from .model_builder import GraphBuilder, load_composite
# 导入parallel_estimate模块
from .graph_parallel import parallel_estimate

@ -20,42 +20,103 @@ class ParalGain:
"""Paral Gain"""
def __init__(self, fusion_type, bottleneck, gain, block_assign, type_info):
"""
类的构造函数
Args:
fusion_type (str): 融合类型
bottleneck (int): 瓶颈层的大小
gain (float): 增益值
block_assign (list): 块分配列表
type_info (dict): 类型信息字典
Returns:
None
"""
# 初始化融合类型
self.fusion_type = fusion_type
# 初始化瓶颈层
self.bottleneck = bottleneck
# 初始化增益
self.gain = gain
# 初始化块分配
self.block_assign = block_assign
# 初始化类型信息
self.type_info = type_info
class ScheduleAnalyzer:
"""schedule analyzer"""
# 定义一个常量表示wrap的大小
WRAP_SIZE = 32
# 定义一个常量表示最大SM数量
MAX_SM = 80 # Volta
# 定义一个常量,表示最大线程数量
MAX_NUM_THREADS = 1024
# 定义一个常量表示最大block数量
MAX_BLOCK = 256
# 定义一个常量,表示流水线操作的阈值
PIPELINE_OP_THREADHOLD = 5
def __init__(self, graph):
"""
初始化图处理类
Args:
graph (Graph): 图对象用于存储图的结构和参数
Attributes:
graph (Graph): 图对象存储图的结构和参数
block_num (int): 块的数量初始值为0
block_weight (float): 块的权重初始值为0
ops (List[Operation]): 图的操作列表
dom_op (List[Operation]): 输出的每个操作对应的操作列表
"""
# 将传入的图对象赋值给实例变量graph
self.graph = graph
# 初始化block数量为0
self.block_num = 0
# 初始化block权重为0
self.block_weight = 0
# 通过图对象的deduce_parameters方法获取参数并赋值给outputs变量
_, outputs = graph.deduce_parameters()
# 将图对象的操作列表赋值给实例变量ops
self.ops = graph.ops
# 将outputs中的每个输出对应的操作收集到一个列表中并赋值给实例变量dom_op
self.dom_op = list(out.op for out in outputs)
@staticmethod
def prod(shape):
"""
计算形状乘积
Args:
shape (list): 一个包含整数的列表表示形状
Returns:
int: 形状乘积的结果
"""
"""Compute shape product"""
# 初始化结果变量为shape的第一个元素
res = shape[0]
# 遍历shape列表从第二个元素开始
for i in range(1, len(shape)):
# 将当前结果与shape的下一个元素相乘
res = res * shape[i]
# 返回计算后的结果
return res
def _cal_weight(self, ops):
# 初始化权重为0
weight = 0
for op in ops:
# 计算当前操作的权重
weight += self.prod(op.output.shape) * \
# 根据输出数据类型计算字节数
PrimLib.dtype_bytes(op.output.dtype)
# 返回计算得到的权重
return weight
def injective_analyze(self):

@ -24,39 +24,61 @@ class Utils:
@staticmethod
def get_attr_type(attr):
"""Get attr type"""
# 判断attr是否为bool类型
if isinstance(attr, bool):
return 'bool'
# 判断attr是否为str类型
if isinstance(attr, str):
return 'str'
# 判断attr是否为int类型
if isinstance(attr, int):
return 'int'
# 判断attr是否为float类型
if isinstance(attr, float):
return 'float'
# 判断attr是否为list或tuple类型
if isinstance(attr, (list, tuple)):
# 判断attr是否为空
if not attr:
raise ValueError("attr is invalid: the length of attr is 0")
# 判断attr的第一个元素是否为int类型
if isinstance(attr[0], int):
return 'listInt'
# 判断attr的第一个元素是否为str类型
if isinstance(attr[0], str):
return 'listStr'
# 如果attr的类型不在支持的列表中则抛出异常
raise ValueError("attr {} type {} is not in supported list ['bool', 'str', 'int', 'float', 'int' list, "
"'str' list]".format(attr, type(attr)))
class DataFormat:
"""DataFormat"""
# 默认格式
DEFAULT = "DefaultFormat"
# NC1KHKWHWC0格式
NC1KHKWHWC0 = "NC1KHKWHWC0"
# ND格式
ND = "ND"
# NCHW格式
NCHW = "NCHW"
# NHWC格式
NHWC = "NHWC"
# HWCN格式
HWCN = "HWCN"
# NC1HWC0格式
NC1HWC0 = "NC1HWC0"
# FRAC_Z格式
FRAC_Z = "FracZ"
# FRAC_NZ格式
FRAC_NZ = "FRACTAL_NZ"
# C1HWNCOC0格式
C1HWNCOC0 = "C1HWNCoC0"
# NC1HWC0_C04格式
NC1HWC0_C04 = "NC1HWC0_C04"
# FRACTAL_Z_C04格式
FRACTAL_Z_C04 = "FRACTAL_Z_C04"
# NDHWC格式
NDHWC = "NDHWC"
def __init__(self):
@ -65,29 +87,47 @@ class DataFormat:
class DataType:
"""Data Type"""
# 浮点型
FLOAT = "float"
# 半精度浮点型
FLOAT16 = "float16"
# 单精度浮点型
FLOAT32 = "float32"
# 双精度浮点型
FLOAT64 = "float64"
# 整型
INT = "int"
# 8位整型
INT8 = "int8"
# 16位整型
INT16 = "int16"
# 32位整型
INT32 = "int32"
# 64位整型
INT64 = "int64"
# 无符号整型
UINT = "uint"
# 8位无符号整型
UINT8 = "uint8"
# 16位无符号整型
UINT16 = "uint16"
# 32位无符号整型
UINT32 = "uint32"
# 64位无符号整型
UINT64 = "uint64"
# 布尔型
BOOL = "bool"
# 初始化函数
def __init__(self):
# 无需执行任何操作
pass
class PrimLib:
"""Prim lib"""
# 定义PrimLib类中的常量
UNKNOWN = 0
RESHAPE = 1
ELEMWISE = 2
@ -102,53 +142,73 @@ class PrimLib:
"""Prim"""
def __init__(self, iter_type, calibrate=1, relation_func=None):
# 初始化Prim类设置iter_type、calibrate和relation_func属性
self.iter_type = iter_type
self.calibrate = calibrate
self.relation_func = relation_func
if relation_func is None:
# 如果relation_func为None则设置默认的relation_func
self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
def default_reshape_relation(self, op, input_idx):
"""Process reshape relation"""
# 处理reshape关系
axis_relation, elem_relation = self.unknown_relation(op, input_idx)
# 将elem_relation设置为PrimLib.RESHAPE
elem_relation = [PrimLib.RESHAPE] * len(elem_relation)
return axis_relation, elem_relation
def default_elemwise_broadcast_relation(self, op, input_idx):
"""Process elemwise and broadcast relation"""
# 处理elemwise和broadcast关系
out_shape = op.output.shape
in_shape = op.inputs[input_idx].shape
# 如果输出形状的长度小于输入形状的长度,则抛出异常
if len(out_shape) < len(in_shape):
raise ValueError("For '{}', the input/output size is abnormal, as the length of output shape{} is less "
"than the length of input shape{}".format(op.prim, out_shape, in_shape))
axis_relation, elem_relation = [], []
# 计算输出形状和输入形状的长度差
delta = len(out_shape) - len(in_shape)
if delta > 0:
# 如果输出形状的长度大于输入形状的长度则在axis_relation和elem_relation中添加None
for i in range(0, delta):
axis_relation.append(None)
elem_relation.append(None)
# 遍历输入形状的每个元素
for i, _ in enumerate(in_shape):
# 在axis_relation中添加当前元素的索引
axis_relation.append(i)
# 如果输出形状的对应元素等于输入形状的对应元素则elem_relation添加PrimLib.ELEMWISE否则添加PrimLib.BROADCAST
elem_relation.append(
PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST)
return axis_relation, elem_relation
def default_reduce_relation(self, op, input_idx):
"""Process reduce relation"""
# 处理reduce关系
axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx)
# 遍历reduce_axis中的每个元素
for i in op.attrs['reduce_axis']:
# 将elem_relation中对应元素的值设置为PrimLib.REDUCE
elem_relation[i] = PrimLib.REDUCE
return axis_relation, elem_relation
def unknown_relation(self, op, input_idx):
"""Process unknown relation"""
# 获取输出和输入的形状
out_shape = op.output.shape
in_shape = op.inputs[input_idx].shape
# 获取所有可能的轴关系
all_relation = list(range(len(in_shape)))
# 初始化轴关系列表
axis_relation = [all_relation for i in range(0, len(out_shape))]
# 初始化元素关系列表
elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))]
# 返回轴关系和元素关系
return axis_relation, elem_relation
# 默认的关系函数列表
default_relation_func = [
unknown_relation,
default_reshape_relation,
@ -158,6 +218,7 @@ class PrimLib:
unknown_relation,
]
# 定义基本操作
primtives = {
'Add': Prim(ELEMWISE),
'Abs': Prim(ELEMWISE),
@ -239,7 +300,9 @@ class PrimLib:
@classmethod
def get_prim(cls, op):
"""Get op primtive"""
# 从cls.primtives中获取op.prim对应的prim
prim = cls.primtives.get(op.prim, None)
# 如果prim为None则打印警告信息并返回cls.default_primtive
if prim is None:
print('[WARN] primtive is not registered: ' + op.prim)
prim = cls.default_primtive
@ -248,50 +311,65 @@ class PrimLib:
@classmethod
def input_relation(cls, op, input_idx):
"""Get op's input_relation according to input_idx"""
# 调用cls.get_prim(op)获取op对应的prim然后调用prim的relation_func方法获取op的input_relation
return cls.get_prim(op).relation_func(op, input_idx)
@classmethod
def iter_type(cls, op):
"""Get op's iter type"""
# 调用cls.get_prim(op)获取op对应的prim然后返回prim的iter_type
return cls.get_prim(op).iter_type
@classmethod
def is_reduce(cls, op):
"""Check whether op's iter type is reduce"""
# 调用cls.get_prim(op)获取op对应的prim然后判断prim的iter_type是否为cls.REDUCE
return cls.get_prim(op).iter_type == cls.REDUCE
@classmethod
def calibrate_iter_size(cls, op, iter_size):
"""Get calibrate_iter_size"""
# 调用cls.get_prim(op)获取op对应的prim然后返回prim的calibrate乘以iter_size
return cls.get_prim(op).calibrate * iter_size
@classmethod
def dtype_bytes(cls, dtype):
"""Get dtype bytes"""
# 初始化bits和unit为1
bits, unit = 1, 1
# 从dtype的最后一个字符开始向前遍历
for i in range(len(dtype) - 1, 0, -1):
# 如果当前字符是数字则将bits加上当前字符对应的数字乘以unit并将unit乘以10
if dtype[i].isdecimal():
bits += int(dtype[i]) * unit
unit *= 10
# 如果当前字符不是数字,则跳出循环
else:
break
# 返回bits除以8的结果
return bits // 8
@classmethod
def inplace_reuse(cls, op, input_idx, start_axis=0):
"""Check whether op is inplace reuse"""
# 如果op.output.dtype的字节数大于op.inputs[input_idx].dtype的字节数则返回False
if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype):
return False
# 调用cls.get_prim(op)获取op对应的prim然后调用prim的relation_func方法获取op的input_relation
_, elem_relation = cls.get_prim(op).relation_func(op, input_idx)
# 从start_axis开始遍历elem_relation
for i in range(start_axis, len(elem_relation)):
# 如果elem_relation中的元素不等于cls.ELEMWISE则返回False
if elem_relation[i] != cls.ELEMWISE:
return False
# 如果以上条件都不满足则返回True
return True
class Tensor:
"""Tensor"""
# 参数类型常量
PARA_NONE = 0
PARA_INPUT = 1
PARA_OUTPUT = 2
@ -303,6 +381,7 @@ class Tensor:
self.members = [leader]
def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0):
# 初始化Tensor对象
self.name = name
self.shape = shape
self.dtype = dtype
@ -313,13 +392,16 @@ class Tensor:
self.buddy = None
def __str__(self):
# 返回Tensor对象的字符串表示
return self.name + str(list(self.shape))
def __repr__(self):
# 返回Tensor对象的字符串表示
return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
def get_size(self):
"""Get size"""
# 获取Tensor对象的大小
size = PrimLib.dtype_bytes(self.dtype)
for i in self.shape:
size *= i
@ -327,6 +409,7 @@ class Tensor:
def add_buddy(self, tensor):
"""Add buddy"""
# 添加buddy
if self.buddy is None:
self.buddy = self.Buddy(self)
self.buddy.members.append(tensor)
@ -337,6 +420,7 @@ class Value:
"""Value"""
def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT):
# 初始化Value对象
self.name = name
self.shape = [1]
self.dtype = dtype
@ -344,14 +428,17 @@ class Value:
self.data_format = data_format
def __str__(self):
# 返回Value对象的字符串表示
return self.name + str(list(self.shape))
def __repr__(self):
# 返回Value对象的字符串表示
return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
@staticmethod
def get_size():
"""Get size"""
# 获取Value对象的大小
return 1
@ -359,23 +446,31 @@ class Operator:
"""Operator"""
def __init__(self, primtive, inputs, output, attrs):
# 初始化Operator对象
self.prim = primtive
self.inputs = inputs
self.output = output
self.attrs = attrs
# 将当前Operator对象添加到每个输入的to_ops列表中
for t in inputs:
t.to_ops.append(self)
# 如果输出的op属性为None则将当前Operator对象赋值给输出的op属性
if output.op is None:
output.op = self
# 初始化all_inputs列表用于存储Tensor输入和Value输入
self.all_inputs = [] # include Tensor inputs and Value inputs.
def __str__(self):
# 将self.all_inputs中的元素转换为字符串并用逗号连接起来
args = ', '.join((str(t) for t in self.all_inputs))
# 构造表达式字符串
expr = "%s = %s.%s(%s) id:%s" % (
str(self.output), self.prim, self.output.dtype, args, id(self))
# 如果self.attrs不为空则返回表达式字符串和self.attrs的字符串连接
return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
def __repr__(self):
# 返回self的字符串表示
return str(self)
@ -383,6 +478,7 @@ class Graph:
"""Graph"""
def __init__(self, name, ops, stitch_info=None, recompute_ops=None):
# 初始化Graph对象
self.name = name
self.ops = ops # in topo order, can not use set
self.inputs = []
@ -393,10 +489,12 @@ class Graph:
def set_processor(self, processor):
"""Set processor"""
# 设置处理器
self.processor = processor
def add(self, ops):
"""Add ops"""
# 添加操作
if isinstance(ops, Operator):
self.ops.append(ops)
else:
@ -404,101 +502,148 @@ class Graph:
def extract_subgraph(self, graph_name, tensor_names, difference=False):
"""Extract subgraph from this graph"""
# 从当前图中提取子图
graph = Graph(graph_name, [])
outputs = set(tensor_names)
if difference:
# 如果difference为True则提取不在outputs中的操作
for op in self.ops:
if op.output.name not in outputs:
graph.add(op)
else:
# 如果difference为False则提取在outputs中的操作
for op in self.ops:
if op.output.name in outputs:
graph.add(op)
outputs.remove(op.output.name)
# 如果outputs中还有元素则抛出异常
for name in outputs:
raise ValueError("Invalid input tensor : {}, can not find it in graph".format(name))
return graph
def deduce_parameters(self):
"""Deduce parameters"""
# 初始化输入和输出列表
inputs, outputs = [], []
# 遍历所有操作
for op in self.ops:
# 遍历操作的所有输入
for t in op.inputs:
# 如果输入不在输入列表中,且输入的操作不在操作列表中,则将输入添加到输入列表中
if t not in inputs and t.op not in self.ops:
inputs.append(t)
# 如果操作输出已经在输出列表中,则跳过
if op.output in outputs:
continue
# 如果操作输出是输出参数类型,或者操作输出没有后续操作,则将操作输出添加到输出列表中
if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
outputs.append(op.output)
continue
# 如果操作输出的后续操作不在操作列表中,则将操作输出添加到输出列表中
if any((succ not in self.ops for succ in op.output.to_ops)):
outputs.append(op.output)
# 如果有指定的输入,则将指定的输入赋值给输入列表
if self.inputs:
inputs = self.inputs
# 如果有指定的输出,则将指定的输出赋值给输出列表
if self.outputs:
outputs = self.outputs
# 返回输入和输出列表
return inputs, outputs
def __str__(self):
# 调用deduce_parameters方法获取输入和输出列表
inputs, outputs = self.deduce_parameters()
# 将输入列表转换为字符串
para_str = ', '.join((repr(t) for t in inputs))
# 将输出列表转换为字符串
out_str = ', '.join((repr(t) for t in outputs))
# 初始化行列表
lines = []
# 添加操作名称、输入和输出到行列表中
lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
# 如果有拼接信息,则添加拼接操作和拼接原子操作到行列表中
if self.stitch_info:
if self.stitch_info.stitch_ops:
lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops))
if self.stitch_info.stitch_atomic_ops:
lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops))
# 遍历所有操作,将操作添加到行列表中
for op in self.ops:
lines.append(' ' + str(op))
# 添加结束符号到行列表中
lines.append('}')
# 将行列表转换为字符串并返回
return '\n'.join(lines)
def __repr__(self):
# 返回对象的字符串表示
return str(self)
def dump(self):
"""Dump Graph to json"""
# 将Graph转换为json格式
attr_name = {'reduce_axis': 'axis'}
# 获取Graph的输入和输出参数
inputs, outputs = self.deduce_parameters()
input_desc, output_desc, op_desc = [], [], []
# 遍历输入参数
for t in inputs:
# 将输入参数转换为字典格式
input_desc.append([{'data_type': t.dtype, 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format}])
# 遍历输出参数
for t in outputs:
# 将输出参数转换为字典格式
output_desc.append({'data_type': t.dtype, 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format})
# 遍历Graph中的操作
for op in self.ops:
attrs, in_desc = [], []
# 遍历操作中的属性
for a in op.attrs:
# 获取属性名
name = attr_name.get(a, a)
# 将属性转换为字典格式
attrs.append(
{'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])})
# 遍历操作中的输入
for t in op.all_inputs:
# 如果输入是Tensor类型
if isinstance(t, Tensor):
# 将输入转换为字典格式
in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format}])
else:
# 将输入转换为字典格式
in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format}])
# 将操作输出转换为字典格式
out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
'tensor_name': op.output.name, 'format': op.output.data_format}]
# 将操作转换为字典格式
op_desc.append({'attr': attrs, 'impl_path': '',
'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
# 将Graph转换为字典格式
graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
'platform': 'AKG', 'process': self.processor}
# 如果Graph中有stitch信息
if self.stitch_info and self.stitch_info.stitch_ops:
# 将stitch信息转换为字典格式
buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)}
# 如果有stitch_atomic_ops
if self.stitch_info.stitch_atomic_ops:
# 将stitch_atomic_ops转换为字典格式
buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops)
# 将stitch信息添加到Graph字典中
graph_desc['buffer_stitch'] = buffer_stitch
# 返回Graph字典
return graph_desc
@ -506,13 +651,16 @@ class GraphVisitor:
"""Graph visitor"""
def __init__(self, forward=True):
# 初始化forward参数默认为True
self.forward = forward
def visit_graph(self, graph):
"""Visit graph"""
# 如果forward为True则按照顺序遍历graph中的ops
if self.forward:
for op in graph.ops:
self.visit(op)
# 如果forward为False则按照逆序遍历graph中的ops
else:
for i in range(len(graph.ops)-1, -1, -1):
self.visit(graph.ops[i])
@ -528,12 +676,18 @@ class AlignShape(GraphVisitor):
def visit(op):
"""Visit op node"""
prim = PrimLib.get_prim(op)
# 如果op的迭代类型是ELEMWISE、BROADCAST或REDUCE则需要进行形状对齐
if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
# 获取op的输出维度
out_dim = len(op.output.shape)
# 初始化对齐维度为输出维度
align_dim = out_dim
# 遍历op的输入
for t in op.inputs:
# 如果输入的维度大于对齐维度,则更新对齐维度
if len(t.shape) > align_dim:
align_dim = len(t.shape)
# 如果对齐维度大于输出维度则对op的输出形状进行对齐
if align_dim > out_dim:
op.output.shape = [1] * (align_dim - out_dim) + op.output.shape

@ -25,90 +25,140 @@ class GraphBuilder:
"""Graph wrapper"""
def __init__(self, name):
"""
初始化类实例
Args:
name (str): 图的名称
Attributes:
self.graph (Graph): 图的实例使用传入的名称初始化
"""
self.graph = Graph(name, [])
def set_input(self, *para):
"""set input to graph inputs"""
# 遍历传入的参数
for t in para:
# 设置参数类型为输入参数
t.para_type = Tensor.PARA_INPUT
# 将参数添加到图的输入列表中
self.graph.inputs.append(t)
def set_output(self, *para):
"""set output to graph inputs"""
# 遍历传入的参数
for t in para:
# 设置参数类型为输出参数
t.para_type = Tensor.PARA_OUTPUT
# 将参数添加到图的输出列表中
self.graph.outputs.append(t)
def __init__(self):
# 初始化图列表
self.graphs = []
# 当前图设置为None
self.current = None
# 初始化名称ID
self.name_id = 0
def _alloc_tensor_name(self):
# 获取当前名称ID
tid = self.name_id
# 名称ID加1
self.name_id += 1
# 格式化字符串,生成张量名称
return "t%d" % (tid)
def graph_scope(self, name):
"""The graph scope to be processed"""
# 定义GraphScope类
class GraphScope:
"""Graph Scope"""
def __init__(self, gb):
# 初始化GraphScope对象接收一个GraphBuilder对象
self.gb = gb
def __enter__(self):
# 当使用with语句进入GraphScope上下文时调用
return self.gb.current
def __exit__(self, ptype, value, trace):
# 当离开GraphScope上下文时调用
self.gb.graphs.append(self.gb.current.graph)
self.gb.current = None
# 检查self.current是否不为None
if self.current is not None:
raise ValueError("self.current is not None!")
# 创建GraphWrapper对象并赋值给self.current
self.current = self.GraphWrapper(name)
# 返回GraphScope对象
return GraphScope(self)
def tensor(self, shape, dtype, data_format="DefaultFormat", name=None, para_type=Tensor.PARA_NONE):
"""Create a new Tensor"""
"""创建一个新的张量"""
# 如果名称为空或None则分配一个新的张量名称
if name in (None, ''):
# 分配一个新的张量名称
name = self._alloc_tensor_name()
# 如果shape为空则默认设置为[1]
if not shape:
shape = [1]
# 返回创建好的张量对象
return Tensor(name, shape, dtype, data_format, para_type=para_type)
def value(self, dtype, value, name=None):
"""Create a new Value"""
# 如果name为None或空字符串
if name in (None, ''):
# 分配一个新的tensor名称
name = self._alloc_tensor_name()
# 创建一个新的Value对象
v = Value(name, dtype, value)
# 返回创建的Value对象
return v
def op(self, prim, output, inputs, attrs=None):
"""Insert an operator into graph"""
# 如果 attrs 为 None则将其设置为空字典
if attrs is None:
attrs = {}
# 如果 inputs 是 Tensor 类型,则将其转换为列表
if isinstance(inputs, Tensor):
inputs = [inputs]
# 过滤出 inputs 中 Tensor 类型的元素
tensor_inputs = [t for t in inputs if isinstance(t, Tensor)]
# 创建一个 Operator 对象
node = Operator(prim, tensor_inputs, output, attrs)
# 将所有输入保存到 node 的 all_inputs 属性中
node.all_inputs = inputs
# 将 node 添加到当前图的节点列表中
self.current.graph.add(node)
def emit(self, prim, inputs, name=None, attrs=None):
"""Emit a new operation"""
# 如果attrs为None则初始化为空字典
if attrs is None:
attrs = {}
# 如果inputs是Tensor或Value的实例则将其转换为列表
if isinstance(inputs, (Tensor, Value)):
inputs = [inputs]
# 过滤出inputs中的Tensor和Value实例
tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))]
# 调用op_infer.infer函数进行形状、数据类型和格式的推断
out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs)
# 创建一个新的Tensor实例作为输出
output = self.tensor(out_shape, out_dtype, out_format, name)
# 执行操作并将结果存储在output中
self.op(prim, output, inputs, attrs)
# 返回操作的结果
return output
def get(self):
"""Get graphs"""
# 返回self.graphs
return self.graphs
@ -116,16 +166,21 @@ class CompositeGraph:
"""Composite Graph"""
def __init__(self):
# 初始化图对象默认为None
self.graph = None
# 初始化描述信息默认为None
self.desc = None
self.tensors = {} # name : Tensor
# 初始化张量字典,默认为空字典
self.tensors = {}
def refine(self):
"""Refine Graph"""
# 对图进行形状对齐操作
AlignShape().visit_graph(self.graph)
def load(self, desc):
"""Load Graph from json"""
# 定义一个内部函数,用于处理操作属性
def _attr_of(op):
if not op['attr']:
return dict()
@ -134,21 +189,29 @@ class CompositeGraph:
if a['name'] == 'axis' and op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin', 'Argmax', 'Argmin'):
attr['reduce_axis'] = a['value']
else:
# 将属性添加到字典中
attr[a['name']] = a['value']
return attr
# 创建GraphBuilder对象
builder = GraphBuilder()
# 在描述的操作范围内构建图
with builder.graph_scope(desc['op']):
# 遍历输入描述并构建输入张量
for in_desc in desc['input_desc'] if desc['input_desc'] is not None else []:
name, shape, dtype, data_format = in_desc[0]['tensor_name'], in_desc[
0]['shape'], in_desc[0]['data_type'], in_desc[0]['format']
# 将输入张量添加到tensors字典中
self.tensors[name] = builder.tensor(
shape, dtype, data_format, name=name, para_type=Tensor.PARA_INPUT)
# 遍历输出描述并构建输出张量
for out_desc in desc['output_desc']:
name, shape, dtype, data_format = out_desc['tensor_name'], out_desc[
'shape'], out_desc['data_type'], out_desc['format']
# 将输出张量添加到tensors字典中
self.tensors[name] = builder.tensor(
shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT)
# 遍历操作描述并构建操作
for op in desc['op_desc']:
inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d]
out_desc = op['output_desc']
@ -164,52 +227,77 @@ class CompositeGraph:
if not output:
output = builder.tensor(shape, dtype, data_format, name=name)
self.tensors[name] = output
# 构建操作并添加到图中
builder.op(op['name'], output, inputs, attrs=_attr_of(op))
# 获取构建好的图
self.graph = builder.get()[0]
self.desc = desc
def add_stitch_info(self, subgraph, desc):
"""add stitch info to desc"""
# 如果subgraph包含stitch信息且stitch_ops不为空
if subgraph.stitch_info and subgraph.stitch_info.stitch_ops:
# 创建一个字典用于存储stitch操作信息
buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)}
# 如果subgraph包含stitch_atomic_ops信息
if subgraph.stitch_info.stitch_atomic_ops:
# 将stitch_atomic_ops信息添加到buffer_stitch字典中
buffer_stitch['stitch_atomic_op'] = list(subgraph.stitch_info.stitch_atomic_ops)
# 将buffer_stitch信息添加到desc字典中
desc['buffer_stitch'] = buffer_stitch
return desc
def add_recompute_ops(self, subgraph, desc):
"""add recompute ops to desc"""
# 如果subgraph中包含需要重新计算的操作
if subgraph.recompute_ops:
# 将需要重新计算的操作的输出名称添加到desc中
desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops]
return desc
def _pre_dump(self, outputs):
"""restore name to before load"""
# 创建一个空字典用于存储inplace赋值操作
inplace_assign = {} # y_name, output_name
inplace_assign_z = None
# 遍历self.desc['op_desc']中的操作
for op in self.desc['op_desc']:
# 如果操作名称为'InplaceAssign'
if op['name'] == 'InplaceAssign':
# 将inplace赋值操作的输入tensor名作为键输出tensor名作为值存入inplace_assign字典
inplace_assign[op['input_desc'][1][0]['tensor_name']] = op['output_desc'][0]['tensor_name']
# 如果inplace_assign字典不为空
if inplace_assign:
# 遍历outputs中的tensor
for t in outputs:
# 如果当前tensor的名称不在inplace_assign字典中
if t.name not in inplace_assign:
# 将当前tensor赋值给inplace_assign_z
inplace_assign_z = t
# 返回inplace_assign和inplace_assign_z
return inplace_assign, inplace_assign_z
def dump(self, subgraph):
"""Dump Graph to json"""
desc = {}
# 获取输入和输出参数
inputs, outputs = subgraph.deduce_parameters()
# 获取图中的所有操作
graph_ops = set(subgraph.ops)
# 预处理输出参数
inplace_assign, inplace_assign_z = self._pre_dump(outputs)
def dump_output(t):
# 如果输出参数是原地赋值操作的结果
if t.name in inplace_assign:
z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name]
# 返回包含数据类型、形状和张量名称的字典
return {'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign.get(t.name)}
# 返回包含数据类型、形状和张量名称的字典
return {'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}
def dump_op_desc(d):
# 如果操作是原地赋值操作
if d['name'] == 'InplaceAssign':
y = d['input_desc'][1][0]['tensor_name']
if self.tensors[y].op in graph_ops:
@ -222,33 +310,50 @@ class CompositeGraph:
z_desc['tensor_name'] = z.name
out_desc['shape'] = z.shape
out_desc['data_type'] = z.dtype
# 返回处理后的原地赋值操作描述
return inplace_desc
# 获取操作对应的张量
op = self.tensors[d['output_desc'][0]['tensor_name']].op
# 如果操作在图操作集或重新计算操作集中
if op in graph_ops or op in subgraph.recompute_ops:
# 返回操作描述
return d
# 返回None
return None
for key in self.desc.keys():
if key == 'input_desc':
# 处理输入描述
desc[key] = [[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs]
elif key == 'output_desc':
# 处理输出描述
desc[key] = list(map(dump_output, outputs))
elif key == 'op_desc':
# 处理操作描述
op_desc = map(dump_op_desc, self.desc[key])
desc[key] = [d for d in op_desc if d is not None]
elif key == 'op':
# 处理操作名称
desc[key] = subgraph.name
else:
# 处理其他描述
desc[key] = self.desc[key]
# 添加缝合信息
desc = self.add_stitch_info(subgraph, desc)
# 添加重新计算操作信息
desc = self.add_recompute_ops(subgraph, desc)
# 返回最终描述
return desc
def load_composite(desc):
"""Load composite kernel"""
# 创建一个CompositeGraph对象
composite = CompositeGraph()
# 加载描述信息
composite.load(desc)
# 对加载的CompositeGraph进行细化
composite.refine()
# 返回处理后的CompositeGraph对象
return composite

@ -25,24 +25,32 @@ def infer(op_name, inputs, attrs):
"""infer shape dtype and format"""
def _create_opinfer():
# 获取当前模块
self_module = sys.modules.get(__name__, None)
# 如果当前模块为空,则抛出异常
if self_module is None:
raise GKException("OpInfo does not support op {}".format(op_name))
# 如果当前模块有op_name属性则获取该属性
if hasattr(self_module, op_name):
op_cls = getattr(self_module, op_name)
return op_cls(op_name, inputs, attrs)
# common infer
# 定义一个字典将PrimLib中的iter_type映射到对应的类名
class_name_map = {
PrimLib.ELEMWISE: "_Elemwise",
PrimLib.REDUCE: "_Reduce",
}
# 获取op_name对应的iter_type
cls_name = class_name_map.get(PrimLib.primtives.get(op_name, PrimLib.default_primtive).iter_type, None)
# 如果没有对应的iter_type则抛出异常
if not cls_name:
raise GKException("OpInfo does not support op {}".format(op_name))
# 获取对应的类
op_cls = getattr(self_module, cls_name)
return op_cls(op_name, inputs, attrs)
# 返回infer方法
return _create_opinfer().infer()
@ -55,49 +63,65 @@ class OpInfer:
"""
def __init__(self, name, inputs, attrs):
# 初始化函数传入参数name、inputs、attrs
self.name = name
self.inputs = inputs
self.attrs = attrs
def infer(self):
"""Infer shape, type and format by op inputs"""
# 根据op的输入推断shape、type和format
self._check()
return self._infer_shape(), self._infer_type(), self._infer_format()
def _infer_shape(self):
# 根据op的输入推断shape
return self.inputs[0].shape
def _infer_type(self):
# 根据op的输入推断type
return self.inputs[0].dtype
def _infer_format(self):
# 根据op的输入推断format
return self.inputs[0].data_format
def _check(self):
# 检查shape、type和format
self._check_shape()
self._check_type()
self._check_format()
def _check_shape(self):
# 检查shape
pass
def _check_type(self):
"""check all dtypes are same"""
# 获取第一个输入的dtype
dtype = self.inputs[0].dtype
# 遍历剩下的输入
for i, t in enumerate(self.inputs[1:]):
# 如果当前输入的dtype与第一个输入的dtype不同则抛出异常
if t.dtype != dtype:
raise GKException(
"Incompatible data type between input {}({}) and {}({})".format(0, dtype, i + 1, t.dtype))
def _check_format(self):
"""check formats are compatible. only DefaultFormat is compatible with others"""
# 获取第一个输入的data_format
result = self.inputs[0].data_format
# 初始化i为0
i = 0
# 遍历剩下的输入
for j, t in enumerate(self.inputs[1:]):
# 如果当前输入的data_format与第一个输入的data_format不同则进行判断
if t.data_format != result:
# 如果第一个输入的data_format和当前输入的data_format都不是DefaultFormat则抛出异常
if DF.DEFAULT not in (result, t.data_format):
raise GKException("Incompatible format between input {}({}) and {}({})".format(
i, result, j + 1, t.data_format))
# 如果第一个输入的data_format是DefaultFormat则将result设置为当前输入的data_format并将i设置为j+1
if result == DF.DEFAULT:
result = t.data_format
i = j + 1
@ -109,17 +133,26 @@ class _Elemwise(OpInfer):
@staticmethod
def broadcast_shape(shapes):
"""deduce broadcast shape using same rules as numpy"""
# 计算所有shape的最大维度
dim_size = max(len(shape) for shape in shapes)
# 将所有shape扩展到最大维度不足的部分用1填充
align_shapes = [[1] * (dim_size - len(shape)) + shape for shape in shapes]
# 初始化输出shape为全1
out_shape = [1] * dim_size
# 遍历每个维度
for i in range(dim_size):
# 遍历每个shape
for align_shape in align_shapes:
# 如果当前维度为1则跳过
if align_shape[i] == 1:
continue
# 如果输出shape当前维度为1则将输出shape当前维度设置为当前shape当前维度的值
if out_shape[i] == 1:
out_shape[i] = align_shape[i]
# 如果输出shape当前维度和当前shape当前维度不相等则抛出异常
elif out_shape[i] != align_shape[i]:
raise GKException("Input shapes {} can not broadcast.".format(shapes))
# 返回输出shape
return out_shape
@staticmethod
@ -174,9 +207,12 @@ class _Elemwise(OpInfer):
.format(inputs_format))
def _infer_format(self):
# 遍历输入张量
for tensor in self.inputs:
# 如果张量的数据格式不是默认格式,则返回该数据格式
if tensor.data_format != DF.DEFAULT:
return tensor.data_format
# 如果所有输入张量的数据格式都是默认格式,则返回默认格式
return DF.DEFAULT
@ -184,6 +220,7 @@ class _Reduce(OpInfer):
"""Common infer for reduction operators"""
def _check(self):
# 调用父类的方法
super(_Reduce, self)._check()
# check reduce axis in the range [-len, len)
shape_len = len(self.inputs[0].shape)
@ -195,21 +232,29 @@ class _Reduce(OpInfer):
"Reduce axis should be in range [{},{}) but got {}".format(-shape_len, shape_len, axis))
def _infer_shape(self):
# 深度拷贝输入的形状
shape = copy.deepcopy(self.inputs[0].shape)
# 获取reduce_axis属性
axis = self.attrs['reduce_axis']
# 如果axis是整数则将其转换为列表
if isinstance(axis, int):
axis = [axis]
# 如果axis中的元素小于0则将其转换为非负数
if any(i < 0 for i in axis):
# change the axis to non-negative number.
# 将axis中的负数转换为正数
axis = list(map(lambda i: i + len(shape) if i < 0 else i, axis))
# 将axis排序
self.attrs['reduce_axis'] = sorted(axis)
# 如果keep_dims为True则将axis中的维度设置为1
if self.attrs['keep_dims']:
for i in axis:
shape[i] = 1
return shape
# 如果keep_dims为False则将axis中的维度从shape中移除
real_shape = []
for i, s in enumerate(shape):
if i not in axis:
@ -223,10 +268,14 @@ class _Reduce(OpInfer):
class _Reshape(OpInfer):
"""Common infer for reshape operators, should not be instantiated"""
# 定义一个函数,用于推断形状
def _infer_shape(self):
# 抛出一个异常,提示子类需要实现这个函数
raise GKException("_infer_shape should be implemented by subclass")
# 定义一个函数,用于推断格式
def _infer_format(self):
# 如果attrs中不存在"format"这个属性则返回DF.DEFAULT否则返回attrs中"format"的值
return DF.DEFAULT if "format" not in self.attrs else self.attrs["format"]
@ -234,14 +283,20 @@ class Reshape(_Reshape):
"""Reshape op infer"""
def _check_shape(self):
# 获取输入形状
input_shape = self.inputs[0].shape
# 获取输出形状
output_shape = self.attrs["shape"]
# 计算输入形状的乘积
size_before_reshape = prod_reduce(lambda x, y: x * y, input_shape)
# 计算输出形状的乘积
size_after_reshape = prod_reduce(lambda x, y: x * y, output_shape)
# 如果输入形状的乘积不等于输出形状的乘积,则抛出异常
if size_before_reshape != size_after_reshape:
raise GKException("For 'Reshape', can not reshape {} to {}".format(input_shape, output_shape))
def _infer_shape(self):
# 返回输出形状
return self.attrs["shape"]
@ -249,6 +304,7 @@ class Cast(_Elemwise):
"""Cast op infer"""
def _infer_type(self):
# 返回dst_type属性
return self.attrs["dst_type"]
@ -256,29 +312,38 @@ class InplaceAssign(_Elemwise):
"""InplaceAssign op infer"""
def _infer_shape(self):
# 返回第3个输入的shape属性
return self.inputs[2].shape
def _infer_type(self):
# 返回第3个输入的dtype属性
return self.inputs[2].dtype
def _infer_format(self):
# 返回第3个输入的data_format属性
return self.inputs[2].data_format
class BroadcastTo(OpInfer):
"""BroadcastTo op infer"""
# 定义一个函数,用于推断形状
def _infer_shape(self):
# 返回self.attrs字典中的"shape"键对应的值
return self.attrs["shape"]
# 定义一个函数,用于推断格式
def _infer_format(self):
# 返回self.inputs列表中第一个元素的data_format属性
return self.inputs[0].data_format
class _CompareOp(_Elemwise):
"""Compare operators"""
# 定义一个函数,用于推断类型
def _infer_type(self):
# 返回类型为bool
return "bool"
@ -286,11 +351,14 @@ class CImag(OpInfer):
"""CImag op infer"""
def _check_type(self):
# 检查输入数据的类型是否为complex64
if self.inputs[0].dtype != "complex64":
# 如果不是,则抛出异常
raise GKException(
"For 'CImag', input[0] should be of type complex64, but got {}".format(self.inputs[0].dtype))
def _infer_type(self):
# 返回数据类型为float32
return "float32"
@ -298,11 +366,14 @@ class CReal(OpInfer):
"""CReal op infer"""
def _check_type(self):
# 检查输入数据的类型是否为complex64
if self.inputs[0].dtype != "complex64":
# 如果不是,则抛出异常
raise GKException(
"For 'CReal', input[0] should be of type complex64, but got {}".format(self.inputs[0].dtype))
def _infer_type(self):
# 返回数据类型为float32
return "float32"
@ -310,14 +381,17 @@ class Complex(OpInfer):
"""Complex op infer"""
def _check_type(self):
# 检查输入数据的类型是否为float32
if self.inputs[0].dtype != "float32":
raise GKException(
"For 'Complex', input[0] should be of type float32, but got {}".format(self.inputs[0].dtype))
# 检查输入数据的类型是否一致
if self.inputs[0].dtype != self.inputs[1].dtype:
raise GKException("For 'Complex', inputs data type mismatch ({} vs {})"
.format(self.inputs[0].dtype, self.inputs[1].dtype))
def _infer_type(self):
# 返回复数类型
return "complex64"
@ -345,40 +419,53 @@ class Select(_Elemwise):
"""Select op infer"""
def _check_type(self):
# 检查输入数据的类型
if self.inputs[0].dtype != "bool":
# 如果输入数据的类型不是bool则抛出异常
raise GKException("For 'Select', input[0] should be of type bool, but got {}".format(self.inputs[0].dtype))
if self.inputs[1].dtype != self.inputs[2].dtype:
# 如果输入数据的类型不一致,则抛出异常
raise GKException("For 'Select', input[1] and input[2] data type mismatch ({} vs {})"
.format(self.inputs[1].dtype, self.inputs[2].dtype))
def _infer_type(self):
# 推断输入数据的类型
return self.inputs[1].dtype
def check_format_any(formats, checked_format):
"""Check whether input format in formats list"""
# 检查输入格式是否在formats列表中
if not isinstance(formats, (list, tuple)):
# 如果formats不是list或tuple类型则抛出异常
raise GKException("formats {} should be of type list or tuple, but got {}.".format(formats, type(formats)))
if checked_format not in formats:
# 如果checked_format不在formats列表中则抛出异常
raise GKException("Check {} failed: can not find it in {}".format(checked_format, formats))
def check_nd(data, nd):
"""Check whether data are nd format"""
# 检查数据是否为nd格式
if not isinstance(data, (list, tuple)) or len(data) != nd:
# 如果数据不是list或tuple类型或者数据的维度不等于nd则抛出异常
raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data))
def conv_had_pad(pad_list, pad_mode):
"""Check whether conv need to add pad"""
# 检查pad_list是否为4D list或tuple如果不是则抛出异常
if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4:
raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list))
# 如果pad_list的前两个元素不相等或后两个元素不相等则返回True
if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]:
return True
# 如果pad_mode不是"VALID"或"valid"则遍历pad_list如果有元素不为0则返回True
if pad_mode not in ["VALID", "valid"]:
for _, pad in enumerate(pad_list):
if pad != 0:
return True
# 否则返回False
return False
@ -386,38 +473,50 @@ class Conv2D(OpInfer):
"""Conv2D infer"""
def _infer_type(self):
# 如果attrs是dict类型且包含"dst_type"键,则返回"dst_type"的值否则返回输入的第一个元素的dtype
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
return self.attrs["dst_type"]
return self.inputs[0].dtype
def _infer_shape(self):
# 将输入的第一个和第二个元素的shape转换为list
shape_0 = list(self.inputs[0].shape)
shape_1 = list(self.inputs[1].shape)
# 检查shape_0和shape_1的维度是否为4
check_nd(shape_0, 4)
check_nd(shape_1, 4)
# 检查输入的data_format是否为NHWC
formats = [self.inputs[0].data_format, self.inputs[1].data_format, self.attrs["format"]]
check_format_any(formats, DF.NHWC)
# 获取输入的n、h、w和out_channel
n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0]
# 获取pad_list和pad_mode
pad_list = self.attrs["pad_list"]
pad_mode = self.attrs["pad_mode"]
# 获取kernel_size、stride和dilation
kernel_size = self.attrs["kernel_size"]
stride = self.attrs["stride"]
dilation = self.attrs["dilation"]
# 检查pad_list、kernel_size、stride和dilation的维度是否为4、2、4和4
check_nd(pad_list, 4)
check_nd(kernel_size, 2)
check_nd(stride, 4)
check_nd(dilation, 4)
# 调用conv_had_pad函数判断是否需要pad
has_pad = conv_had_pad(pad_list, pad_mode)
# 如果不需要pad则将pad_list设置为[0, 0, 0, 0]
if not has_pad:
pad_list = [0, 0, 0, 0]
# 计算输出的h和w
k_h = (kernel_size[0] - 1) * dilation[-2] + 1
k_w = (kernel_size[1] - 1) * dilation[-1] + 1
out_h = (h + pad_list[0] + pad_list[1] - k_h) // stride[-2] + 1
out_w = (w + pad_list[2] + pad_list[3] - k_w) // stride[-1] + 1
# 返回输出的shape
return [n, out_h, out_w, out_channel]
@ -425,23 +524,31 @@ class MatMul(OpInfer):
"""MatMul infer"""
def _infer_type(self):
# 如果attrs是dict类型且包含"dst_type"键,则返回"dst_type"的值否则返回输入的第一个元素的dtype
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
return self.attrs["dst_type"]
return self.inputs[0].dtype
def _infer_shape(self):
# 将输入的第一个和第二个元素的shape转换为list
shape_0 = list(self.inputs[0].shape)
shape_1 = list(self.inputs[1].shape)
# 检查shape_0和shape_1的维度是否为2
if len(shape_0) != 2 or len(shape_1) != 2:
raise GKException("For 'MatMul', inputs shape must be 2D, but got {}, {}"
.format(shape_0, shape_1))
# 获取transpose_a和transpose_b
transpose_a = self.attrs["transpose_a"]
transpose_b = self.attrs["transpose_b"]
# 根据transpose_a和transpose_b获取m、k1、k2和n
m, k1 = (shape_0[-1], shape_0[-2]) if transpose_a else (shape_0[-2], shape_0[-1])
k2, n = (shape_1[-1], shape_1[-2]) if transpose_b else (shape_1[-2], shape_1[-1])
# 如果k1和k2不相等则抛出异常
if k1 != k2:
raise GKException("For 'MatMul', inputs have different k value: {} vs {}".format(k1, k2))
# 计算输出的shape
output_shape = [m, n]
# 返回输出的shape
return output_shape
@ -449,14 +556,20 @@ class PadAkg(OpInfer):
"""PadAkg infer"""
def _infer_shape(self):
# 将输入的第一个元素的shape转换为list
shape = list(self.inputs[0].shape)
# 获取输入的维度
n = len(shape)
# 获取pad_before和pad_after
pad_before = list(self.attrs["head"])
pad_after = list(self.attrs["tail"])
# 检查pad_before和pad_after的维度是否与输入的维度相等
if len(pad_before) != n or len(pad_after) != n:
raise GKException("For 'PadAkg', input dimension and pad mismatch: {}d vs {}d vs {}d"
.format(n, len(pad_before), len(pad_after)))
# 计算输出的shape
out_shape = [shape[i] + pad_before[i] + pad_after[i] for i in range(n)]
# 返回输出的shape
return out_shape
@ -464,13 +577,19 @@ class UnPadAkg(OpInfer):
"""UnPadAkg infer"""
def _infer_shape(self):
# 将输入的第一个元素的shape转换为list
shape = list(self.inputs[0].shape)
# 获取输入的维度
n = len(shape)
# 获取unpad_after
unpad_after = list(self.attrs["tail"])
# 检查unpad_after的维度是否与输入的维度相等
if len(unpad_after) != n:
raise GKException("For 'UnPadAkg', input dimension and pad mismatch: {}d vs {}d"
.format(n, len(unpad_after)))
# 计算输出的shape
out_shape = [shape[i] - unpad_after[i] for i in range(n)]
# 返回输出的shape
return out_shape
@ -478,23 +597,32 @@ class Gather(OpInfer):
"""Gather infer"""
def _infer_shape(self):
# 获取输入的第一个和第二个元素的shape
input_shape = self.inputs[0].shape
indices_shape = self.inputs[1].shape
# 获取axis
axis = self.attrs['axis']
# 将输出的shape设置为输入的第一个元素的shape
output_shape = input_shape
# 计算indices_shape的维度
indices_shape_one_dim = 1
for dim in indices_shape:
indices_shape_one_dim *= dim
# 将输出的shape的axis维度设置为indices_shape的维度
output_shape[axis] = indices_shape_one_dim
# 返回输出的shape
return output_shape
def _infer_type(self):
# 返回输入的第一个元素的dtype
return self.inputs[0].dtype
def _infer_format(self):
# 返回输入的第一个元素的data_format
return self.inputs[0].data_format
def _check_type(self):
# 检查输入的第二个元素的dtype是否为int32如果不是则抛出异常
if self.inputs[1].dtype != "int32":
raise GKException("For 'Gather', inputs[1] should be of type int32, but got {}"
.format(self.inputs[1].dtype))

Loading…
Cancel
Save