Skip to content

Commit 7da9afe

Browse files
prabhat00155fmassa
andauthored
Added KITTI dataset (#3640)
* Added KITTI dataset * Addressed review comments * Changed type of target to List[Dict] and corrected the data types of the returned values. * Updated unit test to rely on ImageDatasetTestCase * Added kitti to dataset documentation * Cleaned up test and some minor changes * Made data_url a string instead of a list * Removed unnecessary try and print Co-authored-by: Francisco Massa <[email protected]>
1 parent 1d0b43e commit 7da9afe

File tree

4 files changed

+207
-1
lines changed

4 files changed

+207
-1
lines changed

docs/source/datasets.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,13 @@ Kinetics-400
149149
:members: __getitem__
150150
:special-members:
151151

152+
KITTI
153+
~~~~~~~~~
154+
155+
.. autoclass:: Kitti
156+
:members: __getitem__
157+
:special-members:
158+
152159
KMNIST
153160
~~~~~~~~~~~~~
154161

test/test_datasets.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,5 +1702,41 @@ def test_classes(self, config):
17021702
self.assertSequenceEqual(dataset.classes, info["classes"])
17031703

17041704

1705+
class KittiTestCase(datasets_utils.ImageDatasetTestCase):
1706+
DATASET_CLASS = datasets.Kitti
1707+
FEATURE_TYPES = (PIL.Image.Image, (list, type(None))) # test split returns None as target
1708+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
1709+
1710+
def inject_fake_data(self, tmpdir, config):
1711+
kitti_dir = os.path.join(tmpdir, "Kitti", "raw")
1712+
os.makedirs(kitti_dir)
1713+
1714+
split_to_num_examples = {
1715+
True: 1,
1716+
False: 2,
1717+
}
1718+
1719+
# We need to create all folders(training and testing).
1720+
for is_training in (True, False):
1721+
num_examples = split_to_num_examples[is_training]
1722+
1723+
datasets_utils.create_image_folder(
1724+
root=kitti_dir,
1725+
name=os.path.join("training" if is_training else "testing", "image_2"),
1726+
file_name_fn=lambda image_idx: f"{image_idx:06d}.png",
1727+
num_examples=num_examples,
1728+
)
1729+
if is_training:
1730+
for image_idx in range(num_examples):
1731+
target_file_dir = os.path.join(kitti_dir, "training", "label_2")
1732+
os.makedirs(target_file_dir)
1733+
target_file_name = os.path.join(target_file_dir, f"{image_idx:06d}.txt")
1734+
target_contents = "Pedestrian 0.00 0 -0.20 712.40 143.00 810.73 307.92 1.89 0.48 1.20 1.84 1.47 8.41 0.01\n" # noqa
1735+
with open(target_file_name, "w") as target_file:
1736+
target_file.write(target_contents)
1737+
1738+
return split_to_num_examples[config["train"]]
1739+
1740+
17051741
if __name__ == "__main__":
17061742
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .hmdb51 import HMDB51
2525
from .ucf101 import UCF101
2626
from .places365 import Places365
27+
from .kitti import Kitti
2728

2829
__all__ = ('LSUN', 'LSUNClass',
2930
'ImageFolder', 'DatasetFolder', 'FakeData',
@@ -34,4 +35,5 @@
3435
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
3536
'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset',
3637
'VisionDataset', 'USPS', 'Kinetics400', 'HMDB51', 'UCF101',
37-
'Places365')
38+
'Places365', 'Kitti',
39+
)

torchvision/datasets/kitti.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import csv
2+
import os
3+
from typing import Any, Callable, List, Optional, Tuple
4+
5+
from PIL import Image
6+
7+
from .utils import download_and_extract_archive
8+
from .vision import VisionDataset
9+
10+
11+
class Kitti(VisionDataset):
12+
"""`KITTI <http://www.cvlibs.net/datasets/kitti>`_ Dataset.
13+
14+
Args:
15+
root (string): Root directory where images are downloaded to.
16+
Expects the following folder structure if download=False:
17+
18+
.. code::
19+
20+
<root>
21+
└── Kitti
22+
└─ raw
23+
├── training
24+
| ├── image_2
25+
| └── label_2
26+
└── testing
27+
└── image_2
28+
train (bool, optional): Use ``train`` split if true, else ``test`` split.
29+
Defaults to ``train``.
30+
transform (callable, optional): A function/transform that takes in a PIL image
31+
and returns a transformed version. E.g, ``transforms.ToTensor``
32+
target_transform (callable, optional): A function/transform that takes in the
33+
target and transforms it.
34+
transforms (callable, optional): A function/transform that takes input sample
35+
and its target as entry and returns a transformed version.
36+
download (bool, optional): If true, downloads the dataset from the internet and
37+
puts it in root directory. If dataset is already downloaded, it is not
38+
downloaded again.
39+
40+
"""
41+
42+
data_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/"
43+
resources = [
44+
"data_object_image_2.zip",
45+
"data_object_label_2.zip",
46+
]
47+
image_dir_name = "image_2"
48+
labels_dir_name = "label_2"
49+
50+
def __init__(
51+
self,
52+
root: str,
53+
train: bool = True,
54+
transform: Optional[Callable] = None,
55+
target_transform: Optional[Callable] = None,
56+
transforms: Optional[Callable] = None,
57+
download: bool = False,
58+
):
59+
super().__init__(
60+
root,
61+
transform=transform,
62+
target_transform=target_transform,
63+
transforms=transforms,
64+
)
65+
self.images = []
66+
self.targets = []
67+
self.root = root
68+
self.train = train
69+
self._location = "training" if self.train else "testing"
70+
71+
if download:
72+
self.download()
73+
if not self._check_exists():
74+
raise RuntimeError(
75+
"Dataset not found. You may use download=True to download it."
76+
)
77+
78+
image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name)
79+
if self.train:
80+
labels_dir = os.path.join(self._raw_folder, self._location, self.labels_dir_name)
81+
for img_file in os.listdir(image_dir):
82+
self.images.append(os.path.join(image_dir, img_file))
83+
if self.train:
84+
self.targets.append(
85+
os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt")
86+
)
87+
88+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
89+
"""Get item at a given index.
90+
91+
Args:
92+
index (int): Index
93+
Returns:
94+
tuple: (image, target), where
95+
target is a list of dictionaries with the following keys:
96+
97+
- type: str
98+
- truncated: float
99+
- occluded: int
100+
- alpha: float
101+
- bbox: float[4]
102+
- dimensions: float[3]
103+
- locations: float[3]
104+
- rotation_y: float
105+
106+
"""
107+
image = Image.open(self.images[index])
108+
target = self._parse_target(index) if self.train else None
109+
if self.transforms:
110+
image, target = self.transforms(image, target)
111+
return image, target
112+
113+
def _parse_target(self, index: int) -> List:
114+
target = []
115+
with open(self.targets[index]) as inp:
116+
content = csv.reader(inp, delimiter=" ")
117+
for line in content:
118+
target.append({
119+
"type": line[0],
120+
"truncated": float(line[1]),
121+
"occluded": int(line[2]),
122+
"alpha": float(line[3]),
123+
"bbox": [float(x) for x in line[4:8]],
124+
"dimensions": [float(x) for x in line[8:11]],
125+
"location": [float(x) for x in line[11:14]],
126+
"rotation_y": float(line[14]),
127+
})
128+
return target
129+
130+
def __len__(self) -> int:
131+
return len(self.images)
132+
133+
@property
134+
def _raw_folder(self) -> str:
135+
return os.path.join(self.root, self.__class__.__name__, "raw")
136+
137+
def _check_exists(self) -> bool:
138+
"""Check if the data directory exists."""
139+
folders = [self.image_dir_name]
140+
if self.train:
141+
folders.append(self.labels_dir_name)
142+
return all(
143+
os.path.isdir(os.path.join(self._raw_folder, self._location, fname))
144+
for fname in folders
145+
)
146+
147+
def download(self) -> None:
148+
"""Download the KITTI data if it doesn't exist already."""
149+
150+
if self._check_exists():
151+
return
152+
153+
os.makedirs(self._raw_folder, exist_ok=True)
154+
155+
# download files
156+
for fname in self.resources:
157+
download_and_extract_archive(
158+
url=f"{self.data_url}{fname}",
159+
download_root=self._raw_folder,
160+
filename=fname,
161+
)

0 commit comments

Comments
 (0)