You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
67 lines
2.2 KiB
67 lines
2.2 KiB
5 months ago
|
import torch
|
||
|
import torch.fx
|
||
|
from torch import nn, Tensor
|
||
|
|
||
|
from ..utils import _log_api_usage_once
|
||
|
|
||
|
|
||
|
def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) -> Tensor:
|
||
|
"""
|
||
|
Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
|
||
|
<https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
|
||
|
branches of residual architectures.
|
||
|
|
||
|
Args:
|
||
|
input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
|
||
|
being its batch i.e. a batch with ``N`` rows.
|
||
|
p (float): probability of the input to be zeroed.
|
||
|
mode (str): ``"batch"`` or ``"row"``.
|
||
|
``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
|
||
|
randomly selected rows from the batch.
|
||
|
training: apply stochastic depth if is ``True``. Default: ``True``
|
||
|
|
||
|
Returns:
|
||
|
Tensor[N, ...]: The randomly zeroed tensor.
|
||
|
"""
|
||
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||
|
_log_api_usage_once(stochastic_depth)
|
||
|
if p < 0.0 or p > 1.0:
|
||
|
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
|
||
|
if mode not in ["batch", "row"]:
|
||
|
raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
|
||
|
if not training or p == 0.0:
|
||
|
return input
|
||
|
|
||
|
survival_rate = 1.0 - p
|
||
|
if mode == "row":
|
||
|
size = [input.shape[0]] + [1] * (input.ndim - 1)
|
||
|
else:
|
||
|
size = [1] * input.ndim
|
||
|
noise = torch.empty(size, dtype=input.dtype, device=input.device)
|
||
|
noise = noise.bernoulli_(survival_rate)
|
||
|
if survival_rate > 0.0:
|
||
|
noise.div_(survival_rate)
|
||
|
return input * noise
|
||
|
|
||
|
|
||
|
torch.fx.wrap("stochastic_depth")
|
||
|
|
||
|
|
||
|
class StochasticDepth(nn.Module):
|
||
|
"""
|
||
|
See :func:`stochastic_depth`.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, p: float, mode: str) -> None:
|
||
|
super().__init__()
|
||
|
_log_api_usage_once(self)
|
||
|
self.p = p
|
||
|
self.mode = mode
|
||
|
|
||
|
def forward(self, input: Tensor) -> Tensor:
|
||
|
return stochastic_depth(input, self.p, self.mode, self.training)
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
|
||
|
return s
|