You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
430 lines
14 KiB
430 lines
14 KiB
import cProfile
|
|
import inspect
|
|
import io
|
|
import itertools
|
|
import os
|
|
import warnings
|
|
from contextlib import contextmanager
|
|
from functools import wraps
|
|
from pstats import Stats
|
|
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TypeVar, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
|
from torch.distributed._shard.sharded_tensor.shard import Shard
|
|
from torch.distributed._tensor import DTensor
|
|
|
|
from .api import (
|
|
_is_wrapped_exception,
|
|
_wrap_exception,
|
|
CheckpointException,
|
|
WRAPPED_EXCEPTION,
|
|
)
|
|
from .metadata import MetadataIndex, STATE_DICT_TYPE
|
|
|
|
__all__ = ["find_tensor_shard", "find_state_dict_object"]
|
|
|
|
T = TypeVar("T")
|
|
R = TypeVar("R")
|
|
|
|
|
|
def _get_failure_dict(
|
|
results: List[Union[T, WRAPPED_EXCEPTION]]
|
|
) -> Dict[int, WRAPPED_EXCEPTION]:
|
|
return cast(
|
|
Dict[int, WRAPPED_EXCEPTION],
|
|
{i: err for i, err in enumerate(results) if _is_wrapped_exception(err)},
|
|
)
|
|
|
|
|
|
def _all_gather_keys(
|
|
local_dict: Dict[Any, Any], group: Optional[dist.ProcessGroup] = None
|
|
) -> List[Any]:
|
|
"""Gathers all keys, and returns them sorted."""
|
|
keys = list(local_dict.keys())
|
|
gathered_keys: List[List[Any]] = [None] * dist.get_world_size() # type: ignore[list-item]
|
|
|
|
dist.all_gather_object(gathered_keys, keys, group=group)
|
|
return sorted(set(itertools.chain.from_iterable(gathered_keys)))
|
|
|
|
|
|
class _DistWrapper:
|
|
"""
|
|
This is a wrapper around PG that provides a series of features around object collectives.
|
|
|
|
It works without distributed initialized, where most collectives turns into nops.
|
|
|
|
All variants that take functions are exception robust, meaning that if one or more
|
|
ranks raise errors, all ranks will observe those.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
group: Optional[dist.ProcessGroup],
|
|
use_dist: bool,
|
|
coordinator_rank: int,
|
|
):
|
|
self.group = group
|
|
self.use_dist = use_dist
|
|
self.coordinator_rank = coordinator_rank
|
|
if self.use_dist:
|
|
self.rank = dist.get_rank(group)
|
|
self.is_coordinator = self.rank == coordinator_rank
|
|
else:
|
|
self.rank = 0
|
|
self.is_coordinator = True
|
|
|
|
def get_rank(self) -> int:
|
|
return self.rank
|
|
|
|
def get_world_size(self) -> int:
|
|
if self.use_dist:
|
|
return dist.get_world_size(self.group)
|
|
return 1
|
|
|
|
def broadcast_object(self, object: Optional[T]) -> T:
|
|
"""Implement functionality similar to c10d::broadcast_object_list but without distributed enabled."""
|
|
object_list = [object]
|
|
if self.use_dist:
|
|
dist.broadcast_object_list(
|
|
object_list=object_list,
|
|
group=self.group,
|
|
src=self.coordinator_rank,
|
|
)
|
|
return cast(T, object_list[0])
|
|
|
|
def gather_object(self, object: T) -> Optional[List[T]]:
|
|
"""Implement functionality similar to c10d::gather_object but without distributed enabled."""
|
|
if self.use_dist:
|
|
gather_objs = (
|
|
cast(List[T], [None] * dist.get_world_size(self.group))
|
|
if self.is_coordinator
|
|
else None
|
|
)
|
|
|
|
dist.gather_object(
|
|
obj=object,
|
|
object_gather_list=gather_objs if self.is_coordinator else None,
|
|
dst=self.coordinator_rank,
|
|
group=self.group,
|
|
)
|
|
result = gather_objs
|
|
else:
|
|
result = [object]
|
|
return result
|
|
|
|
def all_gather_object(self, object: T) -> List[T]:
|
|
"""Implement functionality similar to c10d::all_gather_object but without distributed enabled."""
|
|
if self.use_dist:
|
|
gather_objs = cast(List[T], [None] * dist.get_world_size(self.group))
|
|
|
|
dist.all_gather_object(
|
|
object_list=gather_objs, obj=object, group=self.group
|
|
)
|
|
else:
|
|
gather_objs = [object]
|
|
return gather_objs
|
|
|
|
def scatter_object(self, object_list: Optional[List[T]]) -> T:
|
|
"""Implement functionality similar to c10d::scatter_object but without distributed enabled."""
|
|
if self.use_dist:
|
|
gather_result = cast(List[T], [None])
|
|
dist.scatter_object_list(
|
|
scatter_object_output_list=gather_result,
|
|
scatter_object_input_list=object_list if self.is_coordinator else None,
|
|
src=self.coordinator_rank,
|
|
group=self.group,
|
|
)
|
|
|
|
local_reply = gather_result[0]
|
|
else:
|
|
assert object_list is not None
|
|
local_reply = object_list[0]
|
|
return local_reply
|
|
|
|
def reduce_scatter(
|
|
self,
|
|
step: str,
|
|
map_fun: Callable[[], T],
|
|
reduce_fun: Callable[[List[T]], List[R]],
|
|
) -> R:
|
|
"""
|
|
Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter.
|
|
|
|
This method operates in the following way:
|
|
Run ``map_fun`` on all ranks
|
|
Gather results on rank 0
|
|
Call ``reduce_fun`` on all those values
|
|
Scatter to each rank part of the result.
|
|
"""
|
|
local_data: Union[WRAPPED_EXCEPTION, T]
|
|
try:
|
|
local_data = map_fun()
|
|
except BaseException as e:
|
|
local_data = _wrap_exception(e)
|
|
|
|
all_data = self.gather_object(local_data)
|
|
all_results: Optional[List[Union[R, CheckpointException]]] = None
|
|
if self.is_coordinator:
|
|
assert all_data is not None
|
|
node_failures = _get_failure_dict(all_data)
|
|
|
|
if len(node_failures) == 0:
|
|
try:
|
|
# N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]?
|
|
all_results = cast(
|
|
List[Union[R, CheckpointException]],
|
|
reduce_fun(cast(List[T], all_data)),
|
|
)
|
|
except BaseException as e:
|
|
node_failures[self.rank] = _wrap_exception(e)
|
|
|
|
if len(node_failures) > 0:
|
|
all_results = [
|
|
CheckpointException(step, node_failures)
|
|
] * self.get_world_size()
|
|
|
|
result = self.scatter_object(all_results)
|
|
if isinstance(result, CheckpointException):
|
|
raise result
|
|
return result
|
|
|
|
def all_reduce(
|
|
self,
|
|
step: str,
|
|
map_fun: Callable[[], T],
|
|
reduce_fun: Callable[[List[T]], R],
|
|
) -> R:
|
|
"""
|
|
Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast.
|
|
|
|
This method operates in the following way:
|
|
Run ``map_fun`` on all ranks
|
|
Gather results on rank 0
|
|
Call ``reduce_fun`` on all those values
|
|
Broadcast the reduced value to all ranks.
|
|
"""
|
|
local_data: Union[T, WRAPPED_EXCEPTION]
|
|
try:
|
|
local_data = map_fun()
|
|
except BaseException as e:
|
|
local_data = _wrap_exception(e)
|
|
|
|
all_data = self.gather_object(local_data)
|
|
result: Optional[Union[R, CheckpointException]] = None
|
|
if self.is_coordinator:
|
|
assert all_data is not None
|
|
node_failures = _get_failure_dict(all_data)
|
|
if len(node_failures) == 0:
|
|
try:
|
|
result = reduce_fun(cast(List[T], all_data))
|
|
except BaseException as e:
|
|
node_failures[self.rank] = _wrap_exception(e)
|
|
|
|
if len(node_failures) > 0:
|
|
result = CheckpointException(step, node_failures)
|
|
|
|
final_result = self.broadcast_object(result)
|
|
if isinstance(final_result, CheckpointException):
|
|
raise final_result
|
|
return cast(R, final_result)
|
|
|
|
def all_gather(
|
|
self,
|
|
step: str,
|
|
map_fun: Callable[[], T],
|
|
) -> List[T]:
|
|
"""
|
|
Compute a value on each rank, then all_gather them.
|
|
|
|
This method operates in the following way:
|
|
Run ``map_cp`` on all ranks
|
|
all_gather the values to all ranks
|
|
"""
|
|
result: Union[T, WRAPPED_EXCEPTION]
|
|
try:
|
|
result = map_fun()
|
|
except BaseException as e:
|
|
result = _wrap_exception(e)
|
|
|
|
all_results = self.all_gather_object(result)
|
|
|
|
node_failures = _get_failure_dict(all_results)
|
|
if len(node_failures) > 0:
|
|
raise CheckpointException(step, node_failures)
|
|
return cast(List[T], all_results)
|
|
|
|
def broadcast(
|
|
self,
|
|
step: str,
|
|
map_fun: Callable[[], T],
|
|
) -> T:
|
|
"""
|
|
Compute a value on rank 0 and broadcast it.
|
|
|
|
This method operates in the following way:
|
|
Run ``map_cp`` on rank 0
|
|
broadcast the value
|
|
"""
|
|
result: Optional[Union[T, CheckpointException]] = None
|
|
if self.is_coordinator:
|
|
try:
|
|
result = map_fun()
|
|
except BaseException as e:
|
|
result = CheckpointException(step, {self.rank: _wrap_exception(e)})
|
|
final_result = self.broadcast_object(result)
|
|
if isinstance(final_result, CheckpointException):
|
|
raise final_result
|
|
return cast(T, final_result)
|
|
|
|
|
|
def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard:
|
|
if index.offset is None:
|
|
raise ValueError(
|
|
f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided"
|
|
)
|
|
|
|
shards = tensor.local_shards()
|
|
# index fast path
|
|
if index.index is not None:
|
|
if (
|
|
len(shards) > index.index
|
|
and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset
|
|
):
|
|
return shards[index.index]
|
|
|
|
for shard in shards:
|
|
if torch.Size(shard.metadata.shard_offsets) == index.offset:
|
|
return shard
|
|
raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'")
|
|
|
|
|
|
def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor:
|
|
if isinstance(tensor, DTensor):
|
|
return tensor.to_local()
|
|
if isinstance(tensor, ShardedTensor):
|
|
return _find_shard(tensor, index).tensor
|
|
if index.offset is not None:
|
|
# special case looking up a tensor by origin
|
|
if index.offset == torch.Size([0] * len(tensor.size())):
|
|
return tensor
|
|
raise ValueError(
|
|
f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
|
|
)
|
|
return tensor
|
|
|
|
|
|
def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any:
|
|
if index.fqn not in state_dict:
|
|
raise ValueError(f"Could not find FQN: '{index.fqn}'")
|
|
obj = state_dict[index.fqn]
|
|
|
|
if isinstance(obj, torch.Tensor):
|
|
return find_tensor_shard(obj, index)
|
|
elif index.offset is not None:
|
|
raise ValueError(
|
|
f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
|
|
)
|
|
return obj
|
|
|
|
|
|
def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]:
|
|
return [i_a + i_b for i_a, i_b in zip(a, b)]
|
|
|
|
|
|
def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]:
|
|
return [i_a - i_b for i_a, i_b in zip(a, b)]
|
|
|
|
|
|
class _ReaderView(io.IOBase):
|
|
def __init__(self, base_stream: io.IOBase, offset: int, len: int):
|
|
super().__init__()
|
|
self.offset = offset
|
|
self.len = len
|
|
self.base_stream = base_stream
|
|
self.seek(0)
|
|
|
|
def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int:
|
|
if __whence == os.SEEK_SET:
|
|
__offset = self.offset + __offset
|
|
elif __whence == os.SEEK_END:
|
|
__whence = os.SEEK_SET
|
|
__offset = (self.offset + self.len) - __offset
|
|
return self.base_stream.seek(__offset, __whence)
|
|
|
|
def tell(self) -> int:
|
|
return self.base_stream.tell() - self.offset
|
|
|
|
def readable(self) -> bool:
|
|
return self.base_stream.readable()
|
|
|
|
def seekable(self) -> bool:
|
|
return self.base_stream.seekable()
|
|
|
|
def readinto(self, b):
|
|
return self.base_stream.readinto(b) # type: ignore[attr-defined]
|
|
|
|
def read(self, size=-1):
|
|
return self.base_stream.read(size)
|
|
|
|
|
|
def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase:
|
|
# FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader
|
|
return _ReaderView(file, offset, length)
|
|
|
|
|
|
def _normalize_device_info(device_type: str, device_id: int) -> str:
|
|
"""Device info normalization."""
|
|
if device_type == "cpu":
|
|
return "cpu"
|
|
return f"{device_type}:{device_id}"
|
|
|
|
|
|
# TODO: integrate with distributed logging flag
|
|
ENABLE_PROFILE = False
|
|
|
|
|
|
@contextmanager
|
|
def _profile():
|
|
# Only log the profiling when it is enable and is on rank0 or dist is not
|
|
# avaiable.
|
|
if ENABLE_PROFILE and (not dist.is_available() or dist.get_rank() == 0):
|
|
profiler = cProfile.Profile()
|
|
profiler.enable()
|
|
try:
|
|
yield
|
|
finally:
|
|
profiler.disable()
|
|
stats = Stats(profiler)
|
|
stats.sort_stats("time").print_stats(10)
|
|
else:
|
|
yield
|
|
|
|
|
|
def _api_bc_check(func):
|
|
@wraps(func)
|
|
def inner_func(*args, **kwargs) -> Any:
|
|
if len(args) == 2:
|
|
warnings.warn(
|
|
f"The argument order of {func.__name__} has been changed. "
|
|
"Please check the document to avoid future breakages."
|
|
)
|
|
sig = inspect.signature(func)
|
|
kwonlyargs = [
|
|
p.name for p in sig.parameters.values() if p.kind == p.KEYWORD_ONLY
|
|
]
|
|
if "storage_writer" in kwonlyargs:
|
|
assert "storage_writer" not in kwargs, (args, kwargs)
|
|
kwargs["storage_writer"] = args[1]
|
|
elif "storage_reader" in kwonlyargs:
|
|
assert "storage_reader" not in kwargs, (args, kwargs)
|
|
kwargs["storage_reader"] = args[1]
|
|
else:
|
|
raise RuntimeError(f"Unexpected kwonlyargs = {kwonlyargs}")
|
|
return func(args[0], **kwargs)
|
|
else:
|
|
return func(*args, **kwargs)
|
|
|
|
return inner_func
|