import os import pathlib from typing import Any, Callable, Optional, Tuple, Union import PIL.Image from .utils import download_and_extract_archive, verify_str_arg from .vision import VisionDataset class DTD(VisionDataset): """`Describable Textures Dataset (DTD) `_. Args: root (str or ``pathlib.Path``): Root directory of the dataset. split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``. .. note:: The partition only changes which split each image belongs to. Thus, regardless of the selected partition, combining all splits will result in all images. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Default is False. """ _URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz" _MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1" def __init__( self, root: Union[str, pathlib.Path], split: str = "train", partition: int = 1, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, ) -> None: self._split = verify_str_arg(split, "split", ("train", "val", "test")) if not isinstance(partition, int) and not (1 <= partition <= 10): raise ValueError( f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, " f"but got {partition} instead" ) self._partition = partition super().__init__(root, transform=transform, target_transform=target_transform) self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower() self._data_folder = self._base_folder / "dtd" self._meta_folder = self._data_folder / "labels" self._images_folder = self._data_folder / "images" if download: self._download() if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") self._image_files = [] classes = [] with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file: for line in file: cls, name = line.strip().split("/") self._image_files.append(self._images_folder.joinpath(cls, name)) classes.append(cls) self.classes = sorted(set(classes)) self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) self._labels = [self.class_to_idx[cls] for cls in classes] def __len__(self) -> int: return len(self._image_files) def __getitem__(self, idx: int) -> Tuple[Any, Any]: image_file, label = self._image_files[idx], self._labels[idx] image = PIL.Image.open(image_file).convert("RGB") if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) return image, label def extra_repr(self) -> str: return f"split={self._split}, partition={self._partition}" def _check_exists(self) -> bool: return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder) def _download(self) -> None: if self._check_exists(): return download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5)