|
|
|
@ -13,18 +13,36 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ===========================================================================
|
|
|
|
|
"""generate json desc for DropoutGrad"""
|
|
|
|
|
# 导入Expander和ExpanderInfoValidator类
|
|
|
|
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 定义DropoutGrad类,继承自Expander类
|
|
|
|
|
@VLD.check_all_formats_same
|
|
|
|
|
@VLD.check_attrs('keep_prob')
|
|
|
|
|
class DropoutGrad(Expander):
|
|
|
|
|
"""DropoutGrad expander"""
|
|
|
|
|
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
"""
|
|
|
|
|
对输入数据进行扩展操作。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
graph_builder (GraphBuilder): 图构建器对象。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: 扩展后的输入数据。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# 获取输入数据和掩码
|
|
|
|
|
input_dy, input_mask = self.inputs
|
|
|
|
|
# 获取保持概率
|
|
|
|
|
keep_prob = self.attrs['keep_prob']
|
|
|
|
|
# 计算保持概率的倒数
|
|
|
|
|
r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob)
|
|
|
|
|
# 计算输入数据和保持概率的乘积
|
|
|
|
|
result = graph_builder.emit('Mul', [input_dy, r_keep_prob])
|
|
|
|
|
# 计算乘积和掩码的乘积
|
|
|
|
|
result = graph_builder.emit('Mul', [result, input_mask])
|
|
|
|
|
# 返回结果
|
|
|
|
|
return result
|
|
|
|
|