|
|
@ -13,37 +13,64 @@
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
# ===========================================================================
|
|
|
|
# ===========================================================================
|
|
|
|
"""generate json desc for bias_add"""
|
|
|
|
"""generate json desc for bias_add"""
|
|
|
|
|
|
|
|
# 导入MindSpore的DataFormat类
|
|
|
|
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
|
|
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
|
|
|
|
|
|
# 导入Expander和ExpanderInfoValidator类
|
|
|
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
|
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 为BiasAddGrad类添加DF.DEFAULT、DF.NHWC、DF.NCHW、DF.FRAC_NZ格式的验证
|
|
|
|
@VLD.add_format(DF.DEFAULT)
|
|
|
|
@VLD.add_format(DF.DEFAULT)
|
|
|
|
@VLD.add_format(DF.NHWC)
|
|
|
|
@VLD.add_format(DF.NHWC)
|
|
|
|
@VLD.add_format(DF.NCHW)
|
|
|
|
@VLD.add_format(DF.NCHW)
|
|
|
|
@VLD.add_format(DF.FRAC_NZ)
|
|
|
|
@VLD.add_format(DF.FRAC_NZ)
|
|
|
|
|
|
|
|
# 定义BiasAddGrad类,继承自Expander类
|
|
|
|
class BiasAddGrad(Expander):
|
|
|
|
class BiasAddGrad(Expander):
|
|
|
|
"""BiasAddGrad expander"""
|
|
|
|
"""BiasAddGrad expander"""
|
|
|
|
|
|
|
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
内部方法,用于扩展输入张量的维度。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
graph_builder (GraphBuilder): 图构建器对象,用于生成图操作。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
Tensor: 扩展维度后的张量。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
# 获取输入张量
|
|
|
|
x = self.inputs[0]
|
|
|
|
x = self.inputs[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 定义reduce_axis,用于指定求和的维度
|
|
|
|
reduce_axis = ()
|
|
|
|
reduce_axis = ()
|
|
|
|
|
|
|
|
# 如果输入张量的数据格式为NHWC,则reduce_axis为(0, 1, 2)
|
|
|
|
if x.data_format == DF.NHWC:
|
|
|
|
if x.data_format == DF.NHWC:
|
|
|
|
reduce_axis = (0, 1, 2)
|
|
|
|
reduce_axis = (0, 1, 2)
|
|
|
|
|
|
|
|
# 如果输入张量的数据格式为NCHW,则reduce_axis为(0, 2, 3)
|
|
|
|
elif x.data_format == DF.NCHW:
|
|
|
|
elif x.data_format == DF.NCHW:
|
|
|
|
reduce_axis = (0, 2, 3)
|
|
|
|
reduce_axis = (0, 2, 3)
|
|
|
|
|
|
|
|
# 如果输入张量的数据格式为FRAC_NZ,则reduce_axis为(-2, -3)
|
|
|
|
elif x.data_format == DF.FRAC_NZ:
|
|
|
|
elif x.data_format == DF.FRAC_NZ:
|
|
|
|
reduce_axis = (-2, -3)
|
|
|
|
reduce_axis = (-2, -3)
|
|
|
|
|
|
|
|
# 如果输入张量的数据格式为DefaultFormat,则根据shape的长度确定reduce_axis
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# DefaultFormat shape's length should be from 2 to 4
|
|
|
|
# DefaultFormat shape's length should be from 2 to 4
|
|
|
|
|
|
|
|
# 如果x的维度为2,则reduce_axis为(0,)
|
|
|
|
if len(x.shape) == 2:
|
|
|
|
if len(x.shape) == 2:
|
|
|
|
reduce_axis = (0,)
|
|
|
|
reduce_axis = (0,)
|
|
|
|
|
|
|
|
# 如果x的维度为3,则reduce_axis为(0, 1)
|
|
|
|
elif len(x.shape) == 3:
|
|
|
|
elif len(x.shape) == 3:
|
|
|
|
reduce_axis = (0, 1)
|
|
|
|
reduce_axis = (0, 1)
|
|
|
|
|
|
|
|
# 否则,reduce_axis为(0, 2, 3)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
reduce_axis = (0, 2, 3)
|
|
|
|
reduce_axis = (0, 2, 3)
|
|
|
|
|
|
|
|
# 发射ReduceSum操作,计算x的reduce_sum,reduce_axis为reduce_axis,keep_dims为False
|
|
|
|
result = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
|
|
|
result = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
|
|
|
|
|
|
|
# 如果x的数据格式为DF.FRAC_NZ,则将result的shape改为x.shape[:-4] + [x.shape[-1] * x.shape[-4]]
|
|
|
|
if x.data_format == DF.FRAC_NZ:
|
|
|
|
if x.data_format == DF.FRAC_NZ:
|
|
|
|
out_shape = x.shape[:-4] + [x.shape[-1] * x.shape[-4]]
|
|
|
|
out_shape = x.shape[:-4] + [x.shape[-1] * x.shape[-4]]
|
|
|
|
|
|
|
|
# 发射Reshape操作,将result的shape改为out_shape
|
|
|
|
result = graph_builder.emit('Reshape', [result], attrs={'shape': out_shape})
|
|
|
|
result = graph_builder.emit('Reshape', [result], attrs={'shape': out_shape})
|
|
|
|
|
|
|
|
# 返回result
|
|
|
|
return result
|
|
|
|
return result
|
|
|
|