_extends\graph_kernel\expanders\dropout_grad.py

branch-yixin
yixin 2 months ago
parent e13c287655
commit c64739f456

@ -13,18 +13,36 @@
# limitations under the License.
# ===========================================================================
"""generate json desc for DropoutGrad"""
# 导入Expander和ExpanderInfoValidator类
from ._utils import Expander, ExpanderInfoValidator as VLD
# 定义DropoutGrad类继承自Expander类
@VLD.check_all_formats_same
@VLD.check_attrs('keep_prob')
class DropoutGrad(Expander):
"""DropoutGrad expander"""
def _expand(self, graph_builder):
"""
对输入数据进行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展后的输入数据
"""
# 获取输入数据和掩码
input_dy, input_mask = self.inputs
# 获取保持概率
keep_prob = self.attrs['keep_prob']
# 计算保持概率的倒数
r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob)
# 计算输入数据和保持概率的乘积
result = graph_builder.emit('Mul', [input_dy, r_keep_prob])
# 计算乘积和掩码的乘积
result = graph_builder.emit('Mul', [result, input_mask])
# 返回结果
return result

Loading…
Cancel
Save