Skip to content

Commit d743c79

Browse files
Vincent Moensfacebook-github-bot
Vincent Moens
authored andcommitted
[fbsync] use enums in prototype datasets for demux (#5189)
Summary: * use enums in prototype datasets for demux * use enum for category generation * revert enum usage for single use constants Reviewed By: NicolasHug Differential Revision: D33618173 fbshipit-source-id: a4ab9349905806f2cd0c701c4b59bc1ab0ad14ae
1 parent cf9ee41 commit d743c79

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

torchvision/prototype/datasets/_builtin/dtd.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import io
23
import pathlib
34
from typing import Any, Callable, Dict, List, Optional, Tuple
@@ -30,6 +31,12 @@
3031
from torchvision.prototype.features import Label
3132

3233

34+
class DTDDemux(enum.IntEnum):
35+
SPLIT = 0
36+
JOINT_CATEGORIES = 1
37+
IMAGES = 2
38+
39+
3340
class DTD(Dataset):
3441
def _make_info(self) -> DatasetInfo:
3542
return DatasetInfo(
@@ -54,11 +61,11 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
5461
path = pathlib.Path(data[0])
5562
if path.parent.name == "labels":
5663
if path.name == "labels_joint_anno.txt":
57-
return 1
64+
return DTDDemux.JOINT_CATEGORIES
5865

59-
return 0
66+
return DTDDemux.SPLIT
6067
elif path.parents[1].name == "images":
61-
return 2
68+
return DTDDemux.IMAGES
6269
else:
6370
return None
6471

@@ -122,7 +129,7 @@ def _make_datapipe(
122129
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
123130

124131
def _filter_images(self, data: Tuple[str, Any]) -> bool:
125-
return self._classify_archive(data) == 2
132+
return self._classify_archive(data) == DTDDemux.IMAGES
126133

127134
def _generate_categories(self, root: pathlib.Path) -> List[str]:
128135
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)

torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import functools
23
import io
34
import pathlib
@@ -24,6 +25,11 @@
2425
from torchvision.prototype.features import Label
2526

2627

28+
class OxfordIITPetDemux(enum.IntEnum):
29+
SPLIT_AND_CLASSIFICATION = 0
30+
SEGMENTATIONS = 1
31+
32+
2733
class OxfordIITPet(Dataset):
2834
def _make_info(self) -> DatasetInfo:
2935
return DatasetInfo(
@@ -51,8 +57,8 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
5157

5258
def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]:
5359
return {
54-
"annotations": 0,
55-
"trimaps": 1,
60+
"annotations": OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION,
61+
"trimaps": OxfordIITPetDemux.SEGMENTATIONS,
5662
}.get(pathlib.Path(data[0]).parent.name)
5763

5864
def _filter_images(self, data: Tuple[str, Any]) -> bool:
@@ -135,7 +141,7 @@ def _make_datapipe(
135141
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
136142

137143
def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool:
138-
return self._classify_anns(data) == 0
144+
return self._classify_anns(data) == OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION
139145

140146
def _generate_categories(self, root: pathlib.Path) -> List[str]:
141147
config = self.default_config

0 commit comments

Comments
 (0)