|
|
|
@ -11,46 +11,59 @@
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ===========================================================================
|
|
|
|
|
"""generate json desc for BatchNormGrad"""
|
|
|
|
|
# ========================================================================
|
|
|
|
|
# ===
|
|
|
|
|
# 版权声明
|
|
|
|
|
# 根据Apache License 2.0授权
|
|
|
|
|
# 除非遵守许可,否则不得使用此文件
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
为BatchNormGrad生成json描述,BatchNormGrad是用于计算Batch Normalization层梯度的类。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# 导入必要的模块和类
|
|
|
|
|
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
|
|
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
|
|
|
from .expand_dims import ExpandDims
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 定义BatchNormGrad类,继承自Expander
|
|
|
|
|
@VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
|
|
|
@VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
|
|
|
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
|
|
|
@VLD.check_attrs('is_training', 'epsilon')
|
|
|
|
|
class BatchNormGrad(Expander):
|
|
|
|
|
"""BatchNormGrad expander"""
|
|
|
|
|
"""BatchNormGrad扩展器,用于计算Batch Normalization层的梯度"""
|
|
|
|
|
|
|
|
|
|
# 定义扩展方法,该方法将被调用来执行BatchNormGrad的计算
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
# get op info
|
|
|
|
|
input_dy = self.inputs[0]
|
|
|
|
|
input_x = self.inputs[1]
|
|
|
|
|
input_scale = self.inputs[2]
|
|
|
|
|
input_save_mean = self.inputs[3]
|
|
|
|
|
input_save_inv_variance = self.inputs[4]
|
|
|
|
|
# 获取操作信息,包括梯度、输入数据、尺度、保存的均值和倒数方差
|
|
|
|
|
input_dy = self.inputs[0] # 输入数据的梯度
|
|
|
|
|
input_x = self.inputs[1] # 输入数据
|
|
|
|
|
input_scale = self.inputs[2] # 输入数据的尺度
|
|
|
|
|
input_save_mean = self.inputs[3] # 保存的均值
|
|
|
|
|
input_save_inv_variance = self.inputs[4] # 保存的倒数方差
|
|
|
|
|
|
|
|
|
|
# 根据输入数据的格式计算reduce_axis,用于后续的ReduceSum操作
|
|
|
|
|
reduce_axis = ()
|
|
|
|
|
shape_x = input_x.shape
|
|
|
|
|
if input_x.data_format == DF.NHWC:
|
|
|
|
|
reduce_axis = (0, 1, 2)
|
|
|
|
|
num = shape_x[0] * shape_x[1] * shape_x[2]
|
|
|
|
|
else:
|
|
|
|
|
reduce_axis = (0, 2, 3)
|
|
|
|
|
num = shape_x[0] * shape_x[2] * shape_x[3]
|
|
|
|
|
ori_type = input_x.dtype
|
|
|
|
|
if input_x.data_format == DF.NHWC: # 如果数据格式为NHWC
|
|
|
|
|
reduce_axis = (0, 1, 2) # 指定ReduceSum的轴
|
|
|
|
|
num = shape_x[0] * shape_x[1] * shape_x[2] # 计算元素总数
|
|
|
|
|
else: # 否则,假设数据格式为NCHW
|
|
|
|
|
reduce_axis = (0, 2, 3) # 指定ReduceSum的轴
|
|
|
|
|
num = shape_x[0] * shape_x[2] * shape_x[3] # 计算元素总数
|
|
|
|
|
ori_type = input_x.dtype # 原始数据类型
|
|
|
|
|
|
|
|
|
|
# 如果原始数据类型为float16,则转换为float32进行计算,以避免精度损失
|
|
|
|
|
if ori_type == 'float16':
|
|
|
|
|
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
|
|
|
|
|
if input_dy.dtype == 'float16':
|
|
|
|
|
input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'})
|
|
|
|
|
num_rec = -1.0 / num
|
|
|
|
|
num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
|
|
|
|
|
dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
|
|
|
|
num_rec = -1.0 / num # 计算倒数
|
|
|
|
|
num_rec_v = graph_builder.value(input_scale.dtype, num_rec) # 创建倒数的值
|
|
|
|
|
dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) # 计算dbeta,即beta的梯度
|
|
|
|
|
|
|
|
|
|
# in training input_save_inv_variance means 1 / sqrt(variance + epsilon), which is calculated in forward pass
|
|
|
|
|
# 根据是否在训练中,计算inv_variance(倒数方差)
|
|
|
|
|
if self.attrs['is_training']:
|
|
|
|
|
inv_variance = input_save_inv_variance
|
|
|
|
|
else:
|
|
|
|
@ -61,7 +74,7 @@ class BatchNormGrad(Expander):
|
|
|
|
|
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
|
|
|
|
|
inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps])
|
|
|
|
|
|
|
|
|
|
# compute dgamma
|
|
|
|
|
# 计算dgamma(gamma的梯度)
|
|
|
|
|
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
|
|
|
input_save_mean = graph_builder.emit(
|
|
|
|
|
'Reshape', [input_save_mean], attrs={'shape': ExpandDims.infer_shape(input_save_mean.shape, [-1, -1])})
|
|
|
|
@ -69,13 +82,13 @@ class BatchNormGrad(Expander):
|
|
|
|
|
'Reshape', [inv_variance], attrs={'shape': ExpandDims.infer_shape(inv_variance.shape, [-1, -1])})
|
|
|
|
|
input_scale = graph_builder.emit(
|
|
|
|
|
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
|
|
|
|
|
x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean])
|
|
|
|
|
x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance])
|
|
|
|
|
dgamma_param = graph_builder.emit('Mul', [input_dy, x_div])
|
|
|
|
|
x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean]) # 计算x减去均值
|
|
|
|
|
x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance]) # 计算x除以倒数方差
|
|
|
|
|
dgamma_param = graph_builder.emit('Mul', [input_dy, x_div]) # 计算dgamma参数
|
|
|
|
|
dgamma = graph_builder.emit(
|
|
|
|
|
'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
|
|
|
|
'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) # 计算dgamma
|
|
|
|
|
|
|
|
|
|
# compute dx
|
|
|
|
|
# 计算dx(x的梯度)
|
|
|
|
|
if self.attrs['is_training']:
|
|
|
|
|
tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta])
|
|
|
|
|
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
|
|
@ -95,11 +108,12 @@ class BatchNormGrad(Expander):
|
|
|
|
|
y_scale = graph_builder.emit('Mul', [input_scale, input_dy])
|
|
|
|
|
dx = graph_builder.emit('Mul', [inv_variance, y_scale])
|
|
|
|
|
if ori_type == 'float16':
|
|
|
|
|
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
|
|
|
|
|
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) # 如果原始数据类型为float16,则转换回float16
|
|
|
|
|
|
|
|
|
|
# set output tensors' data_format
|
|
|
|
|
# 设置输出张量的数据格式
|
|
|
|
|
dx.data_format = self.outputs[0]['format']
|
|
|
|
|
dgamma.data_format = self.outputs[1]['format']
|
|
|
|
|
dbeta.data_format = self.outputs[2]['format']
|
|
|
|
|
|
|
|
|
|
# 返回计算结果
|
|
|
|
|
return dx, dgamma, dbeta
|