add comments for _extends\graph_kernel\expanders\bias_add_grad.py

branch-yixin
yixin 2 months ago
parent b0c7662155
commit 2c4a524a6a

@ -13,37 +13,64 @@
# limitations under the License.
# ===========================================================================
"""generate json desc for bias_add"""
# 导入MindSpore的DataFormat类
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
# 导入Expander和ExpanderInfoValidator类
from ._utils import Expander, ExpanderInfoValidator as VLD
# 为BiasAddGrad类添加DF.DEFAULT、DF.NHWC、DF.NCHW、DF.FRAC_NZ格式的验证
@VLD.add_format(DF.DEFAULT)
@VLD.add_format(DF.NHWC)
@VLD.add_format(DF.NCHW)
@VLD.add_format(DF.FRAC_NZ)
# 定义BiasAddGrad类继承自Expander类
class BiasAddGrad(Expander):
"""BiasAddGrad expander"""
def _expand(self, graph_builder):
"""
内部方法用于扩展输入张量的维度
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成图操作
Returns:
Tensor: 扩展维度后的张量
"""
# 获取输入张量
x = self.inputs[0]
# 定义reduce_axis用于指定求和的维度
reduce_axis = ()
# 如果输入张量的数据格式为NHWC则reduce_axis为(0, 1, 2)
if x.data_format == DF.NHWC:
reduce_axis = (0, 1, 2)
# 如果输入张量的数据格式为NCHW则reduce_axis为(0, 2, 3)
elif x.data_format == DF.NCHW:
reduce_axis = (0, 2, 3)
# 如果输入张量的数据格式为FRAC_NZ则reduce_axis为(-2, -3)
elif x.data_format == DF.FRAC_NZ:
reduce_axis = (-2, -3)
# 如果输入张量的数据格式为DefaultFormat则根据shape的长度确定reduce_axis
else:
# DefaultFormat shape's length should be from 2 to 4
# 如果x的维度为2则reduce_axis为(0,)
if len(x.shape) == 2:
reduce_axis = (0,)
# 如果x的维度为3则reduce_axis为(0, 1)
elif len(x.shape) == 3:
reduce_axis = (0, 1)
# 否则reduce_axis为(0, 2, 3)
else:
reduce_axis = (0, 2, 3)
# 发射ReduceSum操作计算x的reduce_sumreduce_axis为reduce_axiskeep_dims为False
result = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
# 如果x的数据格式为DF.FRAC_NZ则将result的shape改为x.shape[:-4] + [x.shape[-1] * x.shape[-4]]
if x.data_format == DF.FRAC_NZ:
out_shape = x.shape[:-4] + [x.shape[-1] * x.shape[-4]]
# 发射Reshape操作将result的shape改为out_shape
result = graph_builder.emit('Reshape', [result], attrs={'shape': out_shape})
# 返回result
return result

Loading…
Cancel
Save