You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

171 lines
5.1 KiB

from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Union
import torch
from torch.distributed.checkpoint.stateful import StatefulT
__all__ = [
"ChunkStorageMetadata",
"TensorStorageMetadata",
"BytesStorageMetadata",
"Metadata",
"MetadataIndex",
"TensorProperties",
]
@dataclass
class ChunkStorageMetadata:
"""
Each chunk is expected to have the same properties of the TensorStorageMetadata
that includes it.
"""
offsets: torch.Size
sizes: torch.Size
class _MEM_FORMAT_ENCODING(Enum):
"""Describe the memory format of a tensor."""
TORCH_CONTIGUOUS_FORMAT = 0
TORCH_CHANNELS_LAST = 1
TORCH_PRESERVE_FORMAT = 2
@dataclass
class TensorProperties:
"""Properties used to create :class:`Tensor`"""
# Regular tensor fields
dtype: torch.dtype = field(default_factory=torch.get_default_dtype)
# This field is deprecated.
layout: torch.layout = field(default=torch.strided)
# This field is deprecated.
requires_grad: bool = False
# This field is deprecated.
memory_format: torch.memory_format = field(default=torch.contiguous_format)
# This field is deprecated.
pin_memory: bool = False
def __getstate__(self):
# Since torch.memory_format cannot be pickled!
memory_format = self.memory_format
if memory_format == torch.contiguous_format:
mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT
elif memory_format == torch.channels_last:
mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST
elif memory_format == torch.preserve_format:
mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT
else:
raise RuntimeError(f"Invalid torch.memory_format: {memory_format}")
return (
self.dtype,
self.layout,
self.requires_grad,
mem_format_encoding,
self.pin_memory,
)
def __setstate__(
self,
state,
):
(
self.dtype,
self.layout,
self.requires_grad,
mem_format_encoding,
self.pin_memory,
) = state
if mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT:
memory_format = torch.contiguous_format
elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST:
memory_format = torch.channels_last
elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT:
memory_format = torch.preserve_format
else:
raise RuntimeError(
f"Invalid torch.memory_format encoding: {mem_format_encoding}"
)
self.memory_format = memory_format
@staticmethod
def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties":
return TensorProperties(
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
memory_format=torch.contiguous_format,
pin_memory=tensor.is_pinned(),
)
@dataclass
class TensorStorageMetadata:
properties: TensorProperties
size: torch.Size
chunks: List[ChunkStorageMetadata]
@dataclass
class BytesStorageMetadata:
pass
STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata]
STATE_DICT_TYPE = Dict[str, Union[StatefulT, Any]]
@dataclass
class Metadata:
"""This class represents the metadata of the checkpoint."""
# Keys are the same from the `state_dict` used.
state_dict_metadata: Dict[str, STORAGE_TYPES]
# It is the responsibility of the planner and storage plugins to ensure
# backward compatibility of the planner_data and storage_data. DCP will
# also ensure the backward compatibility of the metadata in this file and
# the metadata of the built-in planner and storage plugins.
planner_data: Any = None
storage_data: Any = None
@dataclass(frozen=True)
class MetadataIndex:
"""This class represents a lookup key for items in a state dict or Metadata."""
fqn: str
"""Fully Qualified Name of the object"""
offset: Optional[torch.Size] = None
"""If the object is a tensor, offset into the tensor we're looking for"""
index: Optional[int] = field(hash=False, compare=False, default=None)
"""
Index hint when searching for tensor chunk to speedup lookups (optional)
A common representation of a sharded tensor is as a list of chunks so to
find the index in such a list you need to linear search it.
When constructing an instance of MetadataIndex that points to that list,
one can provide the index as a hint and it will be probed first before
the linear search and thus making it significantly faster.
"""
def __init__(
self,
fqn: str,
offset: Optional[Sequence[int]] = None,
index: Optional[int] = None,
):
# We must use object.__setattr__ due to frozen=True
object.__setattr__(self, "fqn", fqn)
object.__setattr__(self, "index", index)
if offset is not None:
object.__setattr__(self, "offset", torch.Size(offset))