You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

397 lines
13 KiB

from typing import Any
import torch
from torch.utils._contextlib import (
_DecoratorContextManager,
_NoParamDecoratorContextManager,
F,
)
__all__ = [
"no_grad",
"enable_grad",
"set_grad_enabled",
"inference_mode",
"set_multithreading_enabled",
]
class no_grad(_NoParamDecoratorContextManager):
r"""Context-manager that disables gradient calculation.
Disabling gradient calculation is useful for inference, when you are sure
that you will not call :meth:`Tensor.backward()`. It will reduce memory
consumption for computations that would otherwise have `requires_grad=True`.
In this mode, the result of every computation will have
`requires_grad=False`, even when the inputs have `requires_grad=True`.
There is an exception! All factory functions, or functions that create
a new Tensor and take a requires_grad kwarg, will NOT be affected by
this mode.
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator.
.. note::
No-grad is one of several mechanisms that can enable or
disable gradients locally see :ref:`locally-disable-grad-doc` for
more information on how they compare.
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
If you want to disable forward AD for a computation, you can unpack
your dual tensors.
Example::
>>> # xdoctest: +SKIP
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
>>> @torch.no_grad
... def tripler(x):
... return x * 3
>>> z = tripler(x)
>>> z.requires_grad
False
>>> # factory function exception
>>> with torch.no_grad():
... a = torch.nn.Parameter(torch.rand(10))
>>> a.requires_grad
True
"""
def __init__(self) -> None:
if not torch._jit_internal.is_scripting():
super().__init__()
self.prev = False
def __enter__(self) -> None:
self.prev = torch.is_grad_enabled()
torch.set_grad_enabled(False)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch.set_grad_enabled(self.prev)
class enable_grad(_NoParamDecoratorContextManager):
r"""Context-manager that enables gradient calculation.
Enables gradient calculation, if it has been disabled via :class:`~no_grad`
or :class:`~set_grad_enabled`.
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator.
.. note::
enable_grad is one of several mechanisms that can enable or
disable gradients locally see :ref:`locally-disable-grad-doc` for
more information on how they compare.
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
Example::
>>> # xdoctest: +SKIP
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... with torch.enable_grad():
... y = x * 2
>>> y.requires_grad
True
>>> y.backward()
>>> x.grad
tensor([2.])
>>> @torch.enable_grad()
... def doubler(x):
... return x * 2
>>> with torch.no_grad():
... z = doubler(x)
>>> z.requires_grad
True
>>> @torch.enable_grad
... def tripler(x):
... return x * 3
>>> with torch.no_grad():
... z = tripler(x)
>>> z.requires_grad
True
"""
def __enter__(self) -> None:
self.prev = torch.is_grad_enabled()
torch._C._set_grad_enabled(True)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_grad_enabled(self.prev)
class set_grad_enabled(_DecoratorContextManager):
r"""Context-manager that sets gradient calculation on or off.
``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
This context manager is thread local; it will not affect computation
in other threads.
Args:
mode (bool): Flag whether to enable grad (``True``), or disable
(``False``). This can be used to conditionally enable
gradients.
.. note::
set_grad_enabled is one of several mechanisms that can enable or
disable gradients locally see :ref:`locally-disable-grad-doc` for
more information on how they compare.
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
Example::
>>> # xdoctest: +SKIP
>>> x = torch.tensor([1.], requires_grad=True)
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
... y = x * 2
>>> y.requires_grad
False
>>> _ = torch.set_grad_enabled(True)
>>> y = x * 2
>>> y.requires_grad
True
>>> _ = torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False
"""
def __init__(self, mode: bool) -> None:
self.prev = torch.is_grad_enabled()
self.mode = mode
torch._C._set_grad_enabled(mode)
def __call__(self, orig_func: F) -> F:
torch._C._set_grad_enabled(self.prev)
return super().__call__(orig_func)
def __enter__(self) -> None:
torch._C._set_grad_enabled(self.mode)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_grad_enabled(self.prev)
def clone(self) -> "set_grad_enabled":
r"""
Create a copy of this class
"""
return self.__class__(self.mode)
class inference_mode(_DecoratorContextManager):
r"""Context-manager that enables or disables inference mode.
InferenceMode is a new context manager analogous to :class:`~no_grad`
to be used when you are certain your operations will have no interactions
with autograd (e.g., model training). Code run under this mode gets better
performance by disabling view tracking and version counter bumps. Note that
unlike some other mechanisms that locally enable or disable grad,
entering inference_mode also disables to :ref:`forward-mode AD <forward-mode-ad>`.
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator.
.. note::
Inference mode is one of several mechanisms that can enable or
disable gradients locally see :ref:`locally-disable-grad-doc` for
more information on how they compare.
Args:
mode (bool or function): Either a boolean flag whether to enable or
disable inference mode or a Python function to decorate with
inference mode enabled
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> import torch
>>> x = torch.ones(1, 2, 3, requires_grad=True)
>>> with torch.inference_mode():
... y = x * x
>>> y.requires_grad
False
>>> # xdoctest: +SKIP("want string isnt quite right")
>>> y._version
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Inference tensors do not track version counter.
>>> @torch.inference_mode()
... def func(x):
... return x * x
>>> out = func(x)
>>> out.requires_grad
False
>>> @torch.inference_mode
... def doubler(x):
... return x * 2
>>> out = doubler(x)
>>> out.requires_grad
False
"""
def __init__(self, mode: bool = True) -> None:
if not torch._jit_internal.is_scripting():
super().__init__()
self.mode = mode
def __new__(cls, mode=True):
if isinstance(mode, bool):
return super().__new__(cls)
return cls()(mode)
def __enter__(self) -> None:
self._inference_mode_context = torch._C._InferenceMode(self.mode)
self._inference_mode_context.__enter__()
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
def clone(self) -> "inference_mode":
r"""
Create a copy of this class
"""
return self.__class__(self.mode)
def _enter_inference_mode(mode):
mode_context = torch._C._InferenceMode(mode)
mode_context.__enter__()
return mode_context
def _exit_inference_mode(mode):
mode.__exit__(None, None, None)
class set_multithreading_enabled(_DecoratorContextManager):
r"""Context-manager that sets multithreaded backwards on or off.
``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
This context manager is thread local; it will not affect computation
in other threads.
Args:
mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable
(``False``).
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
"""
def __init__(self, mode: bool) -> None:
self.prev = torch._C._is_multithreading_enabled()
torch._C._set_multithreading_enabled(mode)
self.mode = mode
def __enter__(self) -> None:
pass
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_multithreading_enabled(self.prev)
def clone(self) -> "set_multithreading_enabled":
r"""
Create a copy of this class
"""
return self.__class__(self.mode)
class _force_original_view_tracking(_DecoratorContextManager):
r"""Context-manager that sets whether or not to always enable view-replay in autograd.
``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
This context manager is thread local; it will not affect computation
in other threads.
When a tensor view is mutated, the autograd engine needs to decide whether or not
to regenerate the "updated view" by either replaying the chain of views from the updated base,
or with a single call to as_strided.
If set_view_replay_enabled is set to True, then autograd will always use view replay.
Otherwise, it will fall back to its existing logic.
Args:
mode (bool): Flag whether to enable view-replay (``True``), or disable
(``False``).
"""
def __init__(self, mode: bool) -> None:
self.prev = torch._C._is_view_replay_enabled()
torch._C._set_view_replay_enabled(mode)
self.mode = mode
def __enter__(self) -> None:
pass
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_view_replay_enabled(self.prev)
def clone(self):
return self.__class__(self.mode)
class _unsafe_preserve_version_counter(_DecoratorContextManager):
r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING.
This context manager can lead to arbitrary silent-correctness issues in any other part of your code
(even the ones not touched directly by the context manager)!
Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute.
This is generally important for correctness, as for example, mutating a tensor that autograd has saved
for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect
and error out in this situation.
However, there are rare instances where it might be useful to hide mutations from autograd. For example:
if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate
the tensor right before it is needed by autograd.
Args:
tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of.
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
"""
def __init__(self, tensor: torch.Tensor) -> None:
self.tensor = tensor
self.prev_version = tensor._version
def __enter__(self) -> None:
pass
def __exit__(self, *args) -> None:
torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version)