You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
257 lines
11 KiB
257 lines
11 KiB
import functools
|
|
import inspect
|
|
import warnings
|
|
from collections import OrderedDict
|
|
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union
|
|
|
|
from torch import nn
|
|
|
|
from .._utils import sequence_to_str
|
|
from ._api import WeightsEnum
|
|
|
|
|
|
class IntermediateLayerGetter(nn.ModuleDict):
|
|
"""
|
|
Module wrapper that returns intermediate layers from a model
|
|
|
|
It has a strong assumption that the modules have been registered
|
|
into the model in the same order as they are used.
|
|
This means that one should **not** reuse the same nn.Module
|
|
twice in the forward if you want this to work.
|
|
|
|
Additionally, it is only able to query submodules that are directly
|
|
assigned to the model. So if `model` is passed, `model.feature1` can
|
|
be returned, but not `model.feature1.layer2`.
|
|
|
|
Args:
|
|
model (nn.Module): model on which we will extract the features
|
|
return_layers (Dict[name, new_name]): a dict containing the names
|
|
of the modules for which the activations will be returned as
|
|
the key of the dict, and the value of the dict is the name
|
|
of the returned activation (which the user can specify).
|
|
|
|
Examples::
|
|
|
|
>>> m = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
|
|
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
|
|
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
|
|
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
|
|
>>> out = new_m(torch.rand(1, 3, 224, 224))
|
|
>>> print([(k, v.shape) for k, v in out.items()])
|
|
>>> [('feat1', torch.Size([1, 64, 56, 56])),
|
|
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
|
|
"""
|
|
|
|
_version = 2
|
|
__annotations__ = {
|
|
"return_layers": Dict[str, str],
|
|
}
|
|
|
|
def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
|
|
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
|
|
raise ValueError("return_layers are not present in model")
|
|
orig_return_layers = return_layers
|
|
return_layers = {str(k): str(v) for k, v in return_layers.items()}
|
|
layers = OrderedDict()
|
|
for name, module in model.named_children():
|
|
layers[name] = module
|
|
if name in return_layers:
|
|
del return_layers[name]
|
|
if not return_layers:
|
|
break
|
|
|
|
super().__init__(layers)
|
|
self.return_layers = orig_return_layers
|
|
|
|
def forward(self, x):
|
|
out = OrderedDict()
|
|
for name, module in self.items():
|
|
x = module(x)
|
|
if name in self.return_layers:
|
|
out_name = self.return_layers[name]
|
|
out[out_name] = x
|
|
return out
|
|
|
|
|
|
def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
|
|
"""
|
|
This function is taken from the original tf repo.
|
|
It ensures that all layers have a channel number that is divisible by 8
|
|
It can be seen here:
|
|
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
|
"""
|
|
if min_value is None:
|
|
min_value = divisor
|
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
|
# Make sure that round down does not go down by more than 10%.
|
|
if new_v < 0.9 * v:
|
|
new_v += divisor
|
|
return new_v
|
|
|
|
|
|
D = TypeVar("D")
|
|
|
|
|
|
def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]:
|
|
"""Decorates a function that uses keyword only parameters to also allow them being passed as positionals.
|
|
|
|
For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``:
|
|
|
|
.. code::
|
|
|
|
def old_fn(foo, bar, baz=None):
|
|
...
|
|
|
|
def new_fn(foo, *, bar, baz=None):
|
|
...
|
|
|
|
Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC
|
|
and at the same time warn the user of the deprecation, this decorator can be used:
|
|
|
|
.. code::
|
|
|
|
@kwonly_to_pos_or_kw
|
|
def new_fn(foo, *, bar, baz=None):
|
|
...
|
|
|
|
new_fn("foo", "bar, "baz")
|
|
"""
|
|
params = inspect.signature(fn).parameters
|
|
|
|
try:
|
|
keyword_only_start_idx = next(
|
|
idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY
|
|
)
|
|
except StopIteration:
|
|
raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None
|
|
|
|
keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:]
|
|
|
|
@functools.wraps(fn)
|
|
def wrapper(*args: Any, **kwargs: Any) -> D:
|
|
args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:]
|
|
if keyword_only_args:
|
|
keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args))
|
|
warnings.warn(
|
|
f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "
|
|
f"parameter(s) is deprecated since 0.13 and may be removed in the future. Please use keyword parameter(s) "
|
|
f"instead."
|
|
)
|
|
kwargs.update(keyword_only_kwargs)
|
|
|
|
return fn(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
W = TypeVar("W", bound=WeightsEnum)
|
|
M = TypeVar("M", bound=nn.Module)
|
|
V = TypeVar("V")
|
|
|
|
|
|
def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]):
|
|
"""Decorates a model builder with the new interface to make it compatible with the old.
|
|
|
|
In particular this handles two things:
|
|
|
|
1. Allows positional parameters again, but emits a deprecation warning in case they are used. See
|
|
:func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details.
|
|
2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to
|
|
``weights=Weights`` and emits a deprecation warning with instructions for the new interface.
|
|
|
|
Args:
|
|
**weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter
|
|
name and default value for the legacy ``pretrained=True``. The default value can be a callable in which
|
|
case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in
|
|
the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters
|
|
should be accessed with :meth:`~dict.get`.
|
|
"""
|
|
|
|
def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]:
|
|
@kwonly_to_pos_or_kw
|
|
@functools.wraps(builder)
|
|
def inner_wrapper(*args: Any, **kwargs: Any) -> M:
|
|
for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr]
|
|
# If neither the weights nor the pretrained parameter as passed, or the weights argument already use
|
|
# the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the
|
|
# weight argument, since it is a valid value.
|
|
sentinel = object()
|
|
weights_arg = kwargs.get(weights_param, sentinel)
|
|
if (
|
|
(weights_param not in kwargs and pretrained_param not in kwargs)
|
|
or isinstance(weights_arg, WeightsEnum)
|
|
or (isinstance(weights_arg, str) and weights_arg != "legacy")
|
|
or weights_arg is None
|
|
):
|
|
continue
|
|
|
|
# If the pretrained parameter was passed as positional argument, it is now mapped to
|
|
# `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current
|
|
# signature to infer the names of positionally passed arguments and thus has no knowledge that there
|
|
# used to be a pretrained parameter.
|
|
pretrained_positional = weights_arg is not sentinel
|
|
if pretrained_positional:
|
|
# We put the pretrained argument under its legacy name in the keyword argument dictionary to have
|
|
# unified access to the value if the default value is a callable.
|
|
kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param)
|
|
else:
|
|
pretrained_arg = kwargs[pretrained_param]
|
|
|
|
if pretrained_arg:
|
|
default_weights_arg = default(kwargs) if callable(default) else default
|
|
if not isinstance(default_weights_arg, WeightsEnum):
|
|
raise ValueError(f"No weights available for model {builder.__name__}")
|
|
else:
|
|
default_weights_arg = None
|
|
|
|
if not pretrained_positional:
|
|
warnings.warn(
|
|
f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "
|
|
f"please use '{weights_param}' instead."
|
|
)
|
|
|
|
msg = (
|
|
f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated since 0.13 and "
|
|
f"may be removed in the future. "
|
|
f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`."
|
|
)
|
|
if pretrained_arg:
|
|
msg = (
|
|
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` "
|
|
f"to get the most up-to-date weights."
|
|
)
|
|
warnings.warn(msg)
|
|
|
|
del kwargs[pretrained_param]
|
|
kwargs[weights_param] = default_weights_arg
|
|
|
|
return builder(*args, **kwargs)
|
|
|
|
return inner_wrapper
|
|
|
|
return outer_wrapper
|
|
|
|
|
|
def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
|
|
if param in kwargs:
|
|
if kwargs[param] != new_value:
|
|
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.")
|
|
else:
|
|
kwargs[param] = new_value
|
|
|
|
|
|
def _ovewrite_value_param(param: str, actual: Optional[V], expected: V) -> V:
|
|
if actual is not None:
|
|
if actual != expected:
|
|
raise ValueError(f"The parameter '{param}' expected value {expected} but got {actual} instead.")
|
|
return expected
|
|
|
|
|
|
class _ModelURLs(dict):
|
|
def __getitem__(self, item):
|
|
warnings.warn(
|
|
"Accessing the model URLs via the internal dictionary of the module is deprecated since 0.13 and may "
|
|
"be removed in the future. Please access them via the appropriate Weights Enum instead."
|
|
)
|
|
return super().__getitem__(item)
|