You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

177 lines
5.1 KiB

from typing import Iterable, List, Union
import torch
from .. import Tensor
from . import _lazy_call, _lazy_init, current_device, device_count
def get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor:
r"""Return the random number generator state of the specified GPU as a ByteTensor.
Args:
device (torch.device or int, optional): The device to return the RNG state of.
Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
.. warning::
This function eagerly initializes XPU.
"""
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("xpu", device)
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch.xpu.default_generators[idx]
return default_generator.get_state()
def get_rng_state_all() -> List[Tensor]:
r"""Return a list of ByteTensor representing the random number states of all devices."""
results = []
for i in range(device_count()):
results.append(get_rng_state(i))
return results
def set_rng_state(
new_state: Tensor, device: Union[int, str, torch.device] = "xpu"
) -> None:
r"""Set the random number generator state of the specified GPU.
Args:
new_state (torch.ByteTensor): The desired state
device (torch.device or int, optional): The device to set the RNG state.
Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
"""
with torch._C._DisableFuncTorch():
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("xpu", device)
def cb():
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch.xpu.default_generators[idx]
default_generator.set_state(new_state_copy)
_lazy_call(cb)
def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
r"""Set the random number generator state of all devices.
Args:
new_states (Iterable of torch.ByteTensor): The desired state for each device.
"""
for i, state in enumerate(new_states):
set_rng_state(state, i)
def manual_seed(seed: int) -> None:
r"""Set the seed for generating random numbers for the current GPU.
It's safe to call this function if XPU is not available; in that case, it is silently ignored.
Args:
seed (int): The desired seed.
.. warning::
If you are working with a multi-GPU model, this function is insufficient
to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
"""
seed = int(seed)
def cb():
idx = current_device()
default_generator = torch.xpu.default_generators[idx]
default_generator.manual_seed(seed)
_lazy_call(cb, seed=True)
def manual_seed_all(seed: int) -> None:
r"""Set the seed for generating random numbers on all GPUs.
It's safe to call this function if XPU is not available; in that case, it is silently ignored.
Args:
seed (int): The desired seed.
"""
seed = int(seed)
def cb():
for i in range(device_count()):
default_generator = torch.xpu.default_generators[i]
default_generator.manual_seed(seed)
_lazy_call(cb, seed_all=True)
def seed() -> None:
r"""Set the seed for generating random numbers to a random number for the current GPU.
It's safe to call this function if XPU is not available; in that case, it is silently ignored.
.. warning::
If you are working with a multi-GPU model, this function will only initialize
the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
"""
def cb():
idx = current_device()
default_generator = torch.xpu.default_generators[idx]
default_generator.seed()
_lazy_call(cb)
def seed_all() -> None:
r"""Set the seed for generating random numbers to a random number on all GPUs.
It's safe to call this function if XPU is not available; in that case, it is silently ignored.
"""
def cb():
random_seed = 0
seeded = False
for i in range(device_count()):
default_generator = torch.xpu.default_generators[i]
if not seeded:
default_generator.seed()
random_seed = default_generator.initial_seed()
seeded = True
else:
default_generator.manual_seed(random_seed)
_lazy_call(cb)
def initial_seed() -> int:
r"""Return the current random seed of the current GPU.
.. warning::
This function eagerly initializes XPU.
"""
_lazy_init()
idx = current_device()
default_generator = torch.xpu.default_generators[idx]
return default_generator.initial_seed()
__all__ = [
"get_rng_state",
"get_rng_state_all",
"set_rng_state",
"set_rng_state_all",
"manual_seed",
"manual_seed_all",
"seed",
"seed_all",
"initial_seed",
]