You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

512 lines
19 KiB

from functools import partial
from typing import Any, cast, Dict, List, Optional, Union
import torch
import torch.nn as nn
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
__all__ = [
"VGG",
"VGG11_Weights",
"VGG11_BN_Weights",
"VGG13_Weights",
"VGG13_BN_Weights",
"VGG16_Weights",
"VGG16_BN_Weights",
"VGG19_Weights",
"VGG19_BN_Weights",
"vgg11",
"vgg11_bn",
"vgg13",
"vgg13_bn",
"vgg16",
"vgg16_bn",
"vgg19",
"vgg19_bn",
]
class VGG(nn.Module):
def __init__(
self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
) -> None:
super().__init__()
_log_api_usage_once(self)
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(p=dropout),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(p=dropout),
nn.Linear(4096, num_classes),
)
if init_weights:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
layers: List[nn.Module] = []
in_channels = 3
for v in cfg:
if v == "M":
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
v = cast(int, v)
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
cfgs: Dict[str, List[Union[str, int]]] = {
"A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
"E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
}
def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG:
if weights is not None:
kwargs["init_weights"] = False
if weights.meta["categories"] is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
_COMMON_META = {
"min_size": (32, 32),
"categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
"_docs": """These weights were trained from scratch by using a simplified training recipe.""",
}
class VGG11_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg11-8a719046.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 132863336,
"_metrics": {
"ImageNet-1K": {
"acc@1": 69.020,
"acc@5": 88.628,
}
},
"_ops": 7.609,
"_file_size": 506.84,
},
)
DEFAULT = IMAGENET1K_V1
class VGG11_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 132868840,
"_metrics": {
"ImageNet-1K": {
"acc@1": 70.370,
"acc@5": 89.810,
}
},
"_ops": 7.609,
"_file_size": 506.881,
},
)
DEFAULT = IMAGENET1K_V1
class VGG13_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg13-19584684.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 133047848,
"_metrics": {
"ImageNet-1K": {
"acc@1": 69.928,
"acc@5": 89.246,
}
},
"_ops": 11.308,
"_file_size": 507.545,
},
)
DEFAULT = IMAGENET1K_V1
class VGG13_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 133053736,
"_metrics": {
"ImageNet-1K": {
"acc@1": 71.586,
"acc@5": 90.374,
}
},
"_ops": 11.308,
"_file_size": 507.59,
},
)
DEFAULT = IMAGENET1K_V1
class VGG16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg16-397923af.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 138357544,
"_metrics": {
"ImageNet-1K": {
"acc@1": 71.592,
"acc@5": 90.382,
}
},
"_ops": 15.47,
"_file_size": 527.796,
},
)
IMAGENET1K_FEATURES = Weights(
# Weights ported from https://github.com/amdegroot/ssd.pytorch/
url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
transforms=partial(
ImageClassification,
crop_size=224,
mean=(0.48235, 0.45882, 0.40784),
std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0),
),
meta={
**_COMMON_META,
"num_params": 138357544,
"categories": None,
"recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd",
"_metrics": {
"ImageNet-1K": {
"acc@1": float("nan"),
"acc@5": float("nan"),
}
},
"_ops": 15.47,
"_file_size": 527.802,
"_docs": """
These weights can't be used for classification because they are missing values in the `classifier`
module. Only the `features` module has valid values and can be used for feature extraction. The weights
were trained using the original input standardization method as described in the paper.
""",
},
)
DEFAULT = IMAGENET1K_V1
class VGG16_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 138365992,
"_metrics": {
"ImageNet-1K": {
"acc@1": 73.360,
"acc@5": 91.516,
}
},
"_ops": 15.47,
"_file_size": 527.866,
},
)
DEFAULT = IMAGENET1K_V1
class VGG19_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 143667240,
"_metrics": {
"ImageNet-1K": {
"acc@1": 72.376,
"acc@5": 90.876,
}
},
"_ops": 19.632,
"_file_size": 548.051,
},
)
DEFAULT = IMAGENET1K_V1
class VGG19_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 143678248,
"_metrics": {
"ImageNet-1K": {
"acc@1": 74.218,
"acc@5": 91.842,
}
},
"_ops": 19.632,
"_file_size": 548.143,
},
)
DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1))
def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-11 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (:class:`~torchvision.models.VGG11_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG11_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG11_Weights
:members:
"""
weights = VGG11_Weights.verify(weights)
return _vgg("A", False, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1))
def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-11-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (:class:`~torchvision.models.VGG11_BN_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG11_BN_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG11_BN_Weights
:members:
"""
weights = VGG11_BN_Weights.verify(weights)
return _vgg("A", True, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1))
def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-13 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (:class:`~torchvision.models.VGG13_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG13_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG13_Weights
:members:
"""
weights = VGG13_Weights.verify(weights)
return _vgg("B", False, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1))
def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-13-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (:class:`~torchvision.models.VGG13_BN_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG13_BN_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG13_BN_Weights
:members:
"""
weights = VGG13_BN_Weights.verify(weights)
return _vgg("B", True, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1))
def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-16 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (:class:`~torchvision.models.VGG16_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG16_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG16_Weights
:members:
"""
weights = VGG16_Weights.verify(weights)
return _vgg("D", False, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1))
def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-16-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (:class:`~torchvision.models.VGG16_BN_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG16_BN_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG16_BN_Weights
:members:
"""
weights = VGG16_BN_Weights.verify(weights)
return _vgg("D", True, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1))
def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-19 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (:class:`~torchvision.models.VGG19_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG19_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG19_Weights
:members:
"""
weights = VGG19_Weights.verify(weights)
return _vgg("E", False, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1))
def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-19_BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (:class:`~torchvision.models.VGG19_BN_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG19_BN_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG19_BN_Weights
:members:
"""
weights = VGG19_BN_Weights.verify(weights)
return _vgg("E", True, weights, progress, **kwargs)