You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
129 lines
4.6 KiB
129 lines
4.6 KiB
from typing import Optional, Union
|
|
|
|
import torch
|
|
|
|
|
|
class _remote_device:
|
|
"""
|
|
Represents a device on a remote worker.
|
|
|
|
Args:
|
|
remote_device (str or torch.device): Represents a device on a remote worker.
|
|
The string format should be one of the following:
|
|
|
|
1. "<workername>/<device>", where the device field can be parsed as torch.device type.
|
|
E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
|
|
In addition, the device field can be optional and the default value is "cpu".
|
|
2. "rank:<rank>/<device>", where <rank> is the rank of the
|
|
process and device can be parsed as torch.device type.
|
|
E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0"
|
|
3. <workername> and <rank> are optional and formats like "cpu"
|
|
and "cuda:1", just represent local devices.
|
|
"""
|
|
|
|
def __init__(self, remote_device: Union[str, torch.device]):
|
|
PARSE_ERROR = (
|
|
f"Could not parse remote_device: {remote_device}. The valid format is "
|
|
"'<workername>/<device>' or 'rank:<rank>/<device>' or '<device>'"
|
|
)
|
|
self._worker_name = None
|
|
self._rank = None
|
|
self._device: Optional[Union[str, int, torch.device]] = None
|
|
|
|
if isinstance(remote_device, torch.device):
|
|
self._device = remote_device
|
|
elif isinstance(remote_device, str):
|
|
fields = remote_device.split("/")
|
|
if len(fields) == 2:
|
|
self._worker_name, self._device = fields
|
|
elif len(fields) == 1:
|
|
# Check if this is a valid device.
|
|
if _remote_device._is_valid_local_device(fields[0]):
|
|
self._device = fields[0]
|
|
else:
|
|
self._worker_name = fields[0]
|
|
self._device = "cpu"
|
|
else:
|
|
raise ValueError(PARSE_ERROR)
|
|
else:
|
|
raise TypeError(f'Invalid type for remote_device: {type(remote_device)}')
|
|
|
|
# Do some basic sanity check (no empty string)
|
|
if self._worker_name is not None and not self._worker_name:
|
|
raise ValueError(PARSE_ERROR)
|
|
|
|
# Validate the device.
|
|
self._device = torch.device(self._device)
|
|
|
|
# Check for rank based format.
|
|
if self._worker_name is not None:
|
|
fields = self._worker_name.split(":")
|
|
if len(fields) == 2:
|
|
# rank:<rank>/device format, extract rank
|
|
if fields[0] == "rank" and fields[1].isdigit():
|
|
self._rank = int(fields[1]) # type: ignore[assignment]
|
|
self._worker_name = None
|
|
else:
|
|
raise ValueError(PARSE_ERROR)
|
|
elif len(fields) > 2:
|
|
raise ValueError(PARSE_ERROR)
|
|
|
|
@staticmethod
|
|
def _is_valid_local_device(device):
|
|
# Check for torch.device
|
|
try:
|
|
torch.device(device)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
def worker_name(self) -> Optional[str]:
|
|
"""Return the name of remote worker representing the remote device and ``None`` if no worker name is available."""
|
|
return self._worker_name
|
|
|
|
def rank(self) -> Optional[int]:
|
|
"""
|
|
Returns the rank of remote worker representing the remote device.
|
|
Returns ``None`` if no rank is available.
|
|
"""
|
|
return self._rank
|
|
|
|
def device(self) -> torch.device:
|
|
"""Return the local device on the remote worker."""
|
|
return self._device # type: ignore[return-value]
|
|
|
|
def __repr__(self):
|
|
if self._device is not None:
|
|
if self._worker_name is not None:
|
|
return f'{self._worker_name}/{self._device}'
|
|
elif self._rank is not None:
|
|
return f'rank:{self._rank}/{self._device}'
|
|
else:
|
|
return str(self._device)
|
|
else:
|
|
if self._worker_name is not None:
|
|
return f'{self._worker_name}'
|
|
elif self._rank is not None:
|
|
return f'{self._rank}'
|
|
else:
|
|
raise RuntimeError('Invalid state!')
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, _remote_device):
|
|
return False
|
|
|
|
if (
|
|
self._worker_name == other._worker_name
|
|
and self._device == other._device
|
|
and self._rank == other._rank
|
|
):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def __hash__(self):
|
|
return hash(self._worker_name) ^ \
|
|
hash(self._device) ^ \
|
|
hash(self._rank)
|