Skip to content

Commit 9e5556c

Browse files
committed
use enums in prototype datasets for demux
1 parent 68f511e commit 9e5556c

File tree

7 files changed

+64
-17
lines changed

7 files changed

+64
-17
lines changed

torchvision/prototype/datasets/_builtin/clevr.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import functools
23
import io
34
import pathlib
@@ -24,6 +25,11 @@
2425
from torchvision.prototype.features import Label
2526

2627

28+
class CLEVRDemux(enum.IntEnum):
29+
IMAGES = 0
30+
SCENES = 1
31+
32+
2733
class CLEVR(Dataset):
2834
def _make_info(self) -> DatasetInfo:
2935
return DatasetInfo(
@@ -43,9 +49,9 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
4349
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
4450
path = pathlib.Path(data[0])
4551
if path.parents[1].name == "images":
46-
return 0
52+
return CLEVRDemux.IMAGES
4753
elif path.parent.name == "scenes":
48-
return 1
54+
return CLEVRDemux.SCENES
4955
else:
5056
return None
5157

torchvision/prototype/datasets/_builtin/coco.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import functools
23
import io
34
import pathlib
@@ -37,6 +38,11 @@
3738
from torchvision.prototype.utils._internal import FrozenMapping
3839

3940

41+
class CocoDemux(enum.IntEnum):
42+
IMAGES_META = 0
43+
ANNS_META = 1
44+
45+
4046
class Coco(Dataset):
4147
def _make_info(self) -> DatasetInfo:
4248
name = "coco"
@@ -144,9 +150,9 @@ def _filter_meta_files(self, data: Tuple[str, Any], *, split: str, year: str, an
144150
def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]:
145151
key, _ = data
146152
if key == "images":
147-
return 0
153+
return CocoDemux.IMAGES_META
148154
elif key == "annotations":
149-
return 1
155+
return CocoDemux.ANNS_META
150156
else:
151157
return None
152158

torchvision/prototype/datasets/_builtin/dtd.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import io
23
import pathlib
34
from typing import Any, Callable, Dict, List, Optional, Tuple
@@ -30,6 +31,12 @@
3031
from torchvision.prototype.features import Label
3132

3233

34+
class DTDDemux(enum.IntEnum):
35+
SPLIT = 0
36+
JOINT_CATEGORIES = 1
37+
IMAGES = 2
38+
39+
3340
class DTD(Dataset):
3441
def _make_info(self) -> DatasetInfo:
3542
return DatasetInfo(
@@ -54,11 +61,11 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
5461
path = pathlib.Path(data[0])
5562
if path.parent.name == "labels":
5663
if path.name == "labels_joint_anno.txt":
57-
return 1
64+
return DTDDemux.JOINT_CATEGORIES
5865

59-
return 0
66+
return DTDDemux.SPLIT
6067
elif path.parents[1].name == "images":
61-
return 2
68+
return DTDDemux.IMAGES
6269
else:
6370
return None
6471

torchvision/prototype/datasets/_builtin/mnist.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import abc
2+
import enum
23
import functools
34
import io
45
import operator
@@ -232,6 +233,11 @@ def _make_info(self) -> DatasetInfo:
232233
}
233234

234235

236+
class EMNISTDemux(enum.IntEnum):
237+
IMAGES = 0
238+
LABELS = 1
239+
240+
235241
class EMNIST(_MNISTBase):
236242
def _make_info(self) -> DatasetInfo:
237243
return DatasetInfo(
@@ -273,9 +279,9 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) ->
273279
path = pathlib.Path(data[0])
274280
(images_file, _), (labels_file, _) = self._files_and_checksums(config)
275281
if path.name == images_file:
276-
return 0
282+
return EMNISTDemux.IMAGES
277283
elif path.name == labels_file:
278-
return 1
284+
return EMNISTDemux.LABELS
279285
else:
280286
return None
281287

@@ -320,6 +326,7 @@ def _make_datapipe(
320326
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
321327
) -> IterDataPipe[Dict[str, Any]]:
322328
archive_dp = resource_dps[0]
329+
323330
images_dp, labels_dp = Demultiplexer(
324331
archive_dp,
325332
2,

torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import functools
23
import io
34
import pathlib
@@ -24,6 +25,11 @@
2425
from torchvision.prototype.features import Label
2526

2627

28+
class OxfordIITPetDemux(enum.IntEnum):
29+
SPLIT_AND_CLASSIFICATION = 0
30+
SEGMENTATIONS = 1
31+
32+
2733
class OxfordIITPet(Dataset):
2834
def _make_info(self) -> DatasetInfo:
2935
return DatasetInfo(
@@ -51,8 +57,8 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
5157

5258
def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]:
5359
return {
54-
"annotations": 0,
55-
"trimaps": 1,
60+
"annotations": OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION,
61+
"trimaps": OxfordIITPetDemux.SEGMENTATIONS,
5662
}.get(pathlib.Path(data[0]).parent.name)
5763

5864
def _filter_images(self, data: Tuple[str, Any]) -> bool:

torchvision/prototype/datasets/_builtin/sbd.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import functools
23
import io
34
import pathlib
@@ -33,6 +34,12 @@
3334
)
3435

3536

37+
class SBDDemux(enum.IntEnum):
38+
SPLIT = 0
39+
IMAGES = 1
40+
ANNS = 2
41+
42+
3643
class SBD(Dataset):
3744
def _make_info(self) -> DatasetInfo:
3845
return DatasetInfo(
@@ -63,12 +70,12 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
6370
parent, grandparent, *_ = path.parents
6471

6572
if parent.name == "dataset":
66-
return 0
73+
return SBDDemux.SPLIT
6774
elif grandparent.name == "dataset":
6875
if parent.name == "img":
69-
return 1
76+
return SBDDemux.IMAGES
7077
elif parent.name == "cls":
71-
return 2
78+
return SBDDemux.ANNS
7279
else:
7380
return None
7481
else:

torchvision/prototype/datasets/_builtin/voc.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import functools
23
import io
34
import pathlib
@@ -31,6 +32,13 @@
3132
hint_shuffling,
3233
)
3334

35+
36+
class VOCDemux(enum.IntEnum):
37+
SPLIT = 0
38+
IMAGES = 1
39+
ANNS = 2
40+
41+
3442
HERE = pathlib.Path(__file__).parent
3543

3644

@@ -75,11 +83,11 @@ def _is_in_folder(self, data: Tuple[str, Any], *, name: str, depth: int = 1) ->
7583

7684
def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> Optional[int]:
7785
if self._is_in_folder(data, name="ImageSets", depth=2):
78-
return 0
86+
return VOCDemux.SPLIT
7987
elif self._is_in_folder(data, name="JPEGImages"):
80-
return 1
88+
return VOCDemux.IMAGES
8189
elif self._is_in_folder(data, name=self._ANNS_FOLDER[config.task]):
82-
return 2
90+
return VOCDemux.ANNS
8391
else:
8492
return None
8593

0 commit comments

Comments
 (0)