diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/dropout_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/dropout_grad.py index ac7a011f..6293e4de 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/dropout_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/dropout_grad.py @@ -13,18 +13,36 @@ # limitations under the License. # =========================================================================== """generate json desc for DropoutGrad""" +# 导入Expander和ExpanderInfoValidator类 from ._utils import Expander, ExpanderInfoValidator as VLD +# 定义DropoutGrad类,继承自Expander类 @VLD.check_all_formats_same @VLD.check_attrs('keep_prob') class DropoutGrad(Expander): """DropoutGrad expander""" def _expand(self, graph_builder): + """ + 对输入数据进行扩展操作。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tensor: 扩展后的输入数据。 + + """ + # 获取输入数据和掩码 input_dy, input_mask = self.inputs + # 获取保持概率 keep_prob = self.attrs['keep_prob'] + # 计算保持概率的倒数 r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob) + # 计算输入数据和保持概率的乘积 result = graph_builder.emit('Mul', [input_dy, r_keep_prob]) + # 计算乘积和掩码的乘积 result = graph_builder.emit('Mul', [result, input_mask]) + # 返回结果 return result