You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
607 lines
22 KiB
607 lines
22 KiB
5 months ago
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||
|
#
|
||
|
# This source code is licensed under the BSD license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
import contextlib
|
||
|
import copy
|
||
|
from abc import ABC, abstractmethod
|
||
|
from typing import (
|
||
|
Any,
|
||
|
Callable,
|
||
|
cast,
|
||
|
Dict,
|
||
|
Generator,
|
||
|
Iterable,
|
||
|
Optional,
|
||
|
Sequence,
|
||
|
Set,
|
||
|
Tuple,
|
||
|
Type,
|
||
|
Union,
|
||
|
)
|
||
|
|
||
|
import torch.nn as nn
|
||
|
|
||
|
__all__ = [
|
||
|
"always_wrap_policy",
|
||
|
"lambda_auto_wrap_policy",
|
||
|
"transformer_auto_wrap_policy",
|
||
|
"size_based_auto_wrap_policy",
|
||
|
"enable_wrap",
|
||
|
"wrap",
|
||
|
"CustomPolicy",
|
||
|
"ModuleWrapPolicy",
|
||
|
]
|
||
|
|
||
|
|
||
|
# NOTE: We intentionally keep this function simple and isolate the complexity
|
||
|
# to `fn` to enable using this function generically. We may move this to a
|
||
|
# non-FSDP-specific folder and/or make it public in the future.
|
||
|
def _post_order_apply(
|
||
|
root_module: nn.Module,
|
||
|
fn: Callable[[nn.Module], Optional[nn.Module]],
|
||
|
):
|
||
|
"""
|
||
|
This applies ``fn`` to every module in the module tree of ``root_module``
|
||
|
following a post-order traversal. If ``fn`` returns an :class:`nn.Module`,
|
||
|
then this replaces the original module with the newly returned one in the
|
||
|
tree. Otherwise, ``fn`` should return ``None``, in which case the module is
|
||
|
not changed.
|
||
|
"""
|
||
|
# Track visited modules to avoid visiting shared modules multiple times
|
||
|
visited_modules: Set[nn.Module] = {root_module}
|
||
|
|
||
|
def _post_order_apply_inner(
|
||
|
module: nn.Module,
|
||
|
module_name: str,
|
||
|
parent_module: Optional[nn.Module],
|
||
|
):
|
||
|
for child_module_name, child_module in module.named_children():
|
||
|
if child_module not in visited_modules:
|
||
|
visited_modules.add(child_module)
|
||
|
_post_order_apply_inner(child_module, child_module_name, module)
|
||
|
optional_module = fn(module)
|
||
|
if optional_module is not None:
|
||
|
assert isinstance(parent_module, nn.Module), (
|
||
|
"Non-root modules should have their parent module set but got "
|
||
|
f"{parent_module} for {module}"
|
||
|
)
|
||
|
assert module_name, (
|
||
|
"Non-root modules should have their module name set but got "
|
||
|
f"an empty module name for {module}"
|
||
|
)
|
||
|
assert isinstance(
|
||
|
optional_module, nn.Module
|
||
|
), f"fn should return None or an nn.Module but got {optional_module}"
|
||
|
setattr(parent_module, module_name, optional_module)
|
||
|
|
||
|
_post_order_apply_inner(root_module, "", None)
|
||
|
|
||
|
|
||
|
def _construct_wrap_fn(
|
||
|
root_module: nn.Module,
|
||
|
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
|
||
|
fsdp_fn: Callable,
|
||
|
) -> Callable[[nn.Module], Optional[nn.Module]]:
|
||
|
"""
|
||
|
This constructs the "wrap" function to pass to :func:`_post_order_apply`
|
||
|
based on ``target_module_to_kwargs``, which should be constructed from the
|
||
|
wrapping policy.
|
||
|
"""
|
||
|
|
||
|
def fn(module: nn.Module) -> Optional[nn.Module]:
|
||
|
# Explicitly avoid wrapping the root module since for FSDP, it is
|
||
|
# handled by the caller
|
||
|
if module in target_module_to_kwargs and module is not root_module:
|
||
|
kwargs = target_module_to_kwargs[module]
|
||
|
return fsdp_fn(module, **kwargs)
|
||
|
return None
|
||
|
|
||
|
return fn
|
||
|
|
||
|
|
||
|
def _run_mixed_precision_override_policy(
|
||
|
root_module: nn.Module,
|
||
|
module_classes: Iterable[Type[nn.Module]],
|
||
|
ignored_modules: Set[nn.Module],
|
||
|
root_kwargs: Dict[str, Any],
|
||
|
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
|
||
|
):
|
||
|
module_classes_tuple = tuple(set(module_classes))
|
||
|
for module in root_module.modules():
|
||
|
if module in ignored_modules:
|
||
|
continue
|
||
|
elif isinstance(module, module_classes_tuple):
|
||
|
# This policy overrides any existing policy
|
||
|
if module not in target_module_to_kwargs:
|
||
|
# Only inherit from the root kwargs if not already specified
|
||
|
target_module_to_kwargs[module] = root_kwargs
|
||
|
target_module_to_kwargs[module]["mixed_precision"] = None
|
||
|
return target_module_to_kwargs
|
||
|
|
||
|
|
||
|
def always_wrap_policy(*args, **kwargs) -> bool:
|
||
|
"""
|
||
|
A simple recursive wrap policy that always returns ``True``. This means
|
||
|
that every submodule is wrapped by the wrapper class in
|
||
|
:func:`_recursive_wrap`.
|
||
|
"""
|
||
|
return True
|
||
|
|
||
|
|
||
|
class _Policy(ABC):
|
||
|
"""
|
||
|
This defines an abstract base class that represents a policy for applying
|
||
|
a module-level API.
|
||
|
"""
|
||
|
|
||
|
@abstractmethod
|
||
|
def _run_policy(
|
||
|
self,
|
||
|
root_module: nn.Module,
|
||
|
ignored_modules: Set[nn.Module],
|
||
|
root_kwargs: Dict[str, Any],
|
||
|
) -> Dict[nn.Module, Dict[str, Any]]:
|
||
|
"""
|
||
|
This should return a dict ``target_module_to_kwargs`` that maps from
|
||
|
each target module to wrap to its kwargs.
|
||
|
"""
|
||
|
...
|
||
|
|
||
|
|
||
|
def _module_wrap_policy(
|
||
|
module: nn.Module,
|
||
|
recurse: bool,
|
||
|
nonwrapped_numel: int,
|
||
|
module_classes: Set[Type[nn.Module]],
|
||
|
) -> bool:
|
||
|
"""
|
||
|
This auto wrap policy wraps every module that is an instance of any type in
|
||
|
``module_classes`` as its own FSDP instance. The root module given by
|
||
|
``module`` is always wrapped as an FSDP instance regardless. Since the
|
||
|
wrapping proceeds bottom up, each FSDP instance manages the parameters in
|
||
|
its subtree excluding any already managed by a child FSDP instance.
|
||
|
|
||
|
Args:
|
||
|
module (nn.Module): Current module being considered.
|
||
|
recurse (bool): If ``False``, then this function must decide whether
|
||
|
``module`` should be wrapped as an FSDP instance or not. If
|
||
|
``True``, then the function is still recursing down the module
|
||
|
tree as a part of the DFS.
|
||
|
nonwrapped_numel (int): Parameter numel not yet wrapped.
|
||
|
module_classes (Set[Type[nn.Module]]): Set of module classes that are
|
||
|
wrapped as FSDP instances.
|
||
|
|
||
|
Returns:
|
||
|
``True`` if ``recurse=True``, and whether ``module`` should be wrapped
|
||
|
if ``recurse=False``.
|
||
|
"""
|
||
|
if recurse:
|
||
|
return True # always recurse
|
||
|
return isinstance(module, tuple(module_classes))
|
||
|
|
||
|
|
||
|
class ModuleWrapPolicy(_Policy):
|
||
|
"""
|
||
|
This policy applies to every module of the specified module classes,
|
||
|
passing in the kwargs given to the root.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, module_classes: Iterable[Type[nn.Module]]):
|
||
|
module_classes_set = set(module_classes)
|
||
|
self._module_classes = module_classes_set
|
||
|
self._module_classes_str = str(module_classes_set)
|
||
|
|
||
|
def _run_policy(
|
||
|
self,
|
||
|
root_module: nn.Module,
|
||
|
ignored_modules: Set[nn.Module],
|
||
|
root_kwargs: Dict[str, Any],
|
||
|
) -> Dict[nn.Module, Dict[str, Any]]:
|
||
|
module_classes = tuple(self._module_classes)
|
||
|
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
|
||
|
for module in root_module.modules():
|
||
|
if module in ignored_modules:
|
||
|
continue
|
||
|
elif isinstance(module, module_classes):
|
||
|
# Shallow copy to avoid coupling changes across modules
|
||
|
target_module_to_kwargs[module] = copy.copy(root_kwargs)
|
||
|
return target_module_to_kwargs
|
||
|
|
||
|
def __call__(self, module, recurse, *args, **kwargs):
|
||
|
# nonwrapped_numel is not used.
|
||
|
return _module_wrap_policy(
|
||
|
module, recurse, nonwrapped_numel=-1, module_classes=self._module_classes
|
||
|
)
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return super().__repr__() + f"({self._module_classes_str})"
|
||
|
|
||
|
|
||
|
class CustomPolicy(_Policy):
|
||
|
"""
|
||
|
This policy takes in a lambda function that maps a given ``nn.Module`` to
|
||
|
either ``False``, ``True``, or a kwarg dictionary.
|
||
|
- If the function returns ``False`` or an empty dictionary, then the module
|
||
|
does not have the API applied.
|
||
|
- If the function returns ``True``, then the module has the API applied
|
||
|
with the root's kwargs.
|
||
|
- If the function returns a non-empty dictionary, then the module has the
|
||
|
API applied, and the dictionary overrides the root's kwargs.
|
||
|
|
||
|
Example::
|
||
|
|
||
|
>>> # xdoctest: +SKIP("undefined variables")
|
||
|
>>> model = init_transformer_model(...)
|
||
|
>>> def lambda_fn(module: nn.Module):
|
||
|
>>> if module is model.lm_head:
|
||
|
>>> return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}
|
||
|
>>> elif isinstance(module, TransformerBlock):
|
||
|
>>> return True
|
||
|
>>> return False
|
||
|
>>> policy = CustomPolicy(lambda_fn)
|
||
|
>>> fsdp_model = FSDP(model, auto_wrap_policy=policy)
|
||
|
"""
|
||
|
|
||
|
def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, Dict[str, Any]]]):
|
||
|
self._lambda_fn = lambda_fn
|
||
|
|
||
|
def _run_policy(
|
||
|
self,
|
||
|
root_module: nn.Module,
|
||
|
ignored_modules: Set[nn.Module],
|
||
|
root_kwargs: Dict[str, Any],
|
||
|
) -> Dict[nn.Module, Dict[str, Any]]:
|
||
|
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
|
||
|
for module in root_module.modules():
|
||
|
if module in ignored_modules:
|
||
|
continue
|
||
|
res = self._lambda_fn(module)
|
||
|
if not isinstance(res, (dict, bool)):
|
||
|
raise ValueError(
|
||
|
"The lambda_fn passed to CustomPolicy should return "
|
||
|
f"False/True or a kwarg dict, but it returned {res}"
|
||
|
)
|
||
|
if not res:
|
||
|
continue
|
||
|
kwargs = copy.copy(root_kwargs)
|
||
|
if isinstance(res, dict):
|
||
|
# Override the root kwargs with the ones specified by the
|
||
|
# lambda function
|
||
|
kwargs.update(res)
|
||
|
target_module_to_kwargs[module] = kwargs
|
||
|
return target_module_to_kwargs
|
||
|
|
||
|
|
||
|
def lambda_auto_wrap_policy(
|
||
|
module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable
|
||
|
) -> bool:
|
||
|
"""
|
||
|
A convenient auto wrap policy to wrap submodules based on an arbitrary user
|
||
|
function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
|
||
|
a `wrapper_cls` unit.
|
||
|
|
||
|
Return if a module should be wrapped during auto wrapping.
|
||
|
|
||
|
The first three parameters are required by :func:`_recursive_wrap`.
|
||
|
|
||
|
Args:
|
||
|
module (nn.Module): Current module being considered.
|
||
|
recurse (bool): If ``False``, then this function must decide whether
|
||
|
``module`` should be wrapped as an FSDP instance or not. If
|
||
|
``True``, then the function is still recursing down the module
|
||
|
tree as a part of the DFS.
|
||
|
nonwrapped_numel (int): Parameter numel not yet wrapped.
|
||
|
|
||
|
lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
|
||
|
this module will be wrapped.
|
||
|
"""
|
||
|
if recurse:
|
||
|
return True # always recurse
|
||
|
return lambda_fn(module)
|
||
|
|
||
|
|
||
|
def transformer_auto_wrap_policy(
|
||
|
module: nn.Module,
|
||
|
recurse: bool,
|
||
|
nonwrapped_numel: int,
|
||
|
transformer_layer_cls: Set[Type[nn.Module]],
|
||
|
) -> bool:
|
||
|
"""
|
||
|
See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the
|
||
|
same as ``module_classes``. Note that shared parameters must be wrapped in
|
||
|
the same FSDP instance, so this auto wrap policy can help wrap shared
|
||
|
embeddings into the same FSDP instance for transformer models.
|
||
|
"""
|
||
|
return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls)
|
||
|
|
||
|
|
||
|
def _wrap_module_cls_individually(
|
||
|
module: nn.Module, module_classes: Sequence[type], recurse: bool, *args, **kwargs
|
||
|
):
|
||
|
if recurse:
|
||
|
# always recurse
|
||
|
return True
|
||
|
else:
|
||
|
# if not recursing, decide whether we should wrap based on whether the type of module
|
||
|
# is in `module_classes`.
|
||
|
return isinstance(module, tuple(module_classes))
|
||
|
|
||
|
|
||
|
def _or_policy(
|
||
|
module: nn.Module,
|
||
|
recurse: bool,
|
||
|
nonwrapped_numel: int,
|
||
|
policies,
|
||
|
) -> bool:
|
||
|
"""
|
||
|
A policy that wraps ``module`` if any policy in the passed in iterable of
|
||
|
``policies`` returns ``True``.
|
||
|
"""
|
||
|
return any(
|
||
|
policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel)
|
||
|
for policy in policies
|
||
|
)
|
||
|
|
||
|
|
||
|
def size_based_auto_wrap_policy(
|
||
|
module: nn.Module,
|
||
|
recurse: bool,
|
||
|
nonwrapped_numel: int,
|
||
|
# Additional custom arguments
|
||
|
min_num_params: int = int(1e8),
|
||
|
force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
|
||
|
exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
|
||
|
) -> bool:
|
||
|
"""
|
||
|
A size-based auto wrap policy.
|
||
|
|
||
|
Args:
|
||
|
module (nn.Module): Current module being considered.
|
||
|
recurse (bool): If ``False``, then this function must decide whether
|
||
|
``module`` should be wrapped as an FSDP instance or not. If
|
||
|
``True``, then the function is still recursing down the module
|
||
|
tree as a part of the DFS.
|
||
|
nonwrapped_numel (int): Parameter numel not yet wrapped.
|
||
|
|
||
|
min_num_params (int): Customizable policy input that controls the size
|
||
|
threshold over which a module is ready to be wrapped. This is in
|
||
|
units of numel.
|
||
|
force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep
|
||
|
as leaves, i.e. their children will never be wrapped.
|
||
|
exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be
|
||
|
excluded in wrapping.
|
||
|
|
||
|
Returns:
|
||
|
Whether ``module`` should be wrapped.
|
||
|
"""
|
||
|
force_leaf_modules = (
|
||
|
size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined]
|
||
|
if force_leaf_modules is None
|
||
|
else force_leaf_modules
|
||
|
)
|
||
|
exclude_wrap_modules = (
|
||
|
size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined]
|
||
|
if exclude_wrap_modules is None
|
||
|
else exclude_wrap_modules
|
||
|
)
|
||
|
|
||
|
# Keep the argument `min_num_params` for BC for now, but it represents the
|
||
|
# minimum non-wrapped *numel* before triggering a wrapping
|
||
|
min_nonwrapped_numel = min_num_params
|
||
|
is_large = nonwrapped_numel >= min_nonwrapped_numel
|
||
|
if recurse:
|
||
|
# We should recurse if the module is big enough but not in force_leaf_modules list.
|
||
|
return is_large and not isinstance(module, tuple(force_leaf_modules))
|
||
|
else:
|
||
|
# If we are not recursing, determine if we should wrap.
|
||
|
return is_large and not isinstance(module, tuple(exclude_wrap_modules))
|
||
|
|
||
|
|
||
|
# Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported.
|
||
|
size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore[attr-defined]
|
||
|
size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore[attr-defined]
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def enable_wrap(
|
||
|
*, wrapper_cls: Any, **wrapper_kwargs: Any
|
||
|
) -> Generator[None, None, None]:
|
||
|
"""
|
||
|
Context manager to wrap modules using a wrapper.
|
||
|
|
||
|
Useful for when you'd like to apply the same configuration arguments to all
|
||
|
child modules that you wrap. A particularly important use case is wrapping
|
||
|
large layers so that they get sharded (in-place) during initialization, to
|
||
|
avoid running out of system memory. Large layers can indicate that they
|
||
|
should be sharded via the ``wrap`` annotation and this context manager can
|
||
|
provide the exact configuration for these nested instances.
|
||
|
|
||
|
Usage::
|
||
|
|
||
|
with enable_wrap(wrapper_cls, **params):
|
||
|
# Wraps layer in FSDP by default if within context
|
||
|
self.l1 = wrap(torch.nn.Linear(5, 5))
|
||
|
|
||
|
Args:
|
||
|
wrapper_cls:
|
||
|
Class that `wrap` annotation will `wrap` modules with, such as
|
||
|
`FullyShardedDataParallel`.
|
||
|
**wrapper_kwargs:
|
||
|
Configuration settings that will be passed to all ``wrap``
|
||
|
instances inside the context
|
||
|
"""
|
||
|
kwargs = {
|
||
|
"wrapper_cls": wrapper_cls,
|
||
|
**wrapper_kwargs,
|
||
|
}
|
||
|
with _ConfigAutoWrap(**kwargs):
|
||
|
yield
|
||
|
|
||
|
|
||
|
def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
|
||
|
"""
|
||
|
Annotate that a module should be wrapped. Annotated modules will only be
|
||
|
wrapped if inside of an :func:`enable_wrap` context manager. This allows
|
||
|
a module to be initialized both with and without a wrapper without code
|
||
|
change.
|
||
|
|
||
|
The class that this function wraps the passed in ``nn.Module`` with is the
|
||
|
passed in ``wrapper_cls`` argument into ``enable_wrap``. Both
|
||
|
``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct
|
||
|
the ``wrapper_cls`` instance. In the case of duplicate kwargs in
|
||
|
``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be
|
||
|
respected.
|
||
|
|
||
|
Usage::
|
||
|
|
||
|
with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
|
||
|
# Wraps layer in FSDP by default if within context
|
||
|
self.l1 = wrap(torch.nn.Linear(5, 5))
|
||
|
|
||
|
Args:
|
||
|
module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
|
||
|
**wrap_overrides: configuration overrides that will take priority over
|
||
|
the values provided by the :func:`enable_wrap` context
|
||
|
"""
|
||
|
if _ConfigAutoWrap.in_autowrap_context:
|
||
|
assert _ConfigAutoWrap.wrapper_cls is not None
|
||
|
|
||
|
wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides}
|
||
|
return _wrap(
|
||
|
module,
|
||
|
_ConfigAutoWrap.wrapper_cls,
|
||
|
**wrap_overrides,
|
||
|
)
|
||
|
return module
|
||
|
|
||
|
|
||
|
def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
|
||
|
assert wrapper_cls is not None
|
||
|
if hasattr(module, "_wrap_overrides"):
|
||
|
# If module has a _wrap_overrides attribute, we force overriding the
|
||
|
# FSDP config with these attributes for this module. Currently this
|
||
|
# is only used to disable mixed precision for BatchNorm when
|
||
|
# auto_wrapping.
|
||
|
overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type]
|
||
|
return wrapper_cls(module, **overrides)
|
||
|
|
||
|
return wrapper_cls(module, **kwargs)
|
||
|
|
||
|
|
||
|
def _recursive_wrap(
|
||
|
module: nn.Module,
|
||
|
auto_wrap_policy: Callable,
|
||
|
wrapper_cls: Callable,
|
||
|
ignored_modules: Set[nn.Module],
|
||
|
ignored_params: Set[nn.Parameter],
|
||
|
only_wrap_children: bool = False,
|
||
|
**kwargs: Any,
|
||
|
) -> Tuple[nn.Module, int]:
|
||
|
"""
|
||
|
Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns
|
||
|
``True`` with ``wrapper_cls``.
|
||
|
|
||
|
Args:
|
||
|
module (nn.Module): Module to recursively wrap.
|
||
|
auto_wrap_policy (Callable): A callable representing a policy that
|
||
|
determines which modules to recursively wrap with ``wrapper_cls``.
|
||
|
ignored_modules (Set[torch.nn.Module]): Modules to ignore when
|
||
|
wrapping.
|
||
|
ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when
|
||
|
wrapping; these should be the parameters contained in the modules
|
||
|
in ``ignored_modules``.
|
||
|
Returns:
|
||
|
(nn.Module, int):
|
||
|
``module`` after wrapping and the numel recursively wrapped.
|
||
|
"""
|
||
|
assert auto_wrap_policy is not None, "Must specify auto_wrap_policy."
|
||
|
assert wrapper_cls is not None, "Must specify wrapper_cls"
|
||
|
# Make sure no child is already wrapped.
|
||
|
for _, child in module.named_modules():
|
||
|
if child in ignored_modules:
|
||
|
continue
|
||
|
try:
|
||
|
assert not isinstance(child, cast(type, wrapper_cls))
|
||
|
except TypeError:
|
||
|
# wrapper_cls is a function as opposed to a class type, just bypass above check.
|
||
|
pass
|
||
|
|
||
|
# We count all params, assuming none of them are already wrapped.
|
||
|
nonwrapped_numel = sum(
|
||
|
p.numel() for p in module.parameters() if p not in ignored_params
|
||
|
)
|
||
|
|
||
|
assert auto_wrap_policy is not None
|
||
|
if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel):
|
||
|
total_wrapped_numel = 0
|
||
|
# Iterate through the children, recursively wrap if necessary
|
||
|
for name, child in module.named_children():
|
||
|
if child in ignored_modules:
|
||
|
continue
|
||
|
wrapped_child, num_wrapped_params = _recursive_wrap(
|
||
|
module=child,
|
||
|
auto_wrap_policy=auto_wrap_policy,
|
||
|
wrapper_cls=wrapper_cls,
|
||
|
ignored_modules=ignored_modules,
|
||
|
ignored_params=ignored_params,
|
||
|
**kwargs,
|
||
|
)
|
||
|
setattr(module, name, wrapped_child)
|
||
|
# Keep track of how many parameters have been wrapped
|
||
|
total_wrapped_numel += num_wrapped_params
|
||
|
# decide if we need to wrap the current module,
|
||
|
# since the left over parameters exceed the number of params to wrap
|
||
|
remainder = nonwrapped_numel - total_wrapped_numel
|
||
|
if not only_wrap_children and auto_wrap_policy(
|
||
|
module=module, recurse=False, nonwrapped_numel=remainder
|
||
|
):
|
||
|
# Leaf node or final wrapping of the remainder both happen here.
|
||
|
return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
|
||
|
else:
|
||
|
return module, total_wrapped_numel
|
||
|
return module, 0
|
||
|
|
||
|
|
||
|
class _ConfigAutoWrap:
|
||
|
"""
|
||
|
Helper class to wrap modules based on default config args via a context manager.
|
||
|
See :func:`enable_wrap` for more information.
|
||
|
"""
|
||
|
|
||
|
in_autowrap_context: bool = False # Context flag
|
||
|
wrapper_cls: Optional[Callable] = None # The wrapper class
|
||
|
kwargs: Dict[str, Any] = {} # Wrapper's args
|
||
|
|
||
|
def __init__(self, **kwargs: Dict[str, Any]):
|
||
|
self.kwargs = kwargs
|
||
|
|
||
|
@staticmethod
|
||
|
def enable_autowrap_context(kwargs: Any) -> None:
|
||
|
if _ConfigAutoWrap.in_autowrap_context:
|
||
|
raise NotImplementedError(
|
||
|
"You are already within an autowrap context and we currently do not supported nested autowrap."
|
||
|
)
|
||
|
_ConfigAutoWrap.in_autowrap_context = True
|
||
|
# Get and save the wrapper cls for the context.
|
||
|
assert (
|
||
|
"wrapper_cls" in kwargs.keys()
|
||
|
), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
|
||
|
_ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
|
||
|
del kwargs["wrapper_cls"]
|
||
|
# Save the rest.
|
||
|
_ConfigAutoWrap.kwargs = kwargs
|
||
|
|
||
|
@staticmethod
|
||
|
def disable_autowrap_context() -> None:
|
||
|
_ConfigAutoWrap.in_autowrap_context = False
|
||
|
_ConfigAutoWrap.wrapper_cls = None
|
||
|
_ConfigAutoWrap.kwargs = {}
|
||
|
|
||
|
def __enter__(self) -> None:
|
||
|
self.enable_autowrap_context(self.kwargs)
|
||
|
|
||
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||
|
self.disable_autowrap_context()
|