You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

92 lines
2.7 KiB

5 months ago
from typing import Optional
import torch
from torch.overrides import TorchFunctionMode
from torch.utils._contextlib import context_decorator
import functools
CURRENT_DEVICE: Optional[torch.device] = None
@functools.lru_cache(1)
def _device_constructors():
return {
# standard ones
torch.empty,
torch.empty_permuted,
torch.empty_strided,
torch.empty_quantized,
torch.ones,
torch.arange,
torch.bartlett_window,
torch.blackman_window,
torch.eye,
torch.fft.fftfreq,
torch.fft.rfftfreq,
torch.full,
torch.fill,
torch.hamming_window,
torch.hann_window,
torch.kaiser_window,
torch.linspace,
torch.logspace,
torch.nested.nested_tensor,
# This function doesn't actually take a device argument
# torch.normal,
torch.ones,
torch.rand,
torch.randn,
torch.randint,
torch.randperm,
torch.range,
torch.sparse_coo_tensor,
torch.sparse_compressed_tensor,
torch.sparse_csr_tensor,
torch.sparse_csc_tensor,
torch.sparse_bsr_tensor,
torch.sparse_bsc_tensor,
torch.tril_indices,
torch.triu_indices,
torch.vander,
torch.zeros,
torch.asarray,
# weird ones
torch.tensor,
torch.as_tensor,
torch.scalar_tensor,
torch.asarray,
}
# NB: This is directly called from C++ in torch/csrc/Device.cpp
class DeviceContext(TorchFunctionMode):
def __init__(self, device):
self.device = torch.device(device)
def __enter__(self):
global CURRENT_DEVICE
self.old_device = CURRENT_DEVICE
CURRENT_DEVICE = self.device
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
global CURRENT_DEVICE
CURRENT_DEVICE = self.old_device
return super().__exit__(exc_type, exc_val, exc_tb)
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if func in _device_constructors() and kwargs.get('device') is None:
kwargs['device'] = self.device
return func(*args, **kwargs)
# NB: This is directly called from C++ in torch/csrc/Device.cpp
def device_decorator(device, func):
return context_decorator(lambda: device, func)
def set_device(device):
"""
Set the default device inside of the wrapped function by decorating it with this function.
If you would like to use this as a context manager, use device as a
context manager directly, e.g., ``with torch.device(device)``.
"""
return lambda func: device_decorator(torch.device(device), func)