Skip to content

Commit adf8466

Browse files
yiwen-songpmeier
andauthored
Adding fvgc_aircraft dataset (#5178)
* add fvgc_aircraft dataset * add docstring & remove useless import * resolve lint issue * address comments * adding more annotation level * nit * address comments * Apply suggestions from code review * unify format * remove useless line Co-authored-by: Philip Meier <[email protected]>
1 parent 1feb637 commit adf8466

File tree

4 files changed

+168
-0
lines changed

4 files changed

+168
-0
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
5050
FlyingChairs
5151
FlyingThings3D
5252
Food101
53+
FGVCAircraft
5354
GTSRB
5455
HD1K
5556
HMDB51

test/test_datasets.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2206,6 +2206,57 @@ def inject_fake_data(self, tmpdir: str, config):
22062206
return len(sampled_classes * n_samples_per_class)
22072207

22082208

2209+
class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase):
2210+
DATASET_CLASS = datasets.FGVCAircraft
2211+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
2212+
split=("train", "val", "trainval", "test"), annotation_level=("variant", "family", "manufacturer")
2213+
)
2214+
2215+
def inject_fake_data(self, tmpdir: str, config):
2216+
split = config["split"]
2217+
annotation_level = config["annotation_level"]
2218+
annotation_level_to_file = {
2219+
"variant": "variants.txt",
2220+
"family": "families.txt",
2221+
"manufacturer": "manufacturers.txt",
2222+
}
2223+
2224+
root_folder = pathlib.Path(tmpdir) / "fgvc-aircraft-2013b"
2225+
data_folder = root_folder / "data"
2226+
2227+
classes = ["707-320", "Hawk T1", "Tornado"]
2228+
num_images_per_class = 5
2229+
2230+
datasets_utils.create_image_folder(
2231+
data_folder,
2232+
"images",
2233+
file_name_fn=lambda idx: f"{idx}.jpg",
2234+
num_examples=num_images_per_class * len(classes),
2235+
)
2236+
2237+
annotation_file = data_folder / annotation_level_to_file[annotation_level]
2238+
with open(annotation_file, "w") as file:
2239+
file.write("\n".join(classes))
2240+
2241+
num_samples_per_class = 4 if split == "trainval" else 2
2242+
images_classes = []
2243+
for i in range(len(classes)):
2244+
images_classes.extend(
2245+
[
2246+
f"{idx} {classes[i]}"
2247+
for idx in random.sample(
2248+
range(i * num_images_per_class, (i + 1) * num_images_per_class), num_samples_per_class
2249+
)
2250+
]
2251+
)
2252+
2253+
images_annotation_file = data_folder / f"images_{annotation_level}_{split}.txt"
2254+
with open(images_annotation_file, "w") as file:
2255+
file.write("\n".join(images_classes))
2256+
2257+
return len(classes * num_samples_per_class)
2258+
2259+
22092260
class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
22102261
DATASET_CLASS = datasets.SUN397
22112262

torchvision/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .dtd import DTD
1010
from .fakedata import FakeData
1111
from .fer2013 import FER2013
12+
from .fgvc_aircraft import FGVCAircraft
1213
from .flickr import Flickr8k, Flickr30k
1314
from .flowers102 import Flowers102
1415
from .folder import ImageFolder, DatasetFolder
@@ -95,4 +96,5 @@
9596
"CLEVRClassification",
9697
"OxfordIIITPet",
9798
"Country211",
99+
"FGVCAircraft",
98100
)

torchvision/datasets/fgvc_aircraft.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from typing import Any, Callable, Optional, Tuple
5+
6+
import PIL.Image
7+
8+
from .utils import download_and_extract_archive, verify_str_arg
9+
from .vision import VisionDataset
10+
11+
12+
class FGVCAircraft(VisionDataset):
13+
"""`FGVC Aircraft <https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.
14+
15+
The dataset contains 10,200 images of aircraft, with 100 images for each of 102
16+
different aircraft model variants, most of which are airplanes.
17+
Aircraft models are organized in a three-levels hierarchy. The three levels, from
18+
finer to coarser, are:
19+
20+
- ``variant``, e.g. Boeing 737-700. A variant collapses all the models that are visually
21+
indistinguishable into one class. The dataset comprises 102 different variants.
22+
- ``family``, e.g. Boeing 737. The dataset comprises 70 different families.
23+
- ``manufacturer``, e.g. Boeing. The dataset comprises 41 different manufacturers.
24+
25+
Args:
26+
root (string): Root directory of the FGVC Aircraft dataset.
27+
split (string, optional): The dataset split, supports ``train``, ``val``,
28+
``trainval`` and ``test``.
29+
download (bool, optional): If True, downloads the dataset from the internet and
30+
puts it in root directory. If dataset is already downloaded, it is not
31+
downloaded again.
32+
annotation_level (str, optional): The annotation level, supports ``variant``,
33+
``family`` and ``manufacturer``.
34+
transform (callable, optional): A function/transform that takes in an PIL image
35+
and returns a transformed version. E.g, ``transforms.RandomCrop``
36+
target_transform (callable, optional): A function/transform that takes in the
37+
target and transforms it.
38+
"""
39+
40+
_URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
41+
42+
def __init__(
43+
self,
44+
root: str,
45+
split: str = "trainval",
46+
download: bool = False,
47+
annotation_level: str = "variant",
48+
transform: Optional[Callable] = None,
49+
target_transform: Optional[Callable] = None,
50+
) -> None:
51+
super().__init__(root, transform=transform, target_transform=target_transform)
52+
self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
53+
self._annotation_level = verify_str_arg(
54+
annotation_level, "annotation_level", ("variant", "family", "manufacturer")
55+
)
56+
57+
self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b")
58+
if download:
59+
self._download()
60+
61+
if not self._check_exists():
62+
raise RuntimeError("Dataset not found. You can use download=True to download it")
63+
64+
annotation_file = os.path.join(
65+
self._data_path,
66+
"data",
67+
{
68+
"variant": "variants.txt",
69+
"family": "families.txt",
70+
"manufacturer": "manufacturers.txt",
71+
}[self._annotation_level],
72+
)
73+
with open(annotation_file, "r") as f:
74+
self.classes = [line.strip() for line in f]
75+
76+
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
77+
78+
image_data_folder = os.path.join(self._data_path, "data", "images")
79+
labels_file = os.path.join(self._data_path, "data", f"images_{self._annotation_level}_{self._split}.txt")
80+
81+
self._image_files = []
82+
self._labels = []
83+
84+
with open(labels_file, "r") as f:
85+
for line in f:
86+
image_name, label_name = line.strip().split(" ", 1)
87+
self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg"))
88+
self._labels.append(self.class_to_idx[label_name])
89+
90+
def __len__(self) -> int:
91+
return len(self._image_files)
92+
93+
def __getitem__(self, idx) -> Tuple[Any, Any]:
94+
image_file, label = self._image_files[idx], self._labels[idx]
95+
image = PIL.Image.open(image_file).convert("RGB")
96+
97+
if self.transform:
98+
image = self.transform(image)
99+
100+
if self.target_transform:
101+
label = self.target_transform(label)
102+
103+
return image, label
104+
105+
def _download(self) -> None:
106+
"""
107+
Download the FGVC Aircraft dataset archive and extract it under root.
108+
"""
109+
if self._check_exists():
110+
return
111+
download_and_extract_archive(self._URL, self.root)
112+
113+
def _check_exists(self) -> bool:
114+
return os.path.exists(self._data_path) and os.path.isdir(self._data_path)

0 commit comments

Comments
 (0)