|
|
|
@ -18,6 +18,7 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
|
|
|
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
|
|
|
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
|
|
|
|
|
|
|
|
# 定义常量
|
|
|
|
|
M_ALIGN = 32
|
|
|
|
|
N_ALIGN = 32
|
|
|
|
|
K_ALIGN = 16
|
|
|
|
@ -29,6 +30,7 @@ C_CHANNEL_ALIGN = 16
|
|
|
|
|
OUT_NHW_ALIGN = 128
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 添加格式验证
|
|
|
|
|
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
|
|
|
|
@VLD.add_format(DF.NHWC, DF.NHWC)
|
|
|
|
|
@VLD.check_attrs('format', 'pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation')
|
|
|
|
@ -47,6 +49,16 @@ class Conv2D(Expander):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, expand_info):
|
|
|
|
|
"""
|
|
|
|
|
类的构造函数
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
expand_info (dict): 扩展信息字典,包含一些扩展的配置参数。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
super().__init__(expand_info)
|
|
|
|
|
self.dst_type = self.outputs[0]['data_type']
|
|
|
|
|
self.dst_format = self.outputs[0]['format']
|
|
|
|
@ -59,6 +71,19 @@ class Conv2D(Expander):
|
|
|
|
|
self.k = 0
|
|
|
|
|
|
|
|
|
|
def _optimize_to_matmul(self):
|
|
|
|
|
"""
|
|
|
|
|
检查是否可以将Conv2D优化为MatMul。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
无
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
bool: 如果可以将Conv2D优化为MatMul,则返回True;否则返回False。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""
|
|
|
|
|
Check if the Conv2D can be optimized to MatMul.
|
|
|
|
|
"""
|
|
|
|
|
stride = self.attrs['stride']
|
|
|
|
|
dilation = self.attrs['dilation']
|
|
|
|
|
_, h, w, _ = self.inputs[1]['shape']
|
|
|
|
@ -68,6 +93,18 @@ class Conv2D(Expander):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def _common_check(self):
|
|
|
|
|
"""
|
|
|
|
|
对输入和属性的通用检查
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
无
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
无
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
GKException: 如果输入数据类型不是 float16,或者输入格式不是 NHWC,或者属性 groups 和 group 不是 1,或者属性 dilation 不是 [1, 1, 1, 1],抛出异常
|
|
|
|
|
"""
|
|
|
|
|
"""common check for inputs and attrs"""
|
|
|
|
|
type_0 = self.inputs[0]['data_type']
|
|
|
|
|
type_1 = self.inputs[1]['data_type']
|
|
|
|
@ -91,26 +128,52 @@ class Conv2D(Expander):
|
|
|
|
|
.format(dilation))
|
|
|
|
|
|
|
|
|
|
def _check(self):
|
|
|
|
|
"""
|
|
|
|
|
检查卷积2D操作的参数和输入是否合法。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
无
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
GKException: 当输入参数或输入维度不满足要求时抛出异常。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
无
|
|
|
|
|
"""
|
|
|
|
|
# 调用_common_check()方法
|
|
|
|
|
self._common_check()
|
|
|
|
|
|
|
|
|
|
# 获取pad_list
|
|
|
|
|
pad_list = self.attrs['pad_list']
|
|
|
|
|
# 检查pad_list的维度是否为4
|
|
|
|
|
check_nd(pad_list, 4)
|
|
|
|
|
# 调用conv_had_pad()方法,判断是否有pad
|
|
|
|
|
self.has_pad = conv_had_pad(pad_list, self.attrs['pad_mode'])
|
|
|
|
|
|
|
|
|
|
# 获取输入的shape
|
|
|
|
|
shape_0 = self.inputs[0]['shape']
|
|
|
|
|
shape_1 = self.inputs[1]['shape']
|
|
|
|
|
# 获取stride
|
|
|
|
|
stride = self.attrs['stride']
|
|
|
|
|
# 检查shape_0的维度是否为4
|
|
|
|
|
check_nd(shape_0, 4)
|
|
|
|
|
# 检查shape_1的维度是否为4
|
|
|
|
|
check_nd(shape_1, 4)
|
|
|
|
|
# 检查stride的维度是否为4
|
|
|
|
|
check_nd(stride, 4)
|
|
|
|
|
# 获取shape_0的各个维度
|
|
|
|
|
n0, h0, w0, c0 = shape_0
|
|
|
|
|
# 获取shape_1的各个维度
|
|
|
|
|
n1, h1, w1, c1 = shape_1
|
|
|
|
|
# 检查n0是否为N0_CHANNEL_ALIGN的倍数
|
|
|
|
|
if (n0 % N0_CHANNEL_ALIGN) != 0:
|
|
|
|
|
raise GKException("For 'Conv2D', N channel of first input should be multiples of {}, but got {}"
|
|
|
|
|
.format(N0_CHANNEL_ALIGN, n0))
|
|
|
|
|
# 检查n1是否为N1_CHANNEL_ALIGN的倍数
|
|
|
|
|
if (n1 % N1_CHANNEL_ALIGN) != 0:
|
|
|
|
|
raise GKException("For 'Conv2D', N channel of second input should be multiples of {}, but got {}"
|
|
|
|
|
.format(N1_CHANNEL_ALIGN, n1))
|
|
|
|
|
# 检查c0和c1是否相等,并且是否为C_CHANNEL_ALIGN的倍数
|
|
|
|
|
if c0 != c1 or (c0 % C_CHANNEL_ALIGN) != 0:
|
|
|
|
|
raise GKException("For 'Conv2D', C channel of inputs should be same and also be multiples of {}, but got "
|
|
|
|
|
"{} and {}".format(C_CHANNEL_ALIGN, c0, c1))
|
|
|
|
@ -130,68 +193,106 @@ class Conv2D(Expander):
|
|
|
|
|
|
|
|
|
|
# check if can optimize to matmul
|
|
|
|
|
self.m, self.n, self.k = n0 * h0 * w0, n1, c1
|
|
|
|
|
# 调用_optimize_to_matmul()方法,判断是否可以优化为matmul
|
|
|
|
|
self.can_optimize_to_matmul = self._optimize_to_matmul()
|
|
|
|
|
|
|
|
|
|
# requirements
|
|
|
|
|
if self.can_optimize_to_matmul:
|
|
|
|
|
# 如果可以优化为matmul,检查k是否大于K_LIMIT
|
|
|
|
|
if self.k > K_LIMIT:
|
|
|
|
|
raise GKException("For 'Conv2D', if transformed to 'MatMul', C0 should not be larger than {}, but got "
|
|
|
|
|
"{}".format(K_LIMIT, self.k))
|
|
|
|
|
# 如果可以优化为matmul,检查m*n*k的总大小是否大于MNK_LIMIT
|
|
|
|
|
if self.m * self.n * self.k >= MNK_LIMIT:
|
|
|
|
|
raise GKException("For 'Conv2D', if transformed to 'MatMul', The total size should not be larger than "
|
|
|
|
|
"{}, but got {}".format(MNK_LIMIT, self.m * self.n * self.k))
|
|
|
|
|
else:
|
|
|
|
|
# 如果不能优化为matmul,计算输出的大小
|
|
|
|
|
out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1
|
|
|
|
|
# 检查n0*out_h*out_w是否为OUT_NHW_ALIGN的倍数
|
|
|
|
|
if ((n0 * out_h * out_w) % OUT_NHW_ALIGN) != 0:
|
|
|
|
|
raise GKException("For 'Conv2D', N({}) * H({}) * W({}) of output should be multiplies of {}"
|
|
|
|
|
.format(n0, out_h, out_w, OUT_NHW_ALIGN))
|
|
|
|
|
# 检查stride是否为[1, 1, 2, 2]
|
|
|
|
|
if stride != [1, 1, 2, 2]:
|
|
|
|
|
raise GKException("For 'Conv2D', value of attr 'stride' should be [1, 1, 2, 2], but got {}"
|
|
|
|
|
.format(stride))
|
|
|
|
|
|
|
|
|
|
# 保存pad后的shape
|
|
|
|
|
self.shape_0_pad = [n0, h0, w0, c0]
|
|
|
|
|
self.shape_1_pad = [n1, h1, w1, c1]
|
|
|
|
|
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
"""
|
|
|
|
|
对输入进行扩展处理。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
graph_builder (GraphBuilder): 图构建器对象。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: 扩展处理后的结果。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# 获取输入0
|
|
|
|
|
input_0 = self.inputs[0]
|
|
|
|
|
# 获取输入1
|
|
|
|
|
input_1 = self.inputs[1]
|
|
|
|
|
# 获取输入0的形状
|
|
|
|
|
n0, _, _, c0 = input_0.shape
|
|
|
|
|
# 获取输入1的形状
|
|
|
|
|
n1, _, _, c1 = input_1.shape
|
|
|
|
|
# 获取输入0的填充形状
|
|
|
|
|
n0_p, h0_p, w0_p, c0_p = self.shape_0_pad
|
|
|
|
|
# 获取输入1的填充形状
|
|
|
|
|
n1_p, _, _, c1_p = self.shape_1_pad
|
|
|
|
|
|
|
|
|
|
pad_value = 0
|
|
|
|
|
# input0 pad
|
|
|
|
|
# 初始化输入0的填充前后的值
|
|
|
|
|
input_0_pad_before = [0, 0, 0, 0]
|
|
|
|
|
input_0_pad_after = [0, 0, 0, 0]
|
|
|
|
|
# 如果有填充,则获取填充列表
|
|
|
|
|
if self.has_pad:
|
|
|
|
|
pad_list = self.attrs['pad_list']
|
|
|
|
|
# 设置输入0的填充前后的值
|
|
|
|
|
input_0_pad_before = [0, pad_list[0], pad_list[2], 0]
|
|
|
|
|
input_0_pad_after = [0, pad_list[1], pad_list[3], 0]
|
|
|
|
|
# 设置输入0的填充后的值
|
|
|
|
|
input_0_pad_after[0] = n0_p - n0
|
|
|
|
|
input_0_pad_after[3] = c0_p - c0
|
|
|
|
|
# 如果输入0的填充前后的值不为默认值,则进行填充操作
|
|
|
|
|
if input_0_pad_before != [0, 0, 0, 0] or input_0_pad_after != [0, 0, 0, 0]:
|
|
|
|
|
# 发射填充操作
|
|
|
|
|
input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': input_0_pad_before,
|
|
|
|
|
'tail': input_0_pad_after,
|
|
|
|
|
'pad_val': pad_value})
|
|
|
|
|
# input1 pad
|
|
|
|
|
# 计算input_1的pad值
|
|
|
|
|
input_1_pad_after = [n1_p - n1, 0, 0, c1_p - c1]
|
|
|
|
|
# 如果input_1的pad值不为0,则进行pad操作
|
|
|
|
|
if input_1_pad_after != [0, 0, 0, 0]:
|
|
|
|
|
input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': [0, 0, 0, 0],
|
|
|
|
|
'tail': input_1_pad_after,
|
|
|
|
|
'pad_val': pad_value})
|
|
|
|
|
# 如果可以优化为matmul操作,则进行matmul操作
|
|
|
|
|
if self.can_optimize_to_matmul:
|
|
|
|
|
# 将input_0和input_1进行reshape操作
|
|
|
|
|
a = graph_builder.emit('Reshape', [input_0], attrs={'shape': [self.m, self.k]})
|
|
|
|
|
b = graph_builder.emit('Reshape', [input_1], attrs={'shape': [self.n, self.k]})
|
|
|
|
|
# 进行matmul操作
|
|
|
|
|
c = graph_builder.emit('MatMul', [a, b], attrs={'transpose_a': False,
|
|
|
|
|
'transpose_b': True,
|
|
|
|
|
'dst_type': self.dst_type})
|
|
|
|
|
# 将结果进行reshape操作
|
|
|
|
|
result = graph_builder.emit('Reshape', [c], attrs={'shape': [n0_p, h0_p, w0_p, n1_p],
|
|
|
|
|
'format': self.dst_format})
|
|
|
|
|
# 否则进行Conv2D操作
|
|
|
|
|
else:
|
|
|
|
|
# 设置Conv2D操作的属性
|
|
|
|
|
attrs = self.attrs
|
|
|
|
|
attrs['pad_list'] = [0, 0, 0, 0]
|
|
|
|
|
attrs['dst_type'] = self.dst_type
|
|
|
|
|
# 进行Conv2D操作
|
|
|
|
|
result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs)
|
|
|
|
|
# unpad
|
|
|
|
|
unpad_after = [input_0_pad_after[0], 0, 0, input_1_pad_after[0]]
|
|
|
|
|