You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123 lines
3.7 KiB

5 months ago
# Mypy will not try inferring the types of any 3rd party libraries installed.
# mypy: ignore-errors
import io
import os
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Optional, Union
import fsspec
from fsspec import AbstractFileSystem
from fsspec.core import url_to_fs
from torch.distributed.checkpoint.filesystem import (
FileSystemBase,
FileSystemReader,
FileSystemWriter,
)
__all__ = [
"FsspecWriter",
"FsspecReader",
]
class FileSystem(FileSystemBase):
def __init__(self) -> None:
self.fs: Optional[AbstractFileSystem] = None
@contextmanager
def create_stream(
self, path: Union[str, os.PathLike], mode: str
) -> Generator[io.IOBase, None, None]:
assert self.fs is not None
with self.fs.transaction:
with fsspec.open(str(path), mode) as stream:
yield stream
def concat_path(
self, path: Union[str, os.PathLike], suffix: str
) -> Union[str, os.PathLike]:
return os.path.join(path, suffix)
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
self.fs, _ = url_to_fs(path)
return path
def rename(
self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
) -> None:
self.fs.rename(path, new_path)
def mkdir(self, path: [str, os.PathLike]) -> None:
self.fs.makedirs(path, exist_ok=True)
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
if isinstance(checkpoint_id, Path):
return False
try:
url_to_fs(checkpoint_id)
except ValueError as e:
return False
return True
class FsspecWriter(FileSystemWriter):
"""
Basic implementation of StorageWriter using FFspec.
This implementation makes the following assumptions and simplifications:
* The checkpoint path is an empty or non-existing directory.
* File creation is atomic
The checkpoint consist of one file per write request plus
a `.metadata` file with the serialized metadata.
"""
def __init__(
self,
path: Union[str, os.PathLike],
single_file_per_rank: bool = True,
sync_files: bool = True,
thread_count: int = 1,
per_thread_copy_ahead: int = 10_000_000,
) -> None:
"""
Initialize the writer pointing to `path`.
Args:
path: directory where the checkpoint will be written to.
single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
sync_files : force files to be synced to permanent storage. Default to True.
thread_count: Number of IO threads to use to write. Default to 1.
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
"""
super().__init__(
path, single_file_per_rank, sync_files, thread_count, per_thread_copy_ahead
)
self.fs = FileSystem()
self.path = self.fs.init_path(path)
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return FileSystem.validate_checkpoint_id(checkpoint_id)
class FsspecReader(FileSystemReader):
def __init__(self, path: Union[str, os.PathLike]) -> None:
super().__init__(path)
self.fs = FileSystem()
self.path = self.fs.init_path(path)
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return FileSystem.check(checkpoint_id)