You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

68 lines
2.4 KiB

from typing import Any, Callable, Optional, Tuple
import torch
from .. import transforms
from .vision import VisionDataset
class FakeData(VisionDataset):
"""A fake dataset that returns randomly generated images and returns them as PIL images
Args:
size (int, optional): Size of the dataset. Default: 1000 images
image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224)
num_classes(int, optional): Number of classes in the dataset. Default: 10
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
random_offset (int): Offsets the index-based random seed used to
generate each image. Default: 0
"""
def __init__(
self,
size: int = 1000,
image_size: Tuple[int, int, int] = (3, 224, 224),
num_classes: int = 10,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
random_offset: int = 0,
) -> None:
super().__init__(transform=transform, target_transform=target_transform)
self.size = size
self.num_classes = num_classes
self.image_size = image_size
self.random_offset = random_offset
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
# create random image that is consistent with the index id
if index >= len(self):
raise IndexError(f"{self.__class__.__name__} index out of range")
rng_state = torch.get_rng_state()
torch.manual_seed(index + self.random_offset)
img = torch.randn(*self.image_size)
target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0]
torch.set_rng_state(rng_state)
# convert to PIL Image
img = transforms.ToPILImage()(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target.item()
def __len__(self) -> int:
return self.size