You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
5.5 KiB

import csv
import os
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
from PIL import Image
from .utils import download_and_extract_archive
from .vision import VisionDataset
class Kitti(VisionDataset):
"""`KITTI <http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark>`_ Dataset.
It corresponds to the "left color images of object" dataset, for object detection.
Args:
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
Expects the following folder structure if download=False:
.. code::
<root>
└── Kitti
└─ raw
├── training
| ├── image_2
| └── label_2
└── testing
└── image_2
train (bool, optional): Use ``train`` split if true, else ``test`` split.
Defaults to ``train``.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample
and its target as entry and returns a transformed version.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
data_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/"
resources = [
"data_object_image_2.zip",
"data_object_label_2.zip",
]
image_dir_name = "image_2"
labels_dir_name = "label_2"
def __init__(
self,
root: Union[str, Path],
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
download: bool = False,
):
super().__init__(
root,
transform=transform,
target_transform=target_transform,
transforms=transforms,
)
self.images = []
self.targets = []
self.train = train
self._location = "training" if self.train else "testing"
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You may use download=True to download it.")
image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name)
if self.train:
labels_dir = os.path.join(self._raw_folder, self._location, self.labels_dir_name)
for img_file in os.listdir(image_dir):
self.images.append(os.path.join(image_dir, img_file))
if self.train:
self.targets.append(os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt"))
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""Get item at a given index.
Args:
index (int): Index
Returns:
tuple: (image, target), where
target is a list of dictionaries with the following keys:
- type: str
- truncated: float
- occluded: int
- alpha: float
- bbox: float[4]
- dimensions: float[3]
- locations: float[3]
- rotation_y: float
"""
image = Image.open(self.images[index])
target = self._parse_target(index) if self.train else None
if self.transforms:
image, target = self.transforms(image, target)
return image, target
def _parse_target(self, index: int) -> List:
target = []
with open(self.targets[index]) as inp:
content = csv.reader(inp, delimiter=" ")
for line in content:
target.append(
{
"type": line[0],
"truncated": float(line[1]),
"occluded": int(line[2]),
"alpha": float(line[3]),
"bbox": [float(x) for x in line[4:8]],
"dimensions": [float(x) for x in line[8:11]],
"location": [float(x) for x in line[11:14]],
"rotation_y": float(line[14]),
}
)
return target
def __len__(self) -> int:
return len(self.images)
@property
def _raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, "raw")
def _check_exists(self) -> bool:
"""Check if the data directory exists."""
folders = [self.image_dir_name]
if self.train:
folders.append(self.labels_dir_name)
return all(os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders)
def download(self) -> None:
"""Download the KITTI data if it doesn't exist already."""
if self._check_exists():
return
os.makedirs(self._raw_folder, exist_ok=True)
# download files
for fname in self.resources:
download_and_extract_archive(
url=f"{self.data_url}{fname}",
download_root=self._raw_folder,
filename=fname,
)