You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
2.2 KiB
54 lines
2.2 KiB
5 months ago
|
import inspect
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def skip_init(module_cls, *args, **kwargs):
|
||
|
r"""
|
||
|
Given a module class object and args / kwargs, instantiate the module without initializing parameters / buffers.
|
||
|
|
||
|
This can be useful if initialization is slow or if custom initialization will
|
||
|
be performed, making the default initialization unnecessary. There are some caveats to this, due to
|
||
|
the way this function is implemented:
|
||
|
|
||
|
1. The module must accept a `device` arg in its constructor that is passed to any parameters
|
||
|
or buffers created during construction.
|
||
|
|
||
|
2. The module must not perform any computation on parameters in its constructor except
|
||
|
initialization (i.e. functions from :mod:`torch.nn.init`).
|
||
|
|
||
|
If these conditions are satisfied, the module can be instantiated with parameter / buffer values
|
||
|
uninitialized, as if having been created using :func:`torch.empty`.
|
||
|
|
||
|
Args:
|
||
|
module_cls: Class object; should be a subclass of :class:`torch.nn.Module`
|
||
|
args: args to pass to the module's constructor
|
||
|
kwargs: kwargs to pass to the module's constructor
|
||
|
|
||
|
Returns:
|
||
|
Instantiated module with uninitialized parameters / buffers
|
||
|
|
||
|
Example::
|
||
|
|
||
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||
|
>>> import torch
|
||
|
>>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
|
||
|
>>> m.weight
|
||
|
Parameter containing:
|
||
|
tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]],
|
||
|
requires_grad=True)
|
||
|
>>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1)
|
||
|
>>> m2.weight
|
||
|
Parameter containing:
|
||
|
tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24,
|
||
|
4.5915e-41]], requires_grad=True)
|
||
|
|
||
|
"""
|
||
|
if not issubclass(module_cls, torch.nn.Module):
|
||
|
raise RuntimeError(f'Expected a Module; got {module_cls}')
|
||
|
if 'device' not in inspect.signature(module_cls).parameters:
|
||
|
raise RuntimeError('Module must support a \'device\' arg to skip initialization')
|
||
|
|
||
|
final_device = kwargs.pop('device', 'cpu')
|
||
|
kwargs['device'] = 'meta'
|
||
|
return module_cls(*args, **kwargs).to_empty(device=final_device)
|