add comments for _extends\graph_kernel\expanders\batchnorm_grad.py

branch-yixin
yixin 7 months ago
parent e414d6025d
commit eef69d070e

@ -11,46 +11,59 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# =========================================================================== # ========================================================================
"""generate json desc for BatchNormGrad""" # ===
# 版权声明
# 根据Apache License 2.0授权
# 除非遵守许可,否则不得使用此文件
"""
为BatchNormGrad生成json描述BatchNormGrad是用于计算Batch Normalization层梯度的类
"""
# 导入必要的模块和类
from mindspore._extends.graph_kernel.model.model import DataFormat as DF from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
from .expand_dims import ExpandDims from .expand_dims import ExpandDims
# 定义BatchNormGrad类继承自Expander
@VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.check_attrs('is_training', 'epsilon') @VLD.check_attrs('is_training', 'epsilon')
class BatchNormGrad(Expander): class BatchNormGrad(Expander):
"""BatchNormGrad expander""" """BatchNormGrad扩展器用于计算Batch Normalization层的梯度"""
# 定义扩展方法该方法将被调用来执行BatchNormGrad的计算
def _expand(self, graph_builder): def _expand(self, graph_builder):
# get op info # 获取操作信息,包括梯度、输入数据、尺度、保存的均值和倒数方差
input_dy = self.inputs[0] input_dy = self.inputs[0] # 输入数据的梯度
input_x = self.inputs[1] input_x = self.inputs[1] # 输入数据
input_scale = self.inputs[2] input_scale = self.inputs[2] # 输入数据的尺度
input_save_mean = self.inputs[3] input_save_mean = self.inputs[3] # 保存的均值
input_save_inv_variance = self.inputs[4] input_save_inv_variance = self.inputs[4] # 保存的倒数方差
# 根据输入数据的格式计算reduce_axis用于后续的ReduceSum操作
reduce_axis = () reduce_axis = ()
shape_x = input_x.shape shape_x = input_x.shape
if input_x.data_format == DF.NHWC: if input_x.data_format == DF.NHWC: # 如果数据格式为NHWC
reduce_axis = (0, 1, 2) reduce_axis = (0, 1, 2) # 指定ReduceSum的轴
num = shape_x[0] * shape_x[1] * shape_x[2] num = shape_x[0] * shape_x[1] * shape_x[2] # 计算元素总数
else: else: # 否则假设数据格式为NCHW
reduce_axis = (0, 2, 3) reduce_axis = (0, 2, 3) # 指定ReduceSum的轴
num = shape_x[0] * shape_x[2] * shape_x[3] num = shape_x[0] * shape_x[2] * shape_x[3] # 计算元素总数
ori_type = input_x.dtype ori_type = input_x.dtype # 原始数据类型
# 如果原始数据类型为float16则转换为float32进行计算以避免精度损失
if ori_type == 'float16': if ori_type == 'float16':
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'}) input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
if input_dy.dtype == 'float16': if input_dy.dtype == 'float16':
input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'}) input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'})
num_rec = -1.0 / num num_rec = -1.0 / num # 计算倒数
num_rec_v = graph_builder.value(input_scale.dtype, num_rec) num_rec_v = graph_builder.value(input_scale.dtype, num_rec) # 创建倒数的值
dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) # 计算dbeta即beta的梯度
# in training input_save_inv_variance means 1 / sqrt(variance + epsilon), which is calculated in forward pass # 根据是否在训练中计算inv_variance倒数方差
if self.attrs['is_training']: if self.attrs['is_training']:
inv_variance = input_save_inv_variance inv_variance = input_save_inv_variance
else: else:
@ -61,7 +74,7 @@ class BatchNormGrad(Expander):
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one) scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps]) inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps])
# compute dgamma # 计算dgammagamma的梯度
if input_x.data_format in (DF.DEFAULT, DF.NCHW): if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_save_mean = graph_builder.emit( input_save_mean = graph_builder.emit(
'Reshape', [input_save_mean], attrs={'shape': ExpandDims.infer_shape(input_save_mean.shape, [-1, -1])}) 'Reshape', [input_save_mean], attrs={'shape': ExpandDims.infer_shape(input_save_mean.shape, [-1, -1])})
@ -69,13 +82,13 @@ class BatchNormGrad(Expander):
'Reshape', [inv_variance], attrs={'shape': ExpandDims.infer_shape(inv_variance.shape, [-1, -1])}) 'Reshape', [inv_variance], attrs={'shape': ExpandDims.infer_shape(inv_variance.shape, [-1, -1])})
input_scale = graph_builder.emit( input_scale = graph_builder.emit(
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])}) 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean]) x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean]) # 计算x减去均值
x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance]) x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance]) # 计算x除以倒数方差
dgamma_param = graph_builder.emit('Mul', [input_dy, x_div]) dgamma_param = graph_builder.emit('Mul', [input_dy, x_div]) # 计算dgamma参数
dgamma = graph_builder.emit( dgamma = graph_builder.emit(
'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) 'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) # 计算dgamma
# compute dx # 计算dxx的梯度
if self.attrs['is_training']: if self.attrs['is_training']:
tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta]) tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta])
if input_x.data_format in (DF.DEFAULT, DF.NCHW): if input_x.data_format in (DF.DEFAULT, DF.NCHW):
@ -95,11 +108,12 @@ class BatchNormGrad(Expander):
y_scale = graph_builder.emit('Mul', [input_scale, input_dy]) y_scale = graph_builder.emit('Mul', [input_scale, input_dy])
dx = graph_builder.emit('Mul', [inv_variance, y_scale]) dx = graph_builder.emit('Mul', [inv_variance, y_scale])
if ori_type == 'float16': if ori_type == 'float16':
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) # 如果原始数据类型为float16则转换回float16
# set output tensors' data_format # 设置输出张量的数据格式
dx.data_format = self.outputs[0]['format'] dx.data_format = self.outputs[0]['format']
dgamma.data_format = self.outputs[1]['format'] dgamma.data_format = self.outputs[1]['format']
dbeta.data_format = self.outputs[2]['format'] dbeta.data_format = self.outputs[2]['format']
return dx, dgamma, dbeta # 返回计算结果
return dx, dgamma, dbeta
Loading…
Cancel
Save