You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

77 lines
2.7 KiB

5 months ago
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
import PIL.Image
from .utils import download_and_extract_archive
from .vision import VisionDataset
class SUN397(VisionDataset):
"""`The SUN397 Data Set <https://vision.princeton.edu/projects/2010/SUN/>`_.
The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of
397 categories with 108'754 images.
Args:
root (str or ``pathlib.Path``): Root directory of the dataset.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
_DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
_DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a"
def __init__(
self,
root: Union[str, Path],
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._data_dir = Path(self.root) / "SUN397"
if download:
self._download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
with open(self._data_dir / "ClassName.txt") as f:
self.classes = [c[3:].strip() for c in f]
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
self._labels = [
self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
]
def __len__(self) -> int:
return len(self._image_files)
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
def _check_exists(self) -> bool:
return self._data_dir.is_dir()
def _download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)