You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
415 lines
15 KiB
415 lines
15 KiB
5 months ago
|
from functools import partial
|
||
|
from typing import Any, Callable, List, Optional, Sequence
|
||
|
|
||
|
import torch
|
||
|
from torch import nn, Tensor
|
||
|
from torch.nn import functional as F
|
||
|
|
||
|
from ..ops.misc import Conv2dNormActivation, Permute
|
||
|
from ..ops.stochastic_depth import StochasticDepth
|
||
|
from ..transforms._presets import ImageClassification
|
||
|
from ..utils import _log_api_usage_once
|
||
|
from ._api import register_model, Weights, WeightsEnum
|
||
|
from ._meta import _IMAGENET_CATEGORIES
|
||
|
from ._utils import _ovewrite_named_param, handle_legacy_interface
|
||
|
|
||
|
|
||
|
__all__ = [
|
||
|
"ConvNeXt",
|
||
|
"ConvNeXt_Tiny_Weights",
|
||
|
"ConvNeXt_Small_Weights",
|
||
|
"ConvNeXt_Base_Weights",
|
||
|
"ConvNeXt_Large_Weights",
|
||
|
"convnext_tiny",
|
||
|
"convnext_small",
|
||
|
"convnext_base",
|
||
|
"convnext_large",
|
||
|
]
|
||
|
|
||
|
|
||
|
class LayerNorm2d(nn.LayerNorm):
|
||
|
def forward(self, x: Tensor) -> Tensor:
|
||
|
x = x.permute(0, 2, 3, 1)
|
||
|
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||
|
x = x.permute(0, 3, 1, 2)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class CNBlock(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
dim,
|
||
|
layer_scale: float,
|
||
|
stochastic_depth_prob: float,
|
||
|
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
if norm_layer is None:
|
||
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||
|
|
||
|
self.block = nn.Sequential(
|
||
|
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
|
||
|
Permute([0, 2, 3, 1]),
|
||
|
norm_layer(dim),
|
||
|
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
|
||
|
nn.GELU(),
|
||
|
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
|
||
|
Permute([0, 3, 1, 2]),
|
||
|
)
|
||
|
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
|
||
|
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
|
||
|
|
||
|
def forward(self, input: Tensor) -> Tensor:
|
||
|
result = self.layer_scale * self.block(input)
|
||
|
result = self.stochastic_depth(result)
|
||
|
result += input
|
||
|
return result
|
||
|
|
||
|
|
||
|
class CNBlockConfig:
|
||
|
# Stores information listed at Section 3 of the ConvNeXt paper
|
||
|
def __init__(
|
||
|
self,
|
||
|
input_channels: int,
|
||
|
out_channels: Optional[int],
|
||
|
num_layers: int,
|
||
|
) -> None:
|
||
|
self.input_channels = input_channels
|
||
|
self.out_channels = out_channels
|
||
|
self.num_layers = num_layers
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
s = self.__class__.__name__ + "("
|
||
|
s += "input_channels={input_channels}"
|
||
|
s += ", out_channels={out_channels}"
|
||
|
s += ", num_layers={num_layers}"
|
||
|
s += ")"
|
||
|
return s.format(**self.__dict__)
|
||
|
|
||
|
|
||
|
class ConvNeXt(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
block_setting: List[CNBlockConfig],
|
||
|
stochastic_depth_prob: float = 0.0,
|
||
|
layer_scale: float = 1e-6,
|
||
|
num_classes: int = 1000,
|
||
|
block: Optional[Callable[..., nn.Module]] = None,
|
||
|
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
_log_api_usage_once(self)
|
||
|
|
||
|
if not block_setting:
|
||
|
raise ValueError("The block_setting should not be empty")
|
||
|
elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
|
||
|
raise TypeError("The block_setting should be List[CNBlockConfig]")
|
||
|
|
||
|
if block is None:
|
||
|
block = CNBlock
|
||
|
|
||
|
if norm_layer is None:
|
||
|
norm_layer = partial(LayerNorm2d, eps=1e-6)
|
||
|
|
||
|
layers: List[nn.Module] = []
|
||
|
|
||
|
# Stem
|
||
|
firstconv_output_channels = block_setting[0].input_channels
|
||
|
layers.append(
|
||
|
Conv2dNormActivation(
|
||
|
3,
|
||
|
firstconv_output_channels,
|
||
|
kernel_size=4,
|
||
|
stride=4,
|
||
|
padding=0,
|
||
|
norm_layer=norm_layer,
|
||
|
activation_layer=None,
|
||
|
bias=True,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
|
||
|
stage_block_id = 0
|
||
|
for cnf in block_setting:
|
||
|
# Bottlenecks
|
||
|
stage: List[nn.Module] = []
|
||
|
for _ in range(cnf.num_layers):
|
||
|
# adjust stochastic depth probability based on the depth of the stage block
|
||
|
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
|
||
|
stage.append(block(cnf.input_channels, layer_scale, sd_prob))
|
||
|
stage_block_id += 1
|
||
|
layers.append(nn.Sequential(*stage))
|
||
|
if cnf.out_channels is not None:
|
||
|
# Downsampling
|
||
|
layers.append(
|
||
|
nn.Sequential(
|
||
|
norm_layer(cnf.input_channels),
|
||
|
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
|
||
|
)
|
||
|
)
|
||
|
|
||
|
self.features = nn.Sequential(*layers)
|
||
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||
|
|
||
|
lastblock = block_setting[-1]
|
||
|
lastconv_output_channels = (
|
||
|
lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
|
||
|
)
|
||
|
self.classifier = nn.Sequential(
|
||
|
norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
|
||
|
)
|
||
|
|
||
|
for m in self.modules():
|
||
|
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
||
|
if m.bias is not None:
|
||
|
nn.init.zeros_(m.bias)
|
||
|
|
||
|
def _forward_impl(self, x: Tensor) -> Tensor:
|
||
|
x = self.features(x)
|
||
|
x = self.avgpool(x)
|
||
|
x = self.classifier(x)
|
||
|
return x
|
||
|
|
||
|
def forward(self, x: Tensor) -> Tensor:
|
||
|
return self._forward_impl(x)
|
||
|
|
||
|
|
||
|
def _convnext(
|
||
|
block_setting: List[CNBlockConfig],
|
||
|
stochastic_depth_prob: float,
|
||
|
weights: Optional[WeightsEnum],
|
||
|
progress: bool,
|
||
|
**kwargs: Any,
|
||
|
) -> ConvNeXt:
|
||
|
if weights is not None:
|
||
|
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
||
|
|
||
|
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
|
||
|
|
||
|
if weights is not None:
|
||
|
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
||
|
|
||
|
return model
|
||
|
|
||
|
|
||
|
_COMMON_META = {
|
||
|
"min_size": (32, 32),
|
||
|
"categories": _IMAGENET_CATEGORIES,
|
||
|
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
|
||
|
"_docs": """
|
||
|
These weights improve upon the results of the original paper by using a modified version of TorchVision's
|
||
|
`new training recipe
|
||
|
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
||
|
""",
|
||
|
}
|
||
|
|
||
|
|
||
|
class ConvNeXt_Tiny_Weights(WeightsEnum):
|
||
|
IMAGENET1K_V1 = Weights(
|
||
|
url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
|
||
|
transforms=partial(ImageClassification, crop_size=224, resize_size=236),
|
||
|
meta={
|
||
|
**_COMMON_META,
|
||
|
"num_params": 28589128,
|
||
|
"_metrics": {
|
||
|
"ImageNet-1K": {
|
||
|
"acc@1": 82.520,
|
||
|
"acc@5": 96.146,
|
||
|
}
|
||
|
},
|
||
|
"_ops": 4.456,
|
||
|
"_file_size": 109.119,
|
||
|
},
|
||
|
)
|
||
|
DEFAULT = IMAGENET1K_V1
|
||
|
|
||
|
|
||
|
class ConvNeXt_Small_Weights(WeightsEnum):
|
||
|
IMAGENET1K_V1 = Weights(
|
||
|
url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
|
||
|
transforms=partial(ImageClassification, crop_size=224, resize_size=230),
|
||
|
meta={
|
||
|
**_COMMON_META,
|
||
|
"num_params": 50223688,
|
||
|
"_metrics": {
|
||
|
"ImageNet-1K": {
|
||
|
"acc@1": 83.616,
|
||
|
"acc@5": 96.650,
|
||
|
}
|
||
|
},
|
||
|
"_ops": 8.684,
|
||
|
"_file_size": 191.703,
|
||
|
},
|
||
|
)
|
||
|
DEFAULT = IMAGENET1K_V1
|
||
|
|
||
|
|
||
|
class ConvNeXt_Base_Weights(WeightsEnum):
|
||
|
IMAGENET1K_V1 = Weights(
|
||
|
url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
|
||
|
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
||
|
meta={
|
||
|
**_COMMON_META,
|
||
|
"num_params": 88591464,
|
||
|
"_metrics": {
|
||
|
"ImageNet-1K": {
|
||
|
"acc@1": 84.062,
|
||
|
"acc@5": 96.870,
|
||
|
}
|
||
|
},
|
||
|
"_ops": 15.355,
|
||
|
"_file_size": 338.064,
|
||
|
},
|
||
|
)
|
||
|
DEFAULT = IMAGENET1K_V1
|
||
|
|
||
|
|
||
|
class ConvNeXt_Large_Weights(WeightsEnum):
|
||
|
IMAGENET1K_V1 = Weights(
|
||
|
url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
|
||
|
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
||
|
meta={
|
||
|
**_COMMON_META,
|
||
|
"num_params": 197767336,
|
||
|
"_metrics": {
|
||
|
"ImageNet-1K": {
|
||
|
"acc@1": 84.414,
|
||
|
"acc@5": 96.976,
|
||
|
}
|
||
|
},
|
||
|
"_ops": 34.361,
|
||
|
"_file_size": 754.537,
|
||
|
},
|
||
|
)
|
||
|
DEFAULT = IMAGENET1K_V1
|
||
|
|
||
|
|
||
|
@register_model()
|
||
|
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
|
||
|
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
|
||
|
"""ConvNeXt Tiny model architecture from the
|
||
|
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
|
||
|
|
||
|
Args:
|
||
|
weights (:class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`, optional): The pretrained
|
||
|
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`
|
||
|
below for more details and possible values. By default, no pre-trained weights are used.
|
||
|
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
||
|
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
|
||
|
base class. Please refer to the `source code
|
||
|
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
|
||
|
for more details about this class.
|
||
|
|
||
|
.. autoclass:: torchvision.models.ConvNeXt_Tiny_Weights
|
||
|
:members:
|
||
|
"""
|
||
|
weights = ConvNeXt_Tiny_Weights.verify(weights)
|
||
|
|
||
|
block_setting = [
|
||
|
CNBlockConfig(96, 192, 3),
|
||
|
CNBlockConfig(192, 384, 3),
|
||
|
CNBlockConfig(384, 768, 9),
|
||
|
CNBlockConfig(768, None, 3),
|
||
|
]
|
||
|
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
|
||
|
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
|
||
|
|
||
|
|
||
|
@register_model()
|
||
|
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
|
||
|
def convnext_small(
|
||
|
*, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
|
||
|
) -> ConvNeXt:
|
||
|
"""ConvNeXt Small model architecture from the
|
||
|
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
|
||
|
|
||
|
Args:
|
||
|
weights (:class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`, optional): The pretrained
|
||
|
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`
|
||
|
below for more details and possible values. By default, no pre-trained weights are used.
|
||
|
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
||
|
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
|
||
|
base class. Please refer to the `source code
|
||
|
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
|
||
|
for more details about this class.
|
||
|
|
||
|
.. autoclass:: torchvision.models.ConvNeXt_Small_Weights
|
||
|
:members:
|
||
|
"""
|
||
|
weights = ConvNeXt_Small_Weights.verify(weights)
|
||
|
|
||
|
block_setting = [
|
||
|
CNBlockConfig(96, 192, 3),
|
||
|
CNBlockConfig(192, 384, 3),
|
||
|
CNBlockConfig(384, 768, 27),
|
||
|
CNBlockConfig(768, None, 3),
|
||
|
]
|
||
|
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
|
||
|
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
|
||
|
|
||
|
|
||
|
@register_model()
|
||
|
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
|
||
|
def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
|
||
|
"""ConvNeXt Base model architecture from the
|
||
|
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
|
||
|
|
||
|
Args:
|
||
|
weights (:class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`, optional): The pretrained
|
||
|
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`
|
||
|
below for more details and possible values. By default, no pre-trained weights are used.
|
||
|
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
||
|
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
|
||
|
base class. Please refer to the `source code
|
||
|
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
|
||
|
for more details about this class.
|
||
|
|
||
|
.. autoclass:: torchvision.models.ConvNeXt_Base_Weights
|
||
|
:members:
|
||
|
"""
|
||
|
weights = ConvNeXt_Base_Weights.verify(weights)
|
||
|
|
||
|
block_setting = [
|
||
|
CNBlockConfig(128, 256, 3),
|
||
|
CNBlockConfig(256, 512, 3),
|
||
|
CNBlockConfig(512, 1024, 27),
|
||
|
CNBlockConfig(1024, None, 3),
|
||
|
]
|
||
|
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
|
||
|
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
|
||
|
|
||
|
|
||
|
@register_model()
|
||
|
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
|
||
|
def convnext_large(
|
||
|
*, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
|
||
|
) -> ConvNeXt:
|
||
|
"""ConvNeXt Large model architecture from the
|
||
|
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
|
||
|
|
||
|
Args:
|
||
|
weights (:class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`, optional): The pretrained
|
||
|
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`
|
||
|
below for more details and possible values. By default, no pre-trained weights are used.
|
||
|
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
||
|
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
|
||
|
base class. Please refer to the `source code
|
||
|
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
|
||
|
for more details about this class.
|
||
|
|
||
|
.. autoclass:: torchvision.models.ConvNeXt_Large_Weights
|
||
|
:members:
|
||
|
"""
|
||
|
weights = ConvNeXt_Large_Weights.verify(weights)
|
||
|
|
||
|
block_setting = [
|
||
|
CNBlockConfig(192, 384, 3),
|
||
|
CNBlockConfig(384, 768, 3),
|
||
|
CNBlockConfig(768, 1536, 27),
|
||
|
CNBlockConfig(1536, None, 3),
|
||
|
]
|
||
|
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
|
||
|
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
|