You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

415 lines
15 KiB

from functools import partial
from typing import Any, Callable, List, Optional, Sequence
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from ..ops.misc import Conv2dNormActivation, Permute
from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
__all__ = [
"ConvNeXt",
"ConvNeXt_Tiny_Weights",
"ConvNeXt_Small_Weights",
"ConvNeXt_Base_Weights",
"ConvNeXt_Large_Weights",
"convnext_tiny",
"convnext_small",
"convnext_base",
"convnext_large",
]
class LayerNorm2d(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
x = x.permute(0, 2, 3, 1)
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
return x
class CNBlock(nn.Module):
def __init__(
self,
dim,
layer_scale: float,
stochastic_depth_prob: float,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.block = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
Permute([0, 2, 3, 1]),
norm_layer(dim),
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
nn.GELU(),
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
Permute([0, 3, 1, 2]),
)
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
def forward(self, input: Tensor) -> Tensor:
result = self.layer_scale * self.block(input)
result = self.stochastic_depth(result)
result += input
return result
class CNBlockConfig:
# Stores information listed at Section 3 of the ConvNeXt paper
def __init__(
self,
input_channels: int,
out_channels: Optional[int],
num_layers: int,
) -> None:
self.input_channels = input_channels
self.out_channels = out_channels
self.num_layers = num_layers
def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "input_channels={input_channels}"
s += ", out_channels={out_channels}"
s += ", num_layers={num_layers}"
s += ")"
return s.format(**self.__dict__)
class ConvNeXt(nn.Module):
def __init__(
self,
block_setting: List[CNBlockConfig],
stochastic_depth_prob: float = 0.0,
layer_scale: float = 1e-6,
num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
) -> None:
super().__init__()
_log_api_usage_once(self)
if not block_setting:
raise ValueError("The block_setting should not be empty")
elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
raise TypeError("The block_setting should be List[CNBlockConfig]")
if block is None:
block = CNBlock
if norm_layer is None:
norm_layer = partial(LayerNorm2d, eps=1e-6)
layers: List[nn.Module] = []
# Stem
firstconv_output_channels = block_setting[0].input_channels
layers.append(
Conv2dNormActivation(
3,
firstconv_output_channels,
kernel_size=4,
stride=4,
padding=0,
norm_layer=norm_layer,
activation_layer=None,
bias=True,
)
)
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
stage_block_id = 0
for cnf in block_setting:
# Bottlenecks
stage: List[nn.Module] = []
for _ in range(cnf.num_layers):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
stage.append(block(cnf.input_channels, layer_scale, sd_prob))
stage_block_id += 1
layers.append(nn.Sequential(*stage))
if cnf.out_channels is not None:
# Downsampling
layers.append(
nn.Sequential(
norm_layer(cnf.input_channels),
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
)
)
self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
lastblock = block_setting[-1]
lastconv_output_channels = (
lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
)
self.classifier = nn.Sequential(
norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.features(x)
x = self.avgpool(x)
x = self.classifier(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def _convnext(
block_setting: List[CNBlockConfig],
stochastic_depth_prob: float,
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> ConvNeXt:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
_COMMON_META = {
"min_size": (32, 32),
"categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
"_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
}
class ConvNeXt_Tiny_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=236),
meta={
**_COMMON_META,
"num_params": 28589128,
"_metrics": {
"ImageNet-1K": {
"acc@1": 82.520,
"acc@5": 96.146,
}
},
"_ops": 4.456,
"_file_size": 109.119,
},
)
DEFAULT = IMAGENET1K_V1
class ConvNeXt_Small_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=230),
meta={
**_COMMON_META,
"num_params": 50223688,
"_metrics": {
"ImageNet-1K": {
"acc@1": 83.616,
"acc@5": 96.650,
}
},
"_ops": 8.684,
"_file_size": 191.703,
},
)
DEFAULT = IMAGENET1K_V1
class ConvNeXt_Base_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 88591464,
"_metrics": {
"ImageNet-1K": {
"acc@1": 84.062,
"acc@5": 96.870,
}
},
"_ops": 15.355,
"_file_size": 338.064,
},
)
DEFAULT = IMAGENET1K_V1
class ConvNeXt_Large_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 197767336,
"_metrics": {
"ImageNet-1K": {
"acc@1": 84.414,
"acc@5": 96.976,
}
},
"_ops": 34.361,
"_file_size": 754.537,
},
)
DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
"""ConvNeXt Tiny model architecture from the
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
Args:
weights (:class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`, optional): The pretrained
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`
below for more details and possible values. By default, no pre-trained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
for more details about this class.
.. autoclass:: torchvision.models.ConvNeXt_Tiny_Weights
:members:
"""
weights = ConvNeXt_Tiny_Weights.verify(weights)
block_setting = [
CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 9),
CNBlockConfig(768, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
def convnext_small(
*, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
) -> ConvNeXt:
"""ConvNeXt Small model architecture from the
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
Args:
weights (:class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`, optional): The pretrained
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`
below for more details and possible values. By default, no pre-trained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
for more details about this class.
.. autoclass:: torchvision.models.ConvNeXt_Small_Weights
:members:
"""
weights = ConvNeXt_Small_Weights.verify(weights)
block_setting = [
CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 27),
CNBlockConfig(768, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
"""ConvNeXt Base model architecture from the
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
Args:
weights (:class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`, optional): The pretrained
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`
below for more details and possible values. By default, no pre-trained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
for more details about this class.
.. autoclass:: torchvision.models.ConvNeXt_Base_Weights
:members:
"""
weights = ConvNeXt_Base_Weights.verify(weights)
block_setting = [
CNBlockConfig(128, 256, 3),
CNBlockConfig(256, 512, 3),
CNBlockConfig(512, 1024, 27),
CNBlockConfig(1024, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
def convnext_large(
*, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
) -> ConvNeXt:
"""ConvNeXt Large model architecture from the
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
Args:
weights (:class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`, optional): The pretrained
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`
below for more details and possible values. By default, no pre-trained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
for more details about this class.
.. autoclass:: torchvision.models.ConvNeXt_Large_Weights
:members:
"""
weights = ConvNeXt_Large_Weights.verify(weights)
block_setting = [
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 3),
CNBlockConfig(768, 1536, 27),
CNBlockConfig(1536, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)