import os from os import path from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union from urllib.parse import urljoin from .folder import default_loader from .utils import check_integrity, download_and_extract_archive, verify_str_arg from .vision import VisionDataset class Places365(VisionDataset): r"""`Places365 `_ classification dataset. Args: root (str or ``pathlib.Path``): Root directory of the Places365 dataset. split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``, ``val``. small (bool, optional): If ``True``, uses the small images, i.e. resized to 256 x 256 pixels, instead of the high resolution ones. download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already downloaded archives are not downloaded again. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. Attributes: classes (list): List of the class names. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples targets (list): The class_index value for each image in the dataset Raises: RuntimeError: If ``download is False`` and the meta files, i.e. the devkit, are not present or corrupted. RuntimeError: If ``download is True`` and the image archive is already extracted. """ _SPLITS = ("train-standard", "train-challenge", "val") _BASE_URL = "http://data.csail.mit.edu/places/places365/" # {variant: (archive, md5)} _DEVKIT_META = { "standard": ("filelist_places365-standard.tar", "35a0585fee1fa656440f3ab298f8479c"), "challenge": ("filelist_places365-challenge.tar", "70a8307e459c3de41690a7c76c931734"), } # (file, md5) _CATEGORIES_META = ("categories_places365.txt", "06c963b85866bd0649f97cb43dd16673") # {split: (file, md5)} _FILE_LIST_META = { "train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"), "train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"), "val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"), } # {(split, small): (file, md5)} _IMAGES_META = { ("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"), ("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"), ("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"), ("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"), ("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"), ("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"), } def __init__( self, root: Union[str, Path], split: str = "train-standard", small: bool = False, download: bool = False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self.split = self._verify_split(split) self.small = small self.loader = loader self.classes, self.class_to_idx = self.load_categories(download) self.imgs, self.targets = self.load_file_list(download) if download: self.download_images() def __getitem__(self, index: int) -> Tuple[Any, Any]: file, target = self.imgs[index] image = self.loader(file) if self.transforms is not None: image, target = self.transforms(image, target) return image, target def __len__(self) -> int: return len(self.imgs) @property def variant(self) -> str: return "challenge" if "challenge" in self.split else "standard" @property def images_dir(self) -> str: size = "256" if self.small else "large" if self.split.startswith("train"): dir = f"data_{size}_{self.variant}" else: dir = f"{self.split}_{size}" return path.join(self.root, dir) def load_categories(self, download: bool = True) -> Tuple[List[str], Dict[str, int]]: def process(line: str) -> Tuple[str, int]: cls, idx = line.split() return cls, int(idx) file, md5 = self._CATEGORIES_META file = path.join(self.root, file) if not self._check_integrity(file, md5, download): self.download_devkit() with open(file) as fh: class_to_idx = dict(process(line) for line in fh) return sorted(class_to_idx.keys()), class_to_idx def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]: def process(line: str, sep="/") -> Tuple[str, int]: image, idx = line.split() return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), int(idx) file, md5 = self._FILE_LIST_META[self.split] file = path.join(self.root, file) if not self._check_integrity(file, md5, download): self.download_devkit() with open(file) as fh: images = [process(line) for line in fh] _, targets = zip(*images) return images, list(targets) def download_devkit(self) -> None: file, md5 = self._DEVKIT_META[self.variant] download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5) def download_images(self) -> None: if path.exists(self.images_dir): raise RuntimeError( f"The directory {self.images_dir} already exists. If you want to re-download or re-extract the images, " f"delete the directory." ) file, md5 = self._IMAGES_META[(self.split, self.small)] download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5) if self.split.startswith("train"): os.rename(self.images_dir.rsplit("_", 1)[0], self.images_dir) def extra_repr(self) -> str: return "\n".join(("Split: {split}", "Small: {small}")).format(**self.__dict__) def _verify_split(self, split: str) -> str: return verify_str_arg(split, "split", self._SPLITS) def _check_integrity(self, file: str, md5: str, download: bool) -> bool: integrity = check_integrity(file, md5=md5) if not integrity and not download: raise RuntimeError( f"The file {file} does not exist or is corrupted. You can set download=True to download it." ) return integrity