You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
104 lines
4.1 KiB
104 lines
4.1 KiB
5 months ago
|
from os.path import join
|
||
|
from pathlib import Path
|
||
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||
|
|
||
|
from PIL import Image
|
||
|
|
||
|
from .utils import check_integrity, download_and_extract_archive, list_dir, list_files
|
||
|
from .vision import VisionDataset
|
||
|
|
||
|
|
||
|
class Omniglot(VisionDataset):
|
||
|
"""`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
|
||
|
|
||
|
Args:
|
||
|
root (str or ``pathlib.Path``): Root directory of dataset where directory
|
||
|
``omniglot-py`` exists.
|
||
|
background (bool, optional): If True, creates dataset from the "background" set, otherwise
|
||
|
creates from the "evaluation" set. This terminology is defined by the authors.
|
||
|
transform (callable, optional): A function/transform that takes in a PIL image
|
||
|
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||
|
target_transform (callable, optional): A function/transform that takes in the
|
||
|
target and transforms it.
|
||
|
download (bool, optional): If true, downloads the dataset zip files from the internet and
|
||
|
puts it in root directory. If the zip files are already downloaded, they are not
|
||
|
downloaded again.
|
||
|
"""
|
||
|
|
||
|
folder = "omniglot-py"
|
||
|
download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python"
|
||
|
zips_md5 = {
|
||
|
"images_background": "68d2efa1b9178cc56df9314c21c6e718",
|
||
|
"images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811",
|
||
|
}
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
root: Union[str, Path],
|
||
|
background: bool = True,
|
||
|
transform: Optional[Callable] = None,
|
||
|
target_transform: Optional[Callable] = None,
|
||
|
download: bool = False,
|
||
|
) -> None:
|
||
|
super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
|
||
|
self.background = background
|
||
|
|
||
|
if download:
|
||
|
self.download()
|
||
|
|
||
|
if not self._check_integrity():
|
||
|
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
|
||
|
|
||
|
self.target_folder = join(self.root, self._get_target_folder())
|
||
|
self._alphabets = list_dir(self.target_folder)
|
||
|
self._characters: List[str] = sum(
|
||
|
([join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets), []
|
||
|
)
|
||
|
self._character_images = [
|
||
|
[(image, idx) for image in list_files(join(self.target_folder, character), ".png")]
|
||
|
for idx, character in enumerate(self._characters)
|
||
|
]
|
||
|
self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return len(self._flat_character_images)
|
||
|
|
||
|
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||
|
"""
|
||
|
Args:
|
||
|
index (int): Index
|
||
|
|
||
|
Returns:
|
||
|
tuple: (image, target) where target is index of the target character class.
|
||
|
"""
|
||
|
image_name, character_class = self._flat_character_images[index]
|
||
|
image_path = join(self.target_folder, self._characters[character_class], image_name)
|
||
|
image = Image.open(image_path, mode="r").convert("L")
|
||
|
|
||
|
if self.transform:
|
||
|
image = self.transform(image)
|
||
|
|
||
|
if self.target_transform:
|
||
|
character_class = self.target_transform(character_class)
|
||
|
|
||
|
return image, character_class
|
||
|
|
||
|
def _check_integrity(self) -> bool:
|
||
|
zip_filename = self._get_target_folder()
|
||
|
if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]):
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
def download(self) -> None:
|
||
|
if self._check_integrity():
|
||
|
print("Files already downloaded and verified")
|
||
|
return
|
||
|
|
||
|
filename = self._get_target_folder()
|
||
|
zip_filename = filename + ".zip"
|
||
|
url = self.download_url_prefix + "/" + zip_filename
|
||
|
download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
|
||
|
|
||
|
def _get_target_folder(self) -> str:
|
||
|
return "images_background" if self.background else "images_evaluation"
|