import collections import os from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union from xml.etree.ElementTree import Element as ET_Element try: from defusedxml.ElementTree import parse as ET_parse except ImportError: from xml.etree.ElementTree import parse as ET_parse from PIL import Image from .utils import download_and_extract_archive, verify_str_arg from .vision import VisionDataset DATASET_YEAR_DICT = { "2012": { "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar", "filename": "VOCtrainval_11-May-2012.tar", "md5": "6cd6e144f989b92b3379bac3b3de84fd", "base_dir": os.path.join("VOCdevkit", "VOC2012"), }, "2011": { "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar", "filename": "VOCtrainval_25-May-2011.tar", "md5": "6c3384ef61512963050cb5d687e5bf1e", "base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"), }, "2010": { "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar", "filename": "VOCtrainval_03-May-2010.tar", "md5": "da459979d0c395079b5c75ee67908abb", "base_dir": os.path.join("VOCdevkit", "VOC2010"), }, "2009": { "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar", "filename": "VOCtrainval_11-May-2009.tar", "md5": "a3e00b113cfcfebf17e343f59da3caa1", "base_dir": os.path.join("VOCdevkit", "VOC2009"), }, "2008": { "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar", "filename": "VOCtrainval_11-May-2012.tar", "md5": "2629fa636546599198acfcfbfcf1904a", "base_dir": os.path.join("VOCdevkit", "VOC2008"), }, "2007": { "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar", "filename": "VOCtrainval_06-Nov-2007.tar", "md5": "c52e279531787c972589f7e41ab4ae64", "base_dir": os.path.join("VOCdevkit", "VOC2007"), }, "2007-test": { "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar", "filename": "VOCtest_06-Nov-2007.tar", "md5": "b6e924de25625d8de591ea690078ad9f", "base_dir": os.path.join("VOCdevkit", "VOC2007"), }, } class _VOCBase(VisionDataset): _SPLITS_DIR: str _TARGET_DIR: str _TARGET_FILE_EXT: str def __init__( self, root: Union[str, Path], year: str = "2012", image_set: str = "train", download: bool = False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, transforms: Optional[Callable] = None, ): super().__init__(root, transforms, transform, target_transform) self.year = verify_str_arg(year, "year", valid_values=[str(yr) for yr in range(2007, 2013)]) valid_image_sets = ["train", "trainval", "val"] if year == "2007": valid_image_sets.append("test") self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets) key = "2007-test" if year == "2007" and image_set == "test" else year dataset_year_dict = DATASET_YEAR_DICT[key] self.url = dataset_year_dict["url"] self.filename = dataset_year_dict["filename"] self.md5 = dataset_year_dict["md5"] base_dir = dataset_year_dict["base_dir"] voc_root = os.path.join(self.root, base_dir) if download: download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5) if not os.path.isdir(voc_root): raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR) split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt") with open(os.path.join(split_f)) as f: file_names = [x.strip() for x in f.readlines()] image_dir = os.path.join(voc_root, "JPEGImages") self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] target_dir = os.path.join(voc_root, self._TARGET_DIR) self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names] assert len(self.images) == len(self.targets) def __len__(self) -> int: return len(self.images) class VOCSegmentation(_VOCBase): """`Pascal VOC `_ Segmentation Dataset. Args: root (str or ``pathlib.Path``): Root directory of the VOC Dataset. year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``. image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If ``year=="2007"``, can also be ``"test"``. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. """ _SPLITS_DIR = "Segmentation" _TARGET_DIR = "SegmentationClass" _TARGET_FILE_EXT = ".png" @property def masks(self) -> List[str]: return self.targets def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (image, target) where target is the image segmentation. """ img = Image.open(self.images[index]).convert("RGB") target = Image.open(self.masks[index]) if self.transforms is not None: img, target = self.transforms(img, target) return img, target class VOCDetection(_VOCBase): """`Pascal VOC `_ Detection Dataset. Args: root (str or ``pathlib.Path``): Root directory of the VOC Dataset. year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``. image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If ``year=="2007"``, can also be ``"test"``. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. (default: alphabetic indexing of VOC's 20 classes). transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, required): A function/transform that takes in the target and transforms it. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. """ _SPLITS_DIR = "Main" _TARGET_DIR = "Annotations" _TARGET_FILE_EXT = ".xml" @property def annotations(self) -> List[str]: return self.targets def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (image, target) where target is a dictionary of the XML tree. """ img = Image.open(self.images[index]).convert("RGB") target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot()) if self.transforms is not None: img, target = self.transforms(img, target) return img, target @staticmethod def parse_voc_xml(node: ET_Element) -> Dict[str, Any]: voc_dict: Dict[str, Any] = {} children = list(node) if children: def_dic: Dict[str, Any] = collections.defaultdict(list) for dc in map(VOCDetection.parse_voc_xml, children): for ind, v in dc.items(): def_dic[ind].append(v) if node.tag == "annotation": def_dic["object"] = [def_dic["object"]] voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}} if node.text: text = node.text.strip() if not children: voc_dict[node.tag] = text return voc_dict