You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

173 lines
6.9 KiB

5 months ago
from typing import Dict, List, Optional, Union
import torch
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
from . import constants as rpc_contants
DeviceType = Union[int, str, torch.device]
__all__ = ["TensorPipeRpcBackendOptions"]
def _to_device(device: DeviceType) -> torch.device:
device = torch.device(device)
if device.type != "cuda":
raise ValueError(
"`set_devices` expect a list of CUDA devices, but got "
f"device type {device.type}."
)
return device
def _to_device_map(
device_map: Dict[DeviceType, DeviceType]
) -> Dict[torch.device, torch.device]:
full_device_map: Dict[torch.device, torch.device] = {}
reverse_map: Dict[torch.device, torch.device] = {}
for k, v in device_map.items():
k, v = torch.device(k), torch.device(v)
if v in reverse_map:
raise ValueError(
"`device_map` only supports 1-to-1 mapping, "
f"trying to map {k} and {reverse_map[v]} to {v}"
)
full_device_map[k] = v
reverse_map[v] = k
return full_device_map
def _to_device_list(devices: List[DeviceType]) -> List[torch.device]:
return list(map(_to_device, devices))
class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
r"""
The backend options for
:class:`~torch.distributed.rpc.TensorPipeAgent`, derived from
:class:`~torch.distributed.rpc.RpcBackendOptions`.
Args:
num_worker_threads (int, optional): The number of threads in the
thread-pool used by
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
requests (default: 16).
rpc_timeout (float, optional): The default timeout, in seconds,
for RPC requests (default: 60 seconds). If the RPC has not
completed in this timeframe, an exception indicating so will
be raised. Callers can override this timeout for individual
RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and
:meth:`~torch.distributed.rpc.rpc_async` if necessary.
init_method (str, optional): The URL to initialize the distributed
store used for rendezvous. It takes any value accepted for the
same argument of :meth:`~torch.distributed.init_process_group`
(default: ``env://``).
device_maps (Dict[str, Dict], optional): Device placement mappings from
this worker to the callee. Key is the callee worker name and value
the dictionary (``Dict`` of ``int``, ``str``, or ``torch.device``)
that maps this worker's devices to the callee worker's devices.
(default: ``None``)
devices (List[int, str, or ``torch.device``], optional): all local
CUDA devices used by RPC agent. By Default, it will be initialized
to all local devices from its own ``device_maps`` and corresponding
devices from its peers' ``device_maps``. When processing CUDA RPC
requests, the agent will properly synchronize CUDA streams for
all devices in this ``List``.
"""
def __init__(
self,
*,
num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None,
devices: Optional[List[DeviceType]] = None,
_transports: Optional[List] = None,
_channels: Optional[List] = None,
):
full_device_maps = (
{}
if device_maps is None
else {k: _to_device_map(v) for k, v in device_maps.items()}
)
full_device_list = [] if devices is None else _to_device_list(devices)
super().__init__(
num_worker_threads,
_transports,
_channels,
rpc_timeout,
init_method,
full_device_maps,
full_device_list,
)
def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]):
r"""
Set device mapping between each RPC caller and callee pair. This
function can be called multiple times to incrementally add
device placement configurations.
Args:
to (str): Callee name.
device_map (Dict of int, str, or torch.device): Device placement
mappings from this worker to the callee. This map must be
invertible.
Example:
>>> # xdoctest: +SKIP("distributed")
>>> # both workers
>>> def add(x, y):
>>> print(x) # tensor([1., 1.], device='cuda:1')
>>> return x + y, (x + y).to(2)
>>>
>>> # on worker 0
>>> options = TensorPipeRpcBackendOptions(
>>> num_worker_threads=8,
>>> device_maps={"worker1": {0: 1}}
>>> # maps worker0's cuda:0 to worker1's cuda:1
>>> )
>>> options.set_device_map("worker1", {1: 2})
>>> # maps worker0's cuda:1 to worker1's cuda:2
>>>
>>> rpc.init_rpc(
>>> "worker0",
>>> rank=0,
>>> world_size=2,
>>> backend=rpc.BackendType.TENSORPIPE,
>>> rpc_backend_options=options
>>> )
>>>
>>> x = torch.ones(2)
>>> rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1))
>>> # The first argument will be moved to cuda:1 on worker1. When
>>> # sending the return value back, it will follow the invert of
>>> # the device map, and hence will be moved back to cuda:0 and
>>> # cuda:1 on worker0
>>> print(rets[0]) # tensor([2., 2.], device='cuda:0')
>>> print(rets[1]) # tensor([2., 2.], device='cuda:1')
"""
full_device_map = _to_device_map(device_map)
curr_device_maps = super().device_maps
if to in curr_device_maps:
for k, v in full_device_map.items():
if k in curr_device_maps[to] and v != curr_device_maps[to][k]:
raise ValueError(
"`set_device_map` only supports 1-to-1 mapping, trying"
f" to map {k} to {v} and {curr_device_maps[to][k]}"
)
super()._set_device_map(to, full_device_map)
def set_devices(self, devices: List[DeviceType]):
r"""
Set local devices used by the TensorPipe RPC agent. When processing
CUDA RPC requests, the TensorPipe RPC agent will properly synchronize
CUDA streams for all devices in this ``List``.
Args:
devices (List of int, str, or torch.device): local devices used by
the TensorPipe RPC agent.
"""
self.devices = _to_device_list(devices)