You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
4.1 KiB

import os.path
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
from PIL import Image
from .vision import VisionDataset
class CocoDetection(VisionDataset):
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
Args:
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
def __init__(
self,
root: Union[str, Path],
annFile: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
) -> None:
super().__init__(root, transforms, transform, target_transform)
from pycocotools.coco import COCO
self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys()))
def _load_image(self, id: int) -> Image.Image:
path = self.coco.loadImgs(id)[0]["file_name"]
return Image.open(os.path.join(self.root, path)).convert("RGB")
def _load_target(self, id: int) -> List[Any]:
return self.coco.loadAnns(self.coco.getAnnIds(id))
def __getitem__(self, index: int) -> Tuple[Any, Any]:
if not isinstance(index, int):
raise ValueError(f"Index must be of type integer, got {type(index)} instead.")
id = self.ids[index]
image = self._load_image(id)
target = self._load_target(id)
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def __len__(self) -> int:
return len(self.ids)
class CocoCaptions(CocoDetection):
"""`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
Args:
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
Example:
.. code:: python
import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
annFile = 'json annotation file',
transform=transforms.PILToTensor())
print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample
print("Image Size: ", img.size())
print(target)
Output: ::
Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']
"""
def _load_target(self, id: int) -> List[str]:
return [ann["caption"] for ann in super()._load_target(id)]