import warnings from typing import Callable, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor from ..utils import _log_api_usage_once, _make_ntuple interpolate = torch.nn.functional.interpolate class FrozenBatchNorm2d(torch.nn.Module): """ BatchNorm2d where the batch statistics and the affine parameters are fixed Args: num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)`` eps (float): a value added to the denominator for numerical stability. Default: 1e-5 """ def __init__( self, num_features: int, eps: float = 1e-5, ): super().__init__() _log_api_usage_once(self) self.eps = eps self.register_buffer("weight", torch.ones(num_features)) self.register_buffer("bias", torch.zeros(num_features)) self.register_buffer("running_mean", torch.zeros(num_features)) self.register_buffer("running_var", torch.ones(num_features)) def _load_from_state_dict( self, state_dict: dict, prefix: str, local_metadata: dict, strict: bool, missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str], ): num_batches_tracked_key = prefix + "num_batches_tracked" if num_batches_tracked_key in state_dict: del state_dict[num_batches_tracked_key] super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) def forward(self, x: Tensor) -> Tensor: # move reshapes to the beginning # to make it fuser-friendly w = self.weight.reshape(1, -1, 1, 1) b = self.bias.reshape(1, -1, 1, 1) rv = self.running_var.reshape(1, -1, 1, 1) rm = self.running_mean.reshape(1, -1, 1, 1) scale = w * (rv + self.eps).rsqrt() bias = b - rm * scale return x * scale + bias def __repr__(self) -> str: return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})" class ConvNormActivation(torch.nn.Sequential): def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, ...]] = 3, stride: Union[int, Tuple[int, ...]] = 1, padding: Optional[Union[int, Tuple[int, ...], str]] = None, groups: int = 1, norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, dilation: Union[int, Tuple[int, ...]] = 1, inplace: Optional[bool] = True, bias: Optional[bool] = None, conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d, ) -> None: if padding is None: if isinstance(kernel_size, int) and isinstance(dilation, int): padding = (kernel_size - 1) // 2 * dilation else: _conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation) kernel_size = _make_ntuple(kernel_size, _conv_dim) dilation = _make_ntuple(dilation, _conv_dim) padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim)) if bias is None: bias = norm_layer is None layers = [ conv_layer( in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups, bias=bias, ) ] if norm_layer is not None: layers.append(norm_layer(out_channels)) if activation_layer is not None: params = {} if inplace is None else {"inplace": inplace} layers.append(activation_layer(**params)) super().__init__(*layers) _log_api_usage_once(self) self.out_channels = out_channels if self.__class__ == ConvNormActivation: warnings.warn( "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead." ) class Conv2dNormActivation(ConvNormActivation): """ Configurable block used for Convolution2d-Normalization-Activation blocks. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block kernel_size: (int, optional): Size of the convolving kernel. Default: 3 stride (int, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation`` groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d`` activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` dilation (int): Spacing between kernel elements. Default: 1 inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]] = 3, stride: Union[int, Tuple[int, int]] = 1, padding: Optional[Union[int, Tuple[int, int], str]] = None, groups: int = 1, norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, dilation: Union[int, Tuple[int, int]] = 1, inplace: Optional[bool] = True, bias: Optional[bool] = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, groups, norm_layer, activation_layer, dilation, inplace, bias, torch.nn.Conv2d, ) class Conv3dNormActivation(ConvNormActivation): """ Configurable block used for Convolution3d-Normalization-Activation blocks. Args: in_channels (int): Number of channels in the input video. out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block kernel_size: (int, optional): Size of the convolving kernel. Default: 3 stride (int, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation`` groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm3d`` activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` dilation (int): Spacing between kernel elements. Default: 1 inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int, int]] = 3, stride: Union[int, Tuple[int, int, int]] = 1, padding: Optional[Union[int, Tuple[int, int, int], str]] = None, groups: int = 1, norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, dilation: Union[int, Tuple[int, int, int]] = 1, inplace: Optional[bool] = True, bias: Optional[bool] = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, groups, norm_layer, activation_layer, dilation, inplace, bias, torch.nn.Conv3d, ) class SqueezeExcitation(torch.nn.Module): """ This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1). Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in eq. 3. Args: input_channels (int): Number of channels in the input image squeeze_channels (int): Number of squeeze channels activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU`` scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid`` """ def __init__( self, input_channels: int, squeeze_channels: int, activation: Callable[..., torch.nn.Module] = torch.nn.ReLU, scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, ) -> None: super().__init__() _log_api_usage_once(self) self.avgpool = torch.nn.AdaptiveAvgPool2d(1) self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1) self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1) self.activation = activation() self.scale_activation = scale_activation() def _scale(self, input: Tensor) -> Tensor: scale = self.avgpool(input) scale = self.fc1(scale) scale = self.activation(scale) scale = self.fc2(scale) return self.scale_activation(scale) def forward(self, input: Tensor) -> Tensor: scale = self._scale(input) return scale * input class MLP(torch.nn.Sequential): """This block implements the multi-layer perceptron (MLP) module. Args: in_channels (int): Number of channels of the input hidden_channels (List[int]): List of the hidden channel dimensions norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None`` activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place. Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer. bias (bool): Whether to use bias in the linear layer. Default ``True`` dropout (float): The probability for the dropout layer. Default: 0.0 """ def __init__( self, in_channels: int, hidden_channels: List[int], norm_layer: Optional[Callable[..., torch.nn.Module]] = None, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, inplace: Optional[bool] = None, bias: bool = True, dropout: float = 0.0, ): # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal: # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py params = {} if inplace is None else {"inplace": inplace} layers = [] in_dim = in_channels for hidden_dim in hidden_channels[:-1]: layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) if norm_layer is not None: layers.append(norm_layer(hidden_dim)) layers.append(activation_layer(**params)) layers.append(torch.nn.Dropout(dropout, **params)) in_dim = hidden_dim layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias)) layers.append(torch.nn.Dropout(dropout, **params)) super().__init__(*layers) _log_api_usage_once(self) class Permute(torch.nn.Module): """This module returns a view of the tensor input with its dimensions permuted. Args: dims (List[int]): The desired ordering of dimensions """ def __init__(self, dims: List[int]): super().__init__() self.dims = dims def forward(self, x: Tensor) -> Tensor: return torch.permute(x, self.dims)