|
|
@ -24,39 +24,61 @@ class Utils:
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def get_attr_type(attr):
|
|
|
|
def get_attr_type(attr):
|
|
|
|
"""Get attr type"""
|
|
|
|
"""Get attr type"""
|
|
|
|
|
|
|
|
# 判断attr是否为bool类型
|
|
|
|
if isinstance(attr, bool):
|
|
|
|
if isinstance(attr, bool):
|
|
|
|
return 'bool'
|
|
|
|
return 'bool'
|
|
|
|
|
|
|
|
# 判断attr是否为str类型
|
|
|
|
if isinstance(attr, str):
|
|
|
|
if isinstance(attr, str):
|
|
|
|
return 'str'
|
|
|
|
return 'str'
|
|
|
|
|
|
|
|
# 判断attr是否为int类型
|
|
|
|
if isinstance(attr, int):
|
|
|
|
if isinstance(attr, int):
|
|
|
|
return 'int'
|
|
|
|
return 'int'
|
|
|
|
|
|
|
|
# 判断attr是否为float类型
|
|
|
|
if isinstance(attr, float):
|
|
|
|
if isinstance(attr, float):
|
|
|
|
return 'float'
|
|
|
|
return 'float'
|
|
|
|
|
|
|
|
# 判断attr是否为list或tuple类型
|
|
|
|
if isinstance(attr, (list, tuple)):
|
|
|
|
if isinstance(attr, (list, tuple)):
|
|
|
|
|
|
|
|
# 判断attr是否为空
|
|
|
|
if not attr:
|
|
|
|
if not attr:
|
|
|
|
raise ValueError("attr is invalid: the length of attr is 0")
|
|
|
|
raise ValueError("attr is invalid: the length of attr is 0")
|
|
|
|
|
|
|
|
# 判断attr的第一个元素是否为int类型
|
|
|
|
if isinstance(attr[0], int):
|
|
|
|
if isinstance(attr[0], int):
|
|
|
|
return 'listInt'
|
|
|
|
return 'listInt'
|
|
|
|
|
|
|
|
# 判断attr的第一个元素是否为str类型
|
|
|
|
if isinstance(attr[0], str):
|
|
|
|
if isinstance(attr[0], str):
|
|
|
|
return 'listStr'
|
|
|
|
return 'listStr'
|
|
|
|
|
|
|
|
# 如果attr的类型不在支持的列表中,则抛出异常
|
|
|
|
raise ValueError("attr {} type {} is not in supported list ['bool', 'str', 'int', 'float', 'int' list, "
|
|
|
|
raise ValueError("attr {} type {} is not in supported list ['bool', 'str', 'int', 'float', 'int' list, "
|
|
|
|
"'str' list]".format(attr, type(attr)))
|
|
|
|
"'str' list]".format(attr, type(attr)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataFormat:
|
|
|
|
class DataFormat:
|
|
|
|
"""DataFormat"""
|
|
|
|
"""DataFormat"""
|
|
|
|
|
|
|
|
# 默认格式
|
|
|
|
DEFAULT = "DefaultFormat"
|
|
|
|
DEFAULT = "DefaultFormat"
|
|
|
|
|
|
|
|
# NC1KHKWHWC0格式
|
|
|
|
NC1KHKWHWC0 = "NC1KHKWHWC0"
|
|
|
|
NC1KHKWHWC0 = "NC1KHKWHWC0"
|
|
|
|
|
|
|
|
# ND格式
|
|
|
|
ND = "ND"
|
|
|
|
ND = "ND"
|
|
|
|
|
|
|
|
# NCHW格式
|
|
|
|
NCHW = "NCHW"
|
|
|
|
NCHW = "NCHW"
|
|
|
|
|
|
|
|
# NHWC格式
|
|
|
|
NHWC = "NHWC"
|
|
|
|
NHWC = "NHWC"
|
|
|
|
|
|
|
|
# HWCN格式
|
|
|
|
HWCN = "HWCN"
|
|
|
|
HWCN = "HWCN"
|
|
|
|
|
|
|
|
# NC1HWC0格式
|
|
|
|
NC1HWC0 = "NC1HWC0"
|
|
|
|
NC1HWC0 = "NC1HWC0"
|
|
|
|
|
|
|
|
# FRAC_Z格式
|
|
|
|
FRAC_Z = "FracZ"
|
|
|
|
FRAC_Z = "FracZ"
|
|
|
|
|
|
|
|
# FRAC_NZ格式
|
|
|
|
FRAC_NZ = "FRACTAL_NZ"
|
|
|
|
FRAC_NZ = "FRACTAL_NZ"
|
|
|
|
|
|
|
|
# C1HWNCOC0格式
|
|
|
|
C1HWNCOC0 = "C1HWNCoC0"
|
|
|
|
C1HWNCOC0 = "C1HWNCoC0"
|
|
|
|
|
|
|
|
# NC1HWC0_C04格式
|
|
|
|
NC1HWC0_C04 = "NC1HWC0_C04"
|
|
|
|
NC1HWC0_C04 = "NC1HWC0_C04"
|
|
|
|
|
|
|
|
# FRACTAL_Z_C04格式
|
|
|
|
FRACTAL_Z_C04 = "FRACTAL_Z_C04"
|
|
|
|
FRACTAL_Z_C04 = "FRACTAL_Z_C04"
|
|
|
|
|
|
|
|
# NDHWC格式
|
|
|
|
NDHWC = "NDHWC"
|
|
|
|
NDHWC = "NDHWC"
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
@ -65,29 +87,47 @@ class DataFormat:
|
|
|
|
|
|
|
|
|
|
|
|
class DataType:
|
|
|
|
class DataType:
|
|
|
|
"""Data Type"""
|
|
|
|
"""Data Type"""
|
|
|
|
|
|
|
|
# 浮点型
|
|
|
|
FLOAT = "float"
|
|
|
|
FLOAT = "float"
|
|
|
|
|
|
|
|
# 半精度浮点型
|
|
|
|
FLOAT16 = "float16"
|
|
|
|
FLOAT16 = "float16"
|
|
|
|
|
|
|
|
# 单精度浮点型
|
|
|
|
FLOAT32 = "float32"
|
|
|
|
FLOAT32 = "float32"
|
|
|
|
|
|
|
|
# 双精度浮点型
|
|
|
|
FLOAT64 = "float64"
|
|
|
|
FLOAT64 = "float64"
|
|
|
|
|
|
|
|
# 整型
|
|
|
|
INT = "int"
|
|
|
|
INT = "int"
|
|
|
|
|
|
|
|
# 8位整型
|
|
|
|
INT8 = "int8"
|
|
|
|
INT8 = "int8"
|
|
|
|
|
|
|
|
# 16位整型
|
|
|
|
INT16 = "int16"
|
|
|
|
INT16 = "int16"
|
|
|
|
|
|
|
|
# 32位整型
|
|
|
|
INT32 = "int32"
|
|
|
|
INT32 = "int32"
|
|
|
|
|
|
|
|
# 64位整型
|
|
|
|
INT64 = "int64"
|
|
|
|
INT64 = "int64"
|
|
|
|
|
|
|
|
# 无符号整型
|
|
|
|
UINT = "uint"
|
|
|
|
UINT = "uint"
|
|
|
|
|
|
|
|
# 8位无符号整型
|
|
|
|
UINT8 = "uint8"
|
|
|
|
UINT8 = "uint8"
|
|
|
|
|
|
|
|
# 16位无符号整型
|
|
|
|
UINT16 = "uint16"
|
|
|
|
UINT16 = "uint16"
|
|
|
|
|
|
|
|
# 32位无符号整型
|
|
|
|
UINT32 = "uint32"
|
|
|
|
UINT32 = "uint32"
|
|
|
|
|
|
|
|
# 64位无符号整型
|
|
|
|
UINT64 = "uint64"
|
|
|
|
UINT64 = "uint64"
|
|
|
|
|
|
|
|
# 布尔型
|
|
|
|
BOOL = "bool"
|
|
|
|
BOOL = "bool"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化函数
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
|
|
|
|
|
# 无需执行任何操作
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PrimLib:
|
|
|
|
class PrimLib:
|
|
|
|
"""Prim lib"""
|
|
|
|
"""Prim lib"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 定义PrimLib类中的常量
|
|
|
|
UNKNOWN = 0
|
|
|
|
UNKNOWN = 0
|
|
|
|
RESHAPE = 1
|
|
|
|
RESHAPE = 1
|
|
|
|
ELEMWISE = 2
|
|
|
|
ELEMWISE = 2
|
|
|
@ -102,53 +142,73 @@ class PrimLib:
|
|
|
|
"""Prim"""
|
|
|
|
"""Prim"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, iter_type, calibrate=1, relation_func=None):
|
|
|
|
def __init__(self, iter_type, calibrate=1, relation_func=None):
|
|
|
|
|
|
|
|
# 初始化Prim类,设置iter_type、calibrate和relation_func属性
|
|
|
|
self.iter_type = iter_type
|
|
|
|
self.iter_type = iter_type
|
|
|
|
self.calibrate = calibrate
|
|
|
|
self.calibrate = calibrate
|
|
|
|
self.relation_func = relation_func
|
|
|
|
self.relation_func = relation_func
|
|
|
|
if relation_func is None:
|
|
|
|
if relation_func is None:
|
|
|
|
|
|
|
|
# 如果relation_func为None,则设置默认的relation_func
|
|
|
|
self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
|
|
|
|
self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
|
|
|
|
|
|
|
|
|
|
|
|
def default_reshape_relation(self, op, input_idx):
|
|
|
|
def default_reshape_relation(self, op, input_idx):
|
|
|
|
"""Process reshape relation"""
|
|
|
|
"""Process reshape relation"""
|
|
|
|
|
|
|
|
# 处理reshape关系
|
|
|
|
axis_relation, elem_relation = self.unknown_relation(op, input_idx)
|
|
|
|
axis_relation, elem_relation = self.unknown_relation(op, input_idx)
|
|
|
|
|
|
|
|
# 将elem_relation设置为PrimLib.RESHAPE
|
|
|
|
elem_relation = [PrimLib.RESHAPE] * len(elem_relation)
|
|
|
|
elem_relation = [PrimLib.RESHAPE] * len(elem_relation)
|
|
|
|
return axis_relation, elem_relation
|
|
|
|
return axis_relation, elem_relation
|
|
|
|
|
|
|
|
|
|
|
|
def default_elemwise_broadcast_relation(self, op, input_idx):
|
|
|
|
def default_elemwise_broadcast_relation(self, op, input_idx):
|
|
|
|
"""Process elemwise and broadcast relation"""
|
|
|
|
"""Process elemwise and broadcast relation"""
|
|
|
|
|
|
|
|
# 处理elemwise和broadcast关系
|
|
|
|
out_shape = op.output.shape
|
|
|
|
out_shape = op.output.shape
|
|
|
|
in_shape = op.inputs[input_idx].shape
|
|
|
|
in_shape = op.inputs[input_idx].shape
|
|
|
|
|
|
|
|
# 如果输出形状的长度小于输入形状的长度,则抛出异常
|
|
|
|
if len(out_shape) < len(in_shape):
|
|
|
|
if len(out_shape) < len(in_shape):
|
|
|
|
raise ValueError("For '{}', the input/output size is abnormal, as the length of output shape{} is less "
|
|
|
|
raise ValueError("For '{}', the input/output size is abnormal, as the length of output shape{} is less "
|
|
|
|
"than the length of input shape{}".format(op.prim, out_shape, in_shape))
|
|
|
|
"than the length of input shape{}".format(op.prim, out_shape, in_shape))
|
|
|
|
axis_relation, elem_relation = [], []
|
|
|
|
axis_relation, elem_relation = [], []
|
|
|
|
|
|
|
|
# 计算输出形状和输入形状的长度差
|
|
|
|
delta = len(out_shape) - len(in_shape)
|
|
|
|
delta = len(out_shape) - len(in_shape)
|
|
|
|
if delta > 0:
|
|
|
|
if delta > 0:
|
|
|
|
|
|
|
|
# 如果输出形状的长度大于输入形状的长度,则在axis_relation和elem_relation中添加None
|
|
|
|
for i in range(0, delta):
|
|
|
|
for i in range(0, delta):
|
|
|
|
axis_relation.append(None)
|
|
|
|
axis_relation.append(None)
|
|
|
|
elem_relation.append(None)
|
|
|
|
elem_relation.append(None)
|
|
|
|
|
|
|
|
# 遍历输入形状的每个元素
|
|
|
|
for i, _ in enumerate(in_shape):
|
|
|
|
for i, _ in enumerate(in_shape):
|
|
|
|
|
|
|
|
# 在axis_relation中添加当前元素的索引
|
|
|
|
axis_relation.append(i)
|
|
|
|
axis_relation.append(i)
|
|
|
|
|
|
|
|
# 如果输出形状的对应元素等于输入形状的对应元素,则elem_relation添加PrimLib.ELEMWISE,否则添加PrimLib.BROADCAST
|
|
|
|
elem_relation.append(
|
|
|
|
elem_relation.append(
|
|
|
|
PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST)
|
|
|
|
PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST)
|
|
|
|
return axis_relation, elem_relation
|
|
|
|
return axis_relation, elem_relation
|
|
|
|
|
|
|
|
|
|
|
|
def default_reduce_relation(self, op, input_idx):
|
|
|
|
def default_reduce_relation(self, op, input_idx):
|
|
|
|
"""Process reduce relation"""
|
|
|
|
"""Process reduce relation"""
|
|
|
|
|
|
|
|
# 处理reduce关系
|
|
|
|
axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx)
|
|
|
|
axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx)
|
|
|
|
|
|
|
|
# 遍历reduce_axis中的每个元素
|
|
|
|
for i in op.attrs['reduce_axis']:
|
|
|
|
for i in op.attrs['reduce_axis']:
|
|
|
|
|
|
|
|
# 将elem_relation中对应元素的值设置为PrimLib.REDUCE
|
|
|
|
elem_relation[i] = PrimLib.REDUCE
|
|
|
|
elem_relation[i] = PrimLib.REDUCE
|
|
|
|
return axis_relation, elem_relation
|
|
|
|
return axis_relation, elem_relation
|
|
|
|
|
|
|
|
|
|
|
|
def unknown_relation(self, op, input_idx):
|
|
|
|
def unknown_relation(self, op, input_idx):
|
|
|
|
"""Process unknown relation"""
|
|
|
|
"""Process unknown relation"""
|
|
|
|
|
|
|
|
# 获取输出和输入的形状
|
|
|
|
out_shape = op.output.shape
|
|
|
|
out_shape = op.output.shape
|
|
|
|
in_shape = op.inputs[input_idx].shape
|
|
|
|
in_shape = op.inputs[input_idx].shape
|
|
|
|
|
|
|
|
# 获取所有可能的轴关系
|
|
|
|
all_relation = list(range(len(in_shape)))
|
|
|
|
all_relation = list(range(len(in_shape)))
|
|
|
|
|
|
|
|
# 初始化轴关系列表
|
|
|
|
axis_relation = [all_relation for i in range(0, len(out_shape))]
|
|
|
|
axis_relation = [all_relation for i in range(0, len(out_shape))]
|
|
|
|
|
|
|
|
# 初始化元素关系列表
|
|
|
|
elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))]
|
|
|
|
elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))]
|
|
|
|
|
|
|
|
# 返回轴关系和元素关系
|
|
|
|
return axis_relation, elem_relation
|
|
|
|
return axis_relation, elem_relation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 默认的关系函数列表
|
|
|
|
default_relation_func = [
|
|
|
|
default_relation_func = [
|
|
|
|
unknown_relation,
|
|
|
|
unknown_relation,
|
|
|
|
default_reshape_relation,
|
|
|
|
default_reshape_relation,
|
|
|
@ -158,6 +218,7 @@ class PrimLib:
|
|
|
|
unknown_relation,
|
|
|
|
unknown_relation,
|
|
|
|
]
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 定义基本操作
|
|
|
|
primtives = {
|
|
|
|
primtives = {
|
|
|
|
'Add': Prim(ELEMWISE),
|
|
|
|
'Add': Prim(ELEMWISE),
|
|
|
|
'Abs': Prim(ELEMWISE),
|
|
|
|
'Abs': Prim(ELEMWISE),
|
|
|
@ -239,7 +300,9 @@ class PrimLib:
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def get_prim(cls, op):
|
|
|
|
def get_prim(cls, op):
|
|
|
|
"""Get op primtive"""
|
|
|
|
"""Get op primtive"""
|
|
|
|
|
|
|
|
# 从cls.primtives中获取op.prim对应的prim
|
|
|
|
prim = cls.primtives.get(op.prim, None)
|
|
|
|
prim = cls.primtives.get(op.prim, None)
|
|
|
|
|
|
|
|
# 如果prim为None,则打印警告信息,并返回cls.default_primtive
|
|
|
|
if prim is None:
|
|
|
|
if prim is None:
|
|
|
|
print('[WARN] primtive is not registered: ' + op.prim)
|
|
|
|
print('[WARN] primtive is not registered: ' + op.prim)
|
|
|
|
prim = cls.default_primtive
|
|
|
|
prim = cls.default_primtive
|
|
|
@ -248,50 +311,65 @@ class PrimLib:
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def input_relation(cls, op, input_idx):
|
|
|
|
def input_relation(cls, op, input_idx):
|
|
|
|
"""Get op's input_relation according to input_idx"""
|
|
|
|
"""Get op's input_relation according to input_idx"""
|
|
|
|
|
|
|
|
# 调用cls.get_prim(op)获取op对应的prim,然后调用prim的relation_func方法获取op的input_relation
|
|
|
|
return cls.get_prim(op).relation_func(op, input_idx)
|
|
|
|
return cls.get_prim(op).relation_func(op, input_idx)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def iter_type(cls, op):
|
|
|
|
def iter_type(cls, op):
|
|
|
|
"""Get op's iter type"""
|
|
|
|
"""Get op's iter type"""
|
|
|
|
|
|
|
|
# 调用cls.get_prim(op)获取op对应的prim,然后返回prim的iter_type
|
|
|
|
return cls.get_prim(op).iter_type
|
|
|
|
return cls.get_prim(op).iter_type
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def is_reduce(cls, op):
|
|
|
|
def is_reduce(cls, op):
|
|
|
|
"""Check whether op's iter type is reduce"""
|
|
|
|
"""Check whether op's iter type is reduce"""
|
|
|
|
|
|
|
|
# 调用cls.get_prim(op)获取op对应的prim,然后判断prim的iter_type是否为cls.REDUCE
|
|
|
|
return cls.get_prim(op).iter_type == cls.REDUCE
|
|
|
|
return cls.get_prim(op).iter_type == cls.REDUCE
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def calibrate_iter_size(cls, op, iter_size):
|
|
|
|
def calibrate_iter_size(cls, op, iter_size):
|
|
|
|
"""Get calibrate_iter_size"""
|
|
|
|
"""Get calibrate_iter_size"""
|
|
|
|
|
|
|
|
# 调用cls.get_prim(op)获取op对应的prim,然后返回prim的calibrate乘以iter_size
|
|
|
|
return cls.get_prim(op).calibrate * iter_size
|
|
|
|
return cls.get_prim(op).calibrate * iter_size
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def dtype_bytes(cls, dtype):
|
|
|
|
def dtype_bytes(cls, dtype):
|
|
|
|
"""Get dtype bytes"""
|
|
|
|
"""Get dtype bytes"""
|
|
|
|
|
|
|
|
# 初始化bits和unit为1
|
|
|
|
bits, unit = 1, 1
|
|
|
|
bits, unit = 1, 1
|
|
|
|
|
|
|
|
# 从dtype的最后一个字符开始向前遍历
|
|
|
|
for i in range(len(dtype) - 1, 0, -1):
|
|
|
|
for i in range(len(dtype) - 1, 0, -1):
|
|
|
|
|
|
|
|
# 如果当前字符是数字,则将bits加上当前字符对应的数字乘以unit,并将unit乘以10
|
|
|
|
if dtype[i].isdecimal():
|
|
|
|
if dtype[i].isdecimal():
|
|
|
|
bits += int(dtype[i]) * unit
|
|
|
|
bits += int(dtype[i]) * unit
|
|
|
|
unit *= 10
|
|
|
|
unit *= 10
|
|
|
|
|
|
|
|
# 如果当前字符不是数字,则跳出循环
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
break
|
|
|
|
break
|
|
|
|
|
|
|
|
# 返回bits除以8的结果
|
|
|
|
return bits // 8
|
|
|
|
return bits // 8
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def inplace_reuse(cls, op, input_idx, start_axis=0):
|
|
|
|
def inplace_reuse(cls, op, input_idx, start_axis=0):
|
|
|
|
"""Check whether op is inplace reuse"""
|
|
|
|
"""Check whether op is inplace reuse"""
|
|
|
|
|
|
|
|
# 如果op.output.dtype的字节数大于op.inputs[input_idx].dtype的字节数,则返回False
|
|
|
|
if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype):
|
|
|
|
if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype):
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
# 调用cls.get_prim(op)获取op对应的prim,然后调用prim的relation_func方法获取op的input_relation
|
|
|
|
_, elem_relation = cls.get_prim(op).relation_func(op, input_idx)
|
|
|
|
_, elem_relation = cls.get_prim(op).relation_func(op, input_idx)
|
|
|
|
|
|
|
|
# 从start_axis开始遍历elem_relation
|
|
|
|
for i in range(start_axis, len(elem_relation)):
|
|
|
|
for i in range(start_axis, len(elem_relation)):
|
|
|
|
|
|
|
|
# 如果elem_relation中的元素不等于cls.ELEMWISE,则返回False
|
|
|
|
if elem_relation[i] != cls.ELEMWISE:
|
|
|
|
if elem_relation[i] != cls.ELEMWISE:
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
# 如果以上条件都不满足,则返回True
|
|
|
|
return True
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Tensor:
|
|
|
|
class Tensor:
|
|
|
|
"""Tensor"""
|
|
|
|
"""Tensor"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 参数类型常量
|
|
|
|
PARA_NONE = 0
|
|
|
|
PARA_NONE = 0
|
|
|
|
PARA_INPUT = 1
|
|
|
|
PARA_INPUT = 1
|
|
|
|
PARA_OUTPUT = 2
|
|
|
|
PARA_OUTPUT = 2
|
|
|
@ -303,6 +381,7 @@ class Tensor:
|
|
|
|
self.members = [leader]
|
|
|
|
self.members = [leader]
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0):
|
|
|
|
def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0):
|
|
|
|
|
|
|
|
# 初始化Tensor对象
|
|
|
|
self.name = name
|
|
|
|
self.name = name
|
|
|
|
self.shape = shape
|
|
|
|
self.shape = shape
|
|
|
|
self.dtype = dtype
|
|
|
|
self.dtype = dtype
|
|
|
@ -313,13 +392,16 @@ class Tensor:
|
|
|
|
self.buddy = None
|
|
|
|
self.buddy = None
|
|
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
def __str__(self):
|
|
|
|
|
|
|
|
# 返回Tensor对象的字符串表示
|
|
|
|
return self.name + str(list(self.shape))
|
|
|
|
return self.name + str(list(self.shape))
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
def __repr__(self):
|
|
|
|
|
|
|
|
# 返回Tensor对象的字符串表示
|
|
|
|
return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
|
|
|
|
return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
|
|
|
|
|
|
|
|
|
|
|
|
def get_size(self):
|
|
|
|
def get_size(self):
|
|
|
|
"""Get size"""
|
|
|
|
"""Get size"""
|
|
|
|
|
|
|
|
# 获取Tensor对象的大小
|
|
|
|
size = PrimLib.dtype_bytes(self.dtype)
|
|
|
|
size = PrimLib.dtype_bytes(self.dtype)
|
|
|
|
for i in self.shape:
|
|
|
|
for i in self.shape:
|
|
|
|
size *= i
|
|
|
|
size *= i
|
|
|
@ -327,6 +409,7 @@ class Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
def add_buddy(self, tensor):
|
|
|
|
def add_buddy(self, tensor):
|
|
|
|
"""Add buddy"""
|
|
|
|
"""Add buddy"""
|
|
|
|
|
|
|
|
# 添加buddy
|
|
|
|
if self.buddy is None:
|
|
|
|
if self.buddy is None:
|
|
|
|
self.buddy = self.Buddy(self)
|
|
|
|
self.buddy = self.Buddy(self)
|
|
|
|
self.buddy.members.append(tensor)
|
|
|
|
self.buddy.members.append(tensor)
|
|
|
@ -337,6 +420,7 @@ class Value:
|
|
|
|
"""Value"""
|
|
|
|
"""Value"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT):
|
|
|
|
def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT):
|
|
|
|
|
|
|
|
# 初始化Value对象
|
|
|
|
self.name = name
|
|
|
|
self.name = name
|
|
|
|
self.shape = [1]
|
|
|
|
self.shape = [1]
|
|
|
|
self.dtype = dtype
|
|
|
|
self.dtype = dtype
|
|
|
@ -344,14 +428,17 @@ class Value:
|
|
|
|
self.data_format = data_format
|
|
|
|
self.data_format = data_format
|
|
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
def __str__(self):
|
|
|
|
|
|
|
|
# 返回Value对象的字符串表示
|
|
|
|
return self.name + str(list(self.shape))
|
|
|
|
return self.name + str(list(self.shape))
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
def __repr__(self):
|
|
|
|
|
|
|
|
# 返回Value对象的字符串表示
|
|
|
|
return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
|
|
|
|
return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def get_size():
|
|
|
|
def get_size():
|
|
|
|
"""Get size"""
|
|
|
|
"""Get size"""
|
|
|
|
|
|
|
|
# 获取Value对象的大小
|
|
|
|
return 1
|
|
|
|
return 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -359,23 +446,31 @@ class Operator:
|
|
|
|
"""Operator"""
|
|
|
|
"""Operator"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, primtive, inputs, output, attrs):
|
|
|
|
def __init__(self, primtive, inputs, output, attrs):
|
|
|
|
|
|
|
|
# 初始化Operator对象
|
|
|
|
self.prim = primtive
|
|
|
|
self.prim = primtive
|
|
|
|
self.inputs = inputs
|
|
|
|
self.inputs = inputs
|
|
|
|
self.output = output
|
|
|
|
self.output = output
|
|
|
|
self.attrs = attrs
|
|
|
|
self.attrs = attrs
|
|
|
|
|
|
|
|
# 将当前Operator对象添加到每个输入的to_ops列表中
|
|
|
|
for t in inputs:
|
|
|
|
for t in inputs:
|
|
|
|
t.to_ops.append(self)
|
|
|
|
t.to_ops.append(self)
|
|
|
|
|
|
|
|
# 如果输出的op属性为None,则将当前Operator对象赋值给输出的op属性
|
|
|
|
if output.op is None:
|
|
|
|
if output.op is None:
|
|
|
|
output.op = self
|
|
|
|
output.op = self
|
|
|
|
|
|
|
|
# 初始化all_inputs列表,用于存储Tensor输入和Value输入
|
|
|
|
self.all_inputs = [] # include Tensor inputs and Value inputs.
|
|
|
|
self.all_inputs = [] # include Tensor inputs and Value inputs.
|
|
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
def __str__(self):
|
|
|
|
|
|
|
|
# 将self.all_inputs中的元素转换为字符串,并用逗号连接起来
|
|
|
|
args = ', '.join((str(t) for t in self.all_inputs))
|
|
|
|
args = ', '.join((str(t) for t in self.all_inputs))
|
|
|
|
|
|
|
|
# 构造表达式字符串
|
|
|
|
expr = "%s = %s.%s(%s) id:%s" % (
|
|
|
|
expr = "%s = %s.%s(%s) id:%s" % (
|
|
|
|
str(self.output), self.prim, self.output.dtype, args, id(self))
|
|
|
|
str(self.output), self.prim, self.output.dtype, args, id(self))
|
|
|
|
|
|
|
|
# 如果self.attrs不为空,则返回表达式字符串和self.attrs的字符串连接
|
|
|
|
return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
|
|
|
|
return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
def __repr__(self):
|
|
|
|
|
|
|
|
# 返回self的字符串表示
|
|
|
|
return str(self)
|
|
|
|
return str(self)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -383,6 +478,7 @@ class Graph:
|
|
|
|
"""Graph"""
|
|
|
|
"""Graph"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, name, ops, stitch_info=None, recompute_ops=None):
|
|
|
|
def __init__(self, name, ops, stitch_info=None, recompute_ops=None):
|
|
|
|
|
|
|
|
# 初始化Graph对象
|
|
|
|
self.name = name
|
|
|
|
self.name = name
|
|
|
|
self.ops = ops # in topo order, can not use set
|
|
|
|
self.ops = ops # in topo order, can not use set
|
|
|
|
self.inputs = []
|
|
|
|
self.inputs = []
|
|
|
@ -393,10 +489,12 @@ class Graph:
|
|
|
|
|
|
|
|
|
|
|
|
def set_processor(self, processor):
|
|
|
|
def set_processor(self, processor):
|
|
|
|
"""Set processor"""
|
|
|
|
"""Set processor"""
|
|
|
|
|
|
|
|
# 设置处理器
|
|
|
|
self.processor = processor
|
|
|
|
self.processor = processor
|
|
|
|
|
|
|
|
|
|
|
|
def add(self, ops):
|
|
|
|
def add(self, ops):
|
|
|
|
"""Add ops"""
|
|
|
|
"""Add ops"""
|
|
|
|
|
|
|
|
# 添加操作
|
|
|
|
if isinstance(ops, Operator):
|
|
|
|
if isinstance(ops, Operator):
|
|
|
|
self.ops.append(ops)
|
|
|
|
self.ops.append(ops)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -404,101 +502,148 @@ class Graph:
|
|
|
|
|
|
|
|
|
|
|
|
def extract_subgraph(self, graph_name, tensor_names, difference=False):
|
|
|
|
def extract_subgraph(self, graph_name, tensor_names, difference=False):
|
|
|
|
"""Extract subgraph from this graph"""
|
|
|
|
"""Extract subgraph from this graph"""
|
|
|
|
|
|
|
|
# 从当前图中提取子图
|
|
|
|
graph = Graph(graph_name, [])
|
|
|
|
graph = Graph(graph_name, [])
|
|
|
|
outputs = set(tensor_names)
|
|
|
|
outputs = set(tensor_names)
|
|
|
|
if difference:
|
|
|
|
if difference:
|
|
|
|
|
|
|
|
# 如果difference为True,则提取不在outputs中的操作
|
|
|
|
for op in self.ops:
|
|
|
|
for op in self.ops:
|
|
|
|
if op.output.name not in outputs:
|
|
|
|
if op.output.name not in outputs:
|
|
|
|
graph.add(op)
|
|
|
|
graph.add(op)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
|
|
|
|
# 如果difference为False,则提取在outputs中的操作
|
|
|
|
for op in self.ops:
|
|
|
|
for op in self.ops:
|
|
|
|
if op.output.name in outputs:
|
|
|
|
if op.output.name in outputs:
|
|
|
|
graph.add(op)
|
|
|
|
graph.add(op)
|
|
|
|
outputs.remove(op.output.name)
|
|
|
|
outputs.remove(op.output.name)
|
|
|
|
|
|
|
|
# 如果outputs中还有元素,则抛出异常
|
|
|
|
for name in outputs:
|
|
|
|
for name in outputs:
|
|
|
|
raise ValueError("Invalid input tensor : {}, can not find it in graph".format(name))
|
|
|
|
raise ValueError("Invalid input tensor : {}, can not find it in graph".format(name))
|
|
|
|
return graph
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
|
|
def deduce_parameters(self):
|
|
|
|
def deduce_parameters(self):
|
|
|
|
"""Deduce parameters"""
|
|
|
|
"""Deduce parameters"""
|
|
|
|
|
|
|
|
# 初始化输入和输出列表
|
|
|
|
inputs, outputs = [], []
|
|
|
|
inputs, outputs = [], []
|
|
|
|
|
|
|
|
# 遍历所有操作
|
|
|
|
for op in self.ops:
|
|
|
|
for op in self.ops:
|
|
|
|
|
|
|
|
# 遍历操作的所有输入
|
|
|
|
for t in op.inputs:
|
|
|
|
for t in op.inputs:
|
|
|
|
|
|
|
|
# 如果输入不在输入列表中,且输入的操作不在操作列表中,则将输入添加到输入列表中
|
|
|
|
if t not in inputs and t.op not in self.ops:
|
|
|
|
if t not in inputs and t.op not in self.ops:
|
|
|
|
inputs.append(t)
|
|
|
|
inputs.append(t)
|
|
|
|
|
|
|
|
# 如果操作输出已经在输出列表中,则跳过
|
|
|
|
if op.output in outputs:
|
|
|
|
if op.output in outputs:
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
|
|
|
|
# 如果操作输出是输出参数类型,或者操作输出没有后续操作,则将操作输出添加到输出列表中
|
|
|
|
if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
|
|
|
|
if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
|
|
|
|
outputs.append(op.output)
|
|
|
|
outputs.append(op.output)
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
|
|
|
|
# 如果操作输出的后续操作不在操作列表中,则将操作输出添加到输出列表中
|
|
|
|
if any((succ not in self.ops for succ in op.output.to_ops)):
|
|
|
|
if any((succ not in self.ops for succ in op.output.to_ops)):
|
|
|
|
outputs.append(op.output)
|
|
|
|
outputs.append(op.output)
|
|
|
|
|
|
|
|
# 如果有指定的输入,则将指定的输入赋值给输入列表
|
|
|
|
if self.inputs:
|
|
|
|
if self.inputs:
|
|
|
|
inputs = self.inputs
|
|
|
|
inputs = self.inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 如果有指定的输出,则将指定的输出赋值给输出列表
|
|
|
|
if self.outputs:
|
|
|
|
if self.outputs:
|
|
|
|
outputs = self.outputs
|
|
|
|
outputs = self.outputs
|
|
|
|
|
|
|
|
# 返回输入和输出列表
|
|
|
|
return inputs, outputs
|
|
|
|
return inputs, outputs
|
|
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
def __str__(self):
|
|
|
|
|
|
|
|
# 调用deduce_parameters方法获取输入和输出列表
|
|
|
|
inputs, outputs = self.deduce_parameters()
|
|
|
|
inputs, outputs = self.deduce_parameters()
|
|
|
|
|
|
|
|
# 将输入列表转换为字符串
|
|
|
|
para_str = ', '.join((repr(t) for t in inputs))
|
|
|
|
para_str = ', '.join((repr(t) for t in inputs))
|
|
|
|
|
|
|
|
# 将输出列表转换为字符串
|
|
|
|
out_str = ', '.join((repr(t) for t in outputs))
|
|
|
|
out_str = ', '.join((repr(t) for t in outputs))
|
|
|
|
|
|
|
|
# 初始化行列表
|
|
|
|
lines = []
|
|
|
|
lines = []
|
|
|
|
|
|
|
|
# 添加操作名称、输入和输出到行列表中
|
|
|
|
lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
|
|
|
|
lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
|
|
|
|
|
|
|
|
# 如果有拼接信息,则添加拼接操作和拼接原子操作到行列表中
|
|
|
|
if self.stitch_info:
|
|
|
|
if self.stitch_info:
|
|
|
|
if self.stitch_info.stitch_ops:
|
|
|
|
if self.stitch_info.stitch_ops:
|
|
|
|
lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops))
|
|
|
|
lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops))
|
|
|
|
if self.stitch_info.stitch_atomic_ops:
|
|
|
|
if self.stitch_info.stitch_atomic_ops:
|
|
|
|
lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops))
|
|
|
|
lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 遍历所有操作,将操作添加到行列表中
|
|
|
|
for op in self.ops:
|
|
|
|
for op in self.ops:
|
|
|
|
lines.append(' ' + str(op))
|
|
|
|
lines.append(' ' + str(op))
|
|
|
|
|
|
|
|
# 添加结束符号到行列表中
|
|
|
|
lines.append('}')
|
|
|
|
lines.append('}')
|
|
|
|
|
|
|
|
# 将行列表转换为字符串并返回
|
|
|
|
return '\n'.join(lines)
|
|
|
|
return '\n'.join(lines)
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
def __repr__(self):
|
|
|
|
|
|
|
|
# 返回对象的字符串表示
|
|
|
|
return str(self)
|
|
|
|
return str(self)
|
|
|
|
|
|
|
|
|
|
|
|
def dump(self):
|
|
|
|
def dump(self):
|
|
|
|
"""Dump Graph to json"""
|
|
|
|
"""Dump Graph to json"""
|
|
|
|
|
|
|
|
# 将Graph转换为json格式
|
|
|
|
attr_name = {'reduce_axis': 'axis'}
|
|
|
|
attr_name = {'reduce_axis': 'axis'}
|
|
|
|
|
|
|
|
# 获取Graph的输入和输出参数
|
|
|
|
inputs, outputs = self.deduce_parameters()
|
|
|
|
inputs, outputs = self.deduce_parameters()
|
|
|
|
input_desc, output_desc, op_desc = [], [], []
|
|
|
|
input_desc, output_desc, op_desc = [], [], []
|
|
|
|
|
|
|
|
# 遍历输入参数
|
|
|
|
for t in inputs:
|
|
|
|
for t in inputs:
|
|
|
|
|
|
|
|
# 将输入参数转换为字典格式
|
|
|
|
input_desc.append([{'data_type': t.dtype, 'shape': t.shape,
|
|
|
|
input_desc.append([{'data_type': t.dtype, 'shape': t.shape,
|
|
|
|
'tensor_name': t.name, 'format': t.data_format}])
|
|
|
|
'tensor_name': t.name, 'format': t.data_format}])
|
|
|
|
|
|
|
|
# 遍历输出参数
|
|
|
|
for t in outputs:
|
|
|
|
for t in outputs:
|
|
|
|
|
|
|
|
# 将输出参数转换为字典格式
|
|
|
|
output_desc.append({'data_type': t.dtype, 'shape': t.shape,
|
|
|
|
output_desc.append({'data_type': t.dtype, 'shape': t.shape,
|
|
|
|
'tensor_name': t.name, 'format': t.data_format})
|
|
|
|
'tensor_name': t.name, 'format': t.data_format})
|
|
|
|
|
|
|
|
# 遍历Graph中的操作
|
|
|
|
for op in self.ops:
|
|
|
|
for op in self.ops:
|
|
|
|
attrs, in_desc = [], []
|
|
|
|
attrs, in_desc = [], []
|
|
|
|
|
|
|
|
# 遍历操作中的属性
|
|
|
|
for a in op.attrs:
|
|
|
|
for a in op.attrs:
|
|
|
|
|
|
|
|
# 获取属性名
|
|
|
|
name = attr_name.get(a, a)
|
|
|
|
name = attr_name.get(a, a)
|
|
|
|
|
|
|
|
# 将属性转换为字典格式
|
|
|
|
attrs.append(
|
|
|
|
attrs.append(
|
|
|
|
{'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])})
|
|
|
|
{'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])})
|
|
|
|
|
|
|
|
# 遍历操作中的输入
|
|
|
|
for t in op.all_inputs:
|
|
|
|
for t in op.all_inputs:
|
|
|
|
|
|
|
|
# 如果输入是Tensor类型
|
|
|
|
if isinstance(t, Tensor):
|
|
|
|
if isinstance(t, Tensor):
|
|
|
|
|
|
|
|
# 将输入转换为字典格式
|
|
|
|
in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape,
|
|
|
|
in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape,
|
|
|
|
'tensor_name': t.name, 'format': t.data_format}])
|
|
|
|
'tensor_name': t.name, 'format': t.data_format}])
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
|
|
|
|
# 将输入转换为字典格式
|
|
|
|
in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
|
|
|
|
in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
|
|
|
|
'tensor_name': t.name, 'format': t.data_format}])
|
|
|
|
'tensor_name': t.name, 'format': t.data_format}])
|
|
|
|
|
|
|
|
# 将操作输出转换为字典格式
|
|
|
|
out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
|
|
|
|
out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
|
|
|
|
'tensor_name': op.output.name, 'format': op.output.data_format}]
|
|
|
|
'tensor_name': op.output.name, 'format': op.output.data_format}]
|
|
|
|
|
|
|
|
# 将操作转换为字典格式
|
|
|
|
op_desc.append({'attr': attrs, 'impl_path': '',
|
|
|
|
op_desc.append({'attr': attrs, 'impl_path': '',
|
|
|
|
'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
|
|
|
|
'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 将Graph转换为字典格式
|
|
|
|
graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
|
|
|
|
graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
|
|
|
|
'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
|
|
|
|
'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
|
|
|
|
'platform': 'AKG', 'process': self.processor}
|
|
|
|
'platform': 'AKG', 'process': self.processor}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 如果Graph中有stitch信息
|
|
|
|
if self.stitch_info and self.stitch_info.stitch_ops:
|
|
|
|
if self.stitch_info and self.stitch_info.stitch_ops:
|
|
|
|
|
|
|
|
# 将stitch信息转换为字典格式
|
|
|
|
buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)}
|
|
|
|
buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)}
|
|
|
|
|
|
|
|
# 如果有stitch_atomic_ops
|
|
|
|
if self.stitch_info.stitch_atomic_ops:
|
|
|
|
if self.stitch_info.stitch_atomic_ops:
|
|
|
|
|
|
|
|
# 将stitch_atomic_ops转换为字典格式
|
|
|
|
buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops)
|
|
|
|
buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops)
|
|
|
|
|
|
|
|
# 将stitch信息添加到Graph字典中
|
|
|
|
graph_desc['buffer_stitch'] = buffer_stitch
|
|
|
|
graph_desc['buffer_stitch'] = buffer_stitch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 返回Graph字典
|
|
|
|
return graph_desc
|
|
|
|
return graph_desc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -506,13 +651,16 @@ class GraphVisitor:
|
|
|
|
"""Graph visitor"""
|
|
|
|
"""Graph visitor"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, forward=True):
|
|
|
|
def __init__(self, forward=True):
|
|
|
|
|
|
|
|
# 初始化forward参数,默认为True
|
|
|
|
self.forward = forward
|
|
|
|
self.forward = forward
|
|
|
|
|
|
|
|
|
|
|
|
def visit_graph(self, graph):
|
|
|
|
def visit_graph(self, graph):
|
|
|
|
"""Visit graph"""
|
|
|
|
"""Visit graph"""
|
|
|
|
|
|
|
|
# 如果forward为True,则按照顺序遍历graph中的ops
|
|
|
|
if self.forward:
|
|
|
|
if self.forward:
|
|
|
|
for op in graph.ops:
|
|
|
|
for op in graph.ops:
|
|
|
|
self.visit(op)
|
|
|
|
self.visit(op)
|
|
|
|
|
|
|
|
# 如果forward为False,则按照逆序遍历graph中的ops
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
for i in range(len(graph.ops)-1, -1, -1):
|
|
|
|
for i in range(len(graph.ops)-1, -1, -1):
|
|
|
|
self.visit(graph.ops[i])
|
|
|
|
self.visit(graph.ops[i])
|
|
|
@ -528,12 +676,18 @@ class AlignShape(GraphVisitor):
|
|
|
|
def visit(op):
|
|
|
|
def visit(op):
|
|
|
|
"""Visit op node"""
|
|
|
|
"""Visit op node"""
|
|
|
|
prim = PrimLib.get_prim(op)
|
|
|
|
prim = PrimLib.get_prim(op)
|
|
|
|
|
|
|
|
# 如果op的迭代类型是ELEMWISE、BROADCAST或REDUCE,则需要进行形状对齐
|
|
|
|
if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
|
|
|
|
if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
|
|
|
|
|
|
|
|
# 获取op的输出维度
|
|
|
|
out_dim = len(op.output.shape)
|
|
|
|
out_dim = len(op.output.shape)
|
|
|
|
|
|
|
|
# 初始化对齐维度为输出维度
|
|
|
|
align_dim = out_dim
|
|
|
|
align_dim = out_dim
|
|
|
|
|
|
|
|
# 遍历op的输入
|
|
|
|
for t in op.inputs:
|
|
|
|
for t in op.inputs:
|
|
|
|
|
|
|
|
# 如果输入的维度大于对齐维度,则更新对齐维度
|
|
|
|
if len(t.shape) > align_dim:
|
|
|
|
if len(t.shape) > align_dim:
|
|
|
|
align_dim = len(t.shape)
|
|
|
|
align_dim = len(t.shape)
|
|
|
|
|
|
|
|
# 如果对齐维度大于输出维度,则对op的输出形状进行对齐
|
|
|
|
if align_dim > out_dim:
|
|
|
|
if align_dim > out_dim:
|
|
|
|
op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
|
|
|
|
op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
|
|
|
|
|
|
|
|
|
|
|
|