Skip to content

use enums in prototype datasets for demux #5189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions torchvision/prototype/datasets/_builtin/dtd.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import io
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple
Expand Down Expand Up @@ -30,6 +31,12 @@
from torchvision.prototype.features import Label


class DTDDemux(enum.IntEnum):
SPLIT = 0
JOINT_CATEGORIES = 1
IMAGES = 2


class DTD(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
Expand All @@ -54,11 +61,11 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
if path.parent.name == "labels":
if path.name == "labels_joint_anno.txt":
return 1
return DTDDemux.JOINT_CATEGORIES

return 0
return DTDDemux.SPLIT
elif path.parents[1].name == "images":
return 2
return DTDDemux.IMAGES
else:
return None

Expand Down Expand Up @@ -122,7 +129,7 @@ def _make_datapipe(
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))

def _filter_images(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == 2
return self._classify_archive(data) == DTDDemux.IMAGES

def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
Expand Down
12 changes: 9 additions & 3 deletions torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import functools
import io
import pathlib
Expand All @@ -24,6 +25,11 @@
from torchvision.prototype.features import Label


class OxfordIITPetDemux(enum.IntEnum):
SPLIT_AND_CLASSIFICATION = 0
SEGMENTATIONS = 1


class OxfordIITPet(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
Expand Down Expand Up @@ -51,8 +57,8 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:

def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]:
return {
"annotations": 0,
"trimaps": 1,
"annotations": OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION,
"trimaps": OxfordIITPetDemux.SEGMENTATIONS,
}.get(pathlib.Path(data[0]).parent.name)

def _filter_images(self, data: Tuple[str, Any]) -> bool:
Expand Down Expand Up @@ -135,7 +141,7 @@ def _make_datapipe(
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))

def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool:
return self._classify_anns(data) == 0
return self._classify_anns(data) == OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION

def _generate_categories(self, root: pathlib.Path) -> List[str]:
config = self.default_config
Expand Down