Skip to content

Commit 055708d

Browse files
authored
add prototypes for Caltech(101|256) datasets (#4510)
* add prototype for `Caltech256` dataset * silence mypy
1 parent 4bf6086 commit 055708d

File tree

9 files changed

+606
-6
lines changed

9 files changed

+606
-6
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ def run(self):
495495
# Package info
496496
packages=find_packages(exclude=('test',)),
497497
package_data={
498-
package_name: ['*.dll', '*.dylib', '*.so']
498+
package_name: ['*.dll', '*.dylib', '*.so', '*.categories']
499499
},
500500
zip_safe=False,
501501
install_requires=requirements,

torchvision/prototype/datasets/_api.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torchvision.prototype.datasets.decoder import pil
99
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo
1010
from torchvision.prototype.datasets.utils._internal import add_suggestion
11-
11+
from . import _builtin
1212

1313
DATASETS: Dict[str, Dataset] = {}
1414

@@ -17,6 +17,16 @@ def register(dataset: Dataset) -> None:
1717
DATASETS[dataset.name] = dataset
1818

1919

20+
for name, obj in _builtin.__dict__.items():
21+
if (
22+
not name.startswith("_")
23+
and isinstance(obj, type)
24+
and issubclass(obj, Dataset)
25+
and obj is not Dataset
26+
):
27+
register(obj())
28+
29+
2030
# This is exposed as 'list', but we avoid that here to not shadow the built-in 'list'
2131
def _list() -> List[str]:
2232
return sorted(DATASETS.keys())
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .caltech import Caltech101, Caltech256
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import io
2+
import pathlib
3+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4+
import re
5+
6+
import numpy as np
7+
8+
import torch
9+
from torch.utils.data import IterDataPipe
10+
from torch.utils.data.datapipes.iter import (
11+
Mapper,
12+
TarArchiveReader,
13+
Shuffler,
14+
Filter,
15+
)
16+
17+
from torchdata.datapipes.iter import KeyZipper
18+
from torchvision.prototype.datasets.utils import (
19+
Dataset,
20+
DatasetConfig,
21+
DatasetInfo,
22+
HttpResource,
23+
OnlineResource,
24+
)
25+
from torchvision.prototype.datasets.utils._internal import create_categories_file, INFINITE_BUFFER_SIZE, read_mat
26+
27+
HERE = pathlib.Path(__file__).parent
28+
29+
30+
class Caltech101(Dataset):
31+
@property
32+
def info(self) -> DatasetInfo:
33+
return DatasetInfo(
34+
"caltech101",
35+
categories=HERE / "caltech101.categories",
36+
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101",
37+
)
38+
39+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
40+
images = HttpResource(
41+
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
42+
sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926",
43+
)
44+
anns = HttpResource(
45+
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
46+
sha256="1717f4e10aa837b05956e3f4c94456527b143eec0d95e935028b30aff40663d8",
47+
)
48+
return [images, anns]
49+
50+
_IMAGES_NAME_PATTERN = re.compile(r"image_(?P<id>\d+)[.]jpg")
51+
_ANNS_NAME_PATTERN = re.compile(r"annotation_(?P<id>\d+)[.]mat")
52+
_ANNS_CATEGORY_MAP = {
53+
"Faces_2": "Faces",
54+
"Faces_3": "Faces_easy",
55+
"Motorbikes_16": "Motorbikes",
56+
"Airplanes_Side_2": "airplanes",
57+
}
58+
59+
def _is_not_background_image(self, data: Tuple[str, Any]) -> bool:
60+
path = pathlib.Path(data[0])
61+
return path.parent.name != "BACKGROUND_Google"
62+
63+
def _is_ann(self, data: Tuple[str, Any]) -> bool:
64+
path = pathlib.Path(data[0])
65+
return bool(self._ANNS_NAME_PATTERN.match(path.name))
66+
67+
def _images_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]:
68+
path = pathlib.Path(data[0])
69+
70+
category = path.parent.name
71+
id = self._IMAGES_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr]
72+
73+
return category, id
74+
75+
def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]:
76+
path = pathlib.Path(data[0])
77+
78+
category = path.parent.name
79+
if category in self._ANNS_CATEGORY_MAP:
80+
category = self._ANNS_CATEGORY_MAP[category]
81+
82+
id = self._ANNS_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr]
83+
84+
return category, id
85+
86+
def _collate_and_decode_sample(
87+
self, data, *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
88+
) -> Dict[str, Any]:
89+
key, image_data, ann_data = data
90+
category, _ = key
91+
image_path, image_buffer = image_data
92+
ann_path, ann_buffer = ann_data
93+
94+
label = self.info.categories.index(category)
95+
96+
image = decoder(image_buffer) if decoder else image_buffer
97+
98+
ann = read_mat(ann_buffer)
99+
bbox = torch.as_tensor(ann["box_coord"].astype(np.int64))
100+
contour = torch.as_tensor(ann["obj_contour"])
101+
102+
return dict(
103+
category=category,
104+
label=label,
105+
image=image,
106+
image_path=image_path,
107+
bbox=bbox,
108+
contour=contour,
109+
ann_path=ann_path,
110+
)
111+
112+
def _make_datapipe(
113+
self,
114+
resource_dps: List[IterDataPipe],
115+
*,
116+
config: DatasetConfig,
117+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
118+
) -> IterDataPipe[Dict[str, Any]]:
119+
images_dp, anns_dp = resource_dps
120+
121+
images_dp = TarArchiveReader(images_dp)
122+
images_dp = Filter(images_dp, self._is_not_background_image)
123+
# FIXME: add this after https://github.com/pytorch/pytorch/issues/65808 is resolved
124+
# images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
125+
126+
anns_dp = TarArchiveReader(anns_dp)
127+
anns_dp = Filter(anns_dp, self._is_ann)
128+
129+
dp = KeyZipper(
130+
images_dp,
131+
anns_dp,
132+
key_fn=self._images_key_fn,
133+
ref_key_fn=self._anns_key_fn,
134+
buffer_size=INFINITE_BUFFER_SIZE,
135+
keep_key=True,
136+
)
137+
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
138+
139+
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
140+
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
141+
dp = TarArchiveReader(dp)
142+
dp = Filter(dp, self._is_not_background_image)
143+
dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
144+
create_categories_file(HERE, self.name, sorted(dir_names))
145+
146+
147+
class Caltech256(Dataset):
148+
@property
149+
def info(self) -> DatasetInfo:
150+
return DatasetInfo(
151+
"caltech256",
152+
categories=HERE / "caltech256.categories",
153+
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256",
154+
)
155+
156+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
157+
return [
158+
HttpResource(
159+
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
160+
sha256="08ff01b03c65566014ae88eb0490dbe4419fc7ac4de726ee1163e39fd809543e",
161+
)
162+
]
163+
164+
def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool:
165+
path = pathlib.Path(data[0])
166+
return path.name != "RENAME2"
167+
168+
def _collate_and_decode_sample(
169+
self,
170+
data: Tuple[str, io.IOBase],
171+
*,
172+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
173+
) -> Dict[str, Any]:
174+
path, buffer = data
175+
176+
dir_name = pathlib.Path(path).parent.name
177+
label_str, category = dir_name.split(".")
178+
label = torch.tensor(int(label_str))
179+
180+
return dict(label=label, category=category, image=decoder(buffer) if decoder else buffer)
181+
182+
def _make_datapipe(
183+
self,
184+
resource_dps: List[IterDataPipe],
185+
*,
186+
config: DatasetConfig,
187+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
188+
) -> IterDataPipe[Dict[str, Any]]:
189+
dp = resource_dps[0]
190+
dp = TarArchiveReader(dp)
191+
dp = Filter(dp, self._is_not_rogue_file)
192+
# FIXME: add this after https://github.com/pytorch/pytorch/issues/65808 is resolved
193+
# dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
194+
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
195+
196+
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
197+
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
198+
dp = TarArchiveReader(dp)
199+
dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
200+
categories = [name.split(".")[1] for name in sorted(dir_names)]
201+
create_categories_file(HERE, self.name, categories)
202+
203+
204+
if __name__ == "__main__":
205+
from torchvision.prototype.datasets import home
206+
207+
root = home()
208+
Caltech101().generate_categories_file(root)
209+
Caltech256().generate_categories_file(root)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
Faces
2+
Faces_easy
3+
Leopards
4+
Motorbikes
5+
accordion
6+
airplanes
7+
anchor
8+
ant
9+
barrel
10+
bass
11+
beaver
12+
binocular
13+
bonsai
14+
brain
15+
brontosaurus
16+
buddha
17+
butterfly
18+
camera
19+
cannon
20+
car_side
21+
ceiling_fan
22+
cellphone
23+
chair
24+
chandelier
25+
cougar_body
26+
cougar_face
27+
crab
28+
crayfish
29+
crocodile
30+
crocodile_head
31+
cup
32+
dalmatian
33+
dollar_bill
34+
dolphin
35+
dragonfly
36+
electric_guitar
37+
elephant
38+
emu
39+
euphonium
40+
ewer
41+
ferry
42+
flamingo
43+
flamingo_head
44+
garfield
45+
gerenuk
46+
gramophone
47+
grand_piano
48+
hawksbill
49+
headphone
50+
hedgehog
51+
helicopter
52+
ibis
53+
inline_skate
54+
joshua_tree
55+
kangaroo
56+
ketch
57+
lamp
58+
laptop
59+
llama
60+
lobster
61+
lotus
62+
mandolin
63+
mayfly
64+
menorah
65+
metronome
66+
minaret
67+
nautilus
68+
octopus
69+
okapi
70+
pagoda
71+
panda
72+
pigeon
73+
pizza
74+
platypus
75+
pyramid
76+
revolver
77+
rhino
78+
rooster
79+
saxophone
80+
schooner
81+
scissors
82+
scorpion
83+
sea_horse
84+
snoopy
85+
soccer_ball
86+
stapler
87+
starfish
88+
stegosaurus
89+
stop_sign
90+
strawberry
91+
sunflower
92+
tick
93+
trilobite
94+
umbrella
95+
watch
96+
water_lilly
97+
wheelchair
98+
wild_cat
99+
windsor_chair
100+
wrench
101+
yin_yang

0 commit comments

Comments
 (0)