You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
110 lines
4.1 KiB
110 lines
4.1 KiB
import os.path
|
|
from pathlib import Path
|
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
|
|
|
from PIL import Image
|
|
|
|
from .vision import VisionDataset
|
|
|
|
|
|
class CocoDetection(VisionDataset):
|
|
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
|
|
|
|
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
|
|
|
|
Args:
|
|
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
|
|
annFile (string): Path to json annotation file.
|
|
transform (callable, optional): A function/transform that takes in a PIL image
|
|
and returns a transformed version. E.g, ``transforms.PILToTensor``
|
|
target_transform (callable, optional): A function/transform that takes in the
|
|
target and transforms it.
|
|
transforms (callable, optional): A function/transform that takes input sample and its target as entry
|
|
and returns a transformed version.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
root: Union[str, Path],
|
|
annFile: str,
|
|
transform: Optional[Callable] = None,
|
|
target_transform: Optional[Callable] = None,
|
|
transforms: Optional[Callable] = None,
|
|
) -> None:
|
|
super().__init__(root, transforms, transform, target_transform)
|
|
from pycocotools.coco import COCO
|
|
|
|
self.coco = COCO(annFile)
|
|
self.ids = list(sorted(self.coco.imgs.keys()))
|
|
|
|
def _load_image(self, id: int) -> Image.Image:
|
|
path = self.coco.loadImgs(id)[0]["file_name"]
|
|
return Image.open(os.path.join(self.root, path)).convert("RGB")
|
|
|
|
def _load_target(self, id: int) -> List[Any]:
|
|
return self.coco.loadAnns(self.coco.getAnnIds(id))
|
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
|
|
|
if not isinstance(index, int):
|
|
raise ValueError(f"Index must be of type integer, got {type(index)} instead.")
|
|
|
|
id = self.ids[index]
|
|
image = self._load_image(id)
|
|
target = self._load_target(id)
|
|
|
|
if self.transforms is not None:
|
|
image, target = self.transforms(image, target)
|
|
|
|
return image, target
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.ids)
|
|
|
|
|
|
class CocoCaptions(CocoDetection):
|
|
"""`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
|
|
|
|
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
|
|
|
|
Args:
|
|
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
|
|
annFile (string): Path to json annotation file.
|
|
transform (callable, optional): A function/transform that takes in a PIL image
|
|
and returns a transformed version. E.g, ``transforms.PILToTensor``
|
|
target_transform (callable, optional): A function/transform that takes in the
|
|
target and transforms it.
|
|
transforms (callable, optional): A function/transform that takes input sample and its target as entry
|
|
and returns a transformed version.
|
|
|
|
Example:
|
|
|
|
.. code:: python
|
|
|
|
import torchvision.datasets as dset
|
|
import torchvision.transforms as transforms
|
|
cap = dset.CocoCaptions(root = 'dir where images are',
|
|
annFile = 'json annotation file',
|
|
transform=transforms.PILToTensor())
|
|
|
|
print('Number of samples: ', len(cap))
|
|
img, target = cap[3] # load 4th sample
|
|
|
|
print("Image Size: ", img.size())
|
|
print(target)
|
|
|
|
Output: ::
|
|
|
|
Number of samples: 82783
|
|
Image Size: (3L, 427L, 640L)
|
|
[u'A plane emitting smoke stream flying over a mountain.',
|
|
u'A plane darts across a bright blue sky behind a mountain covered in snow',
|
|
u'A plane leaves a contrail above the snowy mountain top.',
|
|
u'A mountain that has a plane flying overheard in the distance.',
|
|
u'A mountain view with a plume of smoke in the background']
|
|
|
|
"""
|
|
|
|
def _load_target(self, id: int) -> List[str]:
|
|
return [ann["caption"] for ann in super()._load_target(id)]
|