You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

225 lines
8.6 KiB

import collections
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from xml.etree.ElementTree import Element as ET_Element
try:
from defusedxml.ElementTree import parse as ET_parse
except ImportError:
from xml.etree.ElementTree import parse as ET_parse
from PIL import Image
from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
DATASET_YEAR_DICT = {
"2012": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
"filename": "VOCtrainval_11-May-2012.tar",
"md5": "6cd6e144f989b92b3379bac3b3de84fd",
"base_dir": os.path.join("VOCdevkit", "VOC2012"),
},
"2011": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
"filename": "VOCtrainval_25-May-2011.tar",
"md5": "6c3384ef61512963050cb5d687e5bf1e",
"base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"),
},
"2010": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
"filename": "VOCtrainval_03-May-2010.tar",
"md5": "da459979d0c395079b5c75ee67908abb",
"base_dir": os.path.join("VOCdevkit", "VOC2010"),
},
"2009": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
"filename": "VOCtrainval_11-May-2009.tar",
"md5": "a3e00b113cfcfebf17e343f59da3caa1",
"base_dir": os.path.join("VOCdevkit", "VOC2009"),
},
"2008": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
"filename": "VOCtrainval_11-May-2012.tar",
"md5": "2629fa636546599198acfcfbfcf1904a",
"base_dir": os.path.join("VOCdevkit", "VOC2008"),
},
"2007": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
"filename": "VOCtrainval_06-Nov-2007.tar",
"md5": "c52e279531787c972589f7e41ab4ae64",
"base_dir": os.path.join("VOCdevkit", "VOC2007"),
},
"2007-test": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar",
"filename": "VOCtest_06-Nov-2007.tar",
"md5": "b6e924de25625d8de591ea690078ad9f",
"base_dir": os.path.join("VOCdevkit", "VOC2007"),
},
}
class _VOCBase(VisionDataset):
_SPLITS_DIR: str
_TARGET_DIR: str
_TARGET_FILE_EXT: str
def __init__(
self,
root: Union[str, Path],
year: str = "2012",
image_set: str = "train",
download: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super().__init__(root, transforms, transform, target_transform)
self.year = verify_str_arg(year, "year", valid_values=[str(yr) for yr in range(2007, 2013)])
valid_image_sets = ["train", "trainval", "val"]
if year == "2007":
valid_image_sets.append("test")
self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)
key = "2007-test" if year == "2007" and image_set == "test" else year
dataset_year_dict = DATASET_YEAR_DICT[key]
self.url = dataset_year_dict["url"]
self.filename = dataset_year_dict["filename"]
self.md5 = dataset_year_dict["md5"]
base_dir = dataset_year_dict["base_dir"]
voc_root = os.path.join(self.root, base_dir)
if download:
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
if not os.path.isdir(voc_root):
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
with open(os.path.join(split_f)) as f:
file_names = [x.strip() for x in f.readlines()]
image_dir = os.path.join(voc_root, "JPEGImages")
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
target_dir = os.path.join(voc_root, self._TARGET_DIR)
self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names]
assert len(self.images) == len(self.targets)
def __len__(self) -> int:
return len(self.images)
class VOCSegmentation(_VOCBase):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
Args:
root (str or ``pathlib.Path``): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
``year=="2007"``, can also be ``"test"``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
_SPLITS_DIR = "Segmentation"
_TARGET_DIR = "SegmentationClass"
_TARGET_FILE_EXT = ".png"
@property
def masks(self) -> List[str]:
return self.targets
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert("RGB")
target = Image.open(self.masks[index])
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
class VOCDetection(_VOCBase):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
Args:
root (str or ``pathlib.Path``): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
``year=="2007"``, can also be ``"test"``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
(default: alphabetic indexing of VOC's 20 classes).
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, required): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
_SPLITS_DIR = "Main"
_TARGET_DIR = "Annotations"
_TARGET_FILE_EXT = ".xml"
@property
def annotations(self) -> List[str]:
return self.targets
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a dictionary of the XML tree.
"""
img = Image.open(self.images[index]).convert("RGB")
target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
@staticmethod
def parse_voc_xml(node: ET_Element) -> Dict[str, Any]:
voc_dict: Dict[str, Any] = {}
children = list(node)
if children:
def_dic: Dict[str, Any] = collections.defaultdict(list)
for dc in map(VOCDetection.parse_voc_xml, children):
for ind, v in dc.items():
def_dic[ind].append(v)
if node.tag == "annotation":
def_dic["object"] = [def_dic["object"]]
voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}}
if node.text:
text = node.text.strip()
if not children:
voc_dict[node.tag] = text
return voc_dict