from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional, Sequence, Union import torch from torch.distributed.checkpoint.stateful import StatefulT __all__ = [ "ChunkStorageMetadata", "TensorStorageMetadata", "BytesStorageMetadata", "Metadata", "MetadataIndex", "TensorProperties", ] @dataclass class ChunkStorageMetadata: """ Each chunk is expected to have the same properties of the TensorStorageMetadata that includes it. """ offsets: torch.Size sizes: torch.Size class _MEM_FORMAT_ENCODING(Enum): """Describe the memory format of a tensor.""" TORCH_CONTIGUOUS_FORMAT = 0 TORCH_CHANNELS_LAST = 1 TORCH_PRESERVE_FORMAT = 2 @dataclass class TensorProperties: """Properties used to create :class:`Tensor`""" # Regular tensor fields dtype: torch.dtype = field(default_factory=torch.get_default_dtype) # This field is deprecated. layout: torch.layout = field(default=torch.strided) # This field is deprecated. requires_grad: bool = False # This field is deprecated. memory_format: torch.memory_format = field(default=torch.contiguous_format) # This field is deprecated. pin_memory: bool = False def __getstate__(self): # Since torch.memory_format cannot be pickled! memory_format = self.memory_format if memory_format == torch.contiguous_format: mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT elif memory_format == torch.channels_last: mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST elif memory_format == torch.preserve_format: mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT else: raise RuntimeError(f"Invalid torch.memory_format: {memory_format}") return ( self.dtype, self.layout, self.requires_grad, mem_format_encoding, self.pin_memory, ) def __setstate__( self, state, ): ( self.dtype, self.layout, self.requires_grad, mem_format_encoding, self.pin_memory, ) = state if mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: memory_format = torch.contiguous_format elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST: memory_format = torch.channels_last elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: memory_format = torch.preserve_format else: raise RuntimeError( f"Invalid torch.memory_format encoding: {mem_format_encoding}" ) self.memory_format = memory_format @staticmethod def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties": return TensorProperties( dtype=tensor.dtype, layout=tensor.layout, requires_grad=tensor.requires_grad, memory_format=torch.contiguous_format, pin_memory=tensor.is_pinned(), ) @dataclass class TensorStorageMetadata: properties: TensorProperties size: torch.Size chunks: List[ChunkStorageMetadata] @dataclass class BytesStorageMetadata: pass STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata] STATE_DICT_TYPE = Dict[str, Union[StatefulT, Any]] @dataclass class Metadata: """This class represents the metadata of the checkpoint.""" # Keys are the same from the `state_dict` used. state_dict_metadata: Dict[str, STORAGE_TYPES] # It is the responsibility of the planner and storage plugins to ensure # backward compatibility of the planner_data and storage_data. DCP will # also ensure the backward compatibility of the metadata in this file and # the metadata of the built-in planner and storage plugins. planner_data: Any = None storage_data: Any = None @dataclass(frozen=True) class MetadataIndex: """This class represents a lookup key for items in a state dict or Metadata.""" fqn: str """Fully Qualified Name of the object""" offset: Optional[torch.Size] = None """If the object is a tensor, offset into the tensor we're looking for""" index: Optional[int] = field(hash=False, compare=False, default=None) """ Index hint when searching for tensor chunk to speedup lookups (optional) A common representation of a sharded tensor is as a list of chunks so to find the index in such a list you need to linear search it. When constructing an instance of MetadataIndex that points to that list, one can provide the index as a hint and it will be probed first before the linear search and thus making it significantly faster. """ def __init__( self, fqn: str, offset: Optional[Sequence[int]] = None, index: Optional[int] = None, ): # We must use object.__setattr__ due to frozen=True object.__setattr__(self, "fqn", fqn) object.__setattr__(self, "index", index) if offset is not None: object.__setattr__(self, "offset", torch.Size(offset))