You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

94 lines
3.7 KiB

import json
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
import PIL.Image
from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
class Food101(VisionDataset):
"""`The Food-101 Data Set <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_.
The Food-101 is a challenging data set of 101 food categories with 101,000 images.
For each class, 250 manually reviewed test images are provided as well as 750 training images.
On purpose, the training images were not cleaned, and thus still contain some amount of noise.
This comes mostly in the form of intense colors and sometimes wrong labels. All images were
rescaled to have a maximum side length of 512 pixels.
Args:
root (str or ``pathlib.Path``): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
"""
_URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
_MD5 = "85eeb15f3717b99a5da872d97d918f87"
def __init__(
self,
root: Union[str, Path],
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "test"))
self._base_folder = Path(self.root) / "food-101"
self._meta_folder = self._base_folder / "meta"
self._images_folder = self._base_folder / "images"
if download:
self._download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
self._labels = []
self._image_files = []
with open(self._meta_folder / f"{split}.json") as f:
metadata = json.loads(f.read())
self.classes = sorted(metadata.keys())
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
for class_label, im_rel_paths in metadata.items():
self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)
self._image_files += [
self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths
]
def __len__(self) -> int:
return len(self._image_files)
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
def extra_repr(self) -> str:
return f"split={self._split}"
def _check_exists(self) -> bool:
return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))
def _download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)