You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
2.2 KiB

5 months ago
import inspect
import torch
def skip_init(module_cls, *args, **kwargs):
r"""
Given a module class object and args / kwargs, instantiate the module without initializing parameters / buffers.
This can be useful if initialization is slow or if custom initialization will
be performed, making the default initialization unnecessary. There are some caveats to this, due to
the way this function is implemented:
1. The module must accept a `device` arg in its constructor that is passed to any parameters
or buffers created during construction.
2. The module must not perform any computation on parameters in its constructor except
initialization (i.e. functions from :mod:`torch.nn.init`).
If these conditions are satisfied, the module can be instantiated with parameter / buffer values
uninitialized, as if having been created using :func:`torch.empty`.
Args:
module_cls: Class object; should be a subclass of :class:`torch.nn.Module`
args: args to pass to the module's constructor
kwargs: kwargs to pass to the module's constructor
Returns:
Instantiated module with uninitialized parameters / buffers
Example::
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> import torch
>>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
>>> m.weight
Parameter containing:
tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]],
requires_grad=True)
>>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1)
>>> m2.weight
Parameter containing:
tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24,
4.5915e-41]], requires_grad=True)
"""
if not issubclass(module_cls, torch.nn.Module):
raise RuntimeError(f'Expected a Module; got {module_cls}')
if 'device' not in inspect.signature(module_cls).parameters:
raise RuntimeError('Module must support a \'device\' arg to skip initialization')
final_device = kwargs.pop('device', 'cpu')
kwargs['device'] = 'meta'
return module_cls(*args, **kwargs).to_empty(device=final_device)