Skip to content

Replace torchvision.datasets.utils with functionality from torchdata #6060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
from .caltech import Caltech101, Caltech256
from .celeba import CelebA

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only made changes to four datasets to cover all cases, i.e.

  • HTTP download (CIFAR)
  • GDrive download (Caltech)
  • Kaggle download (FER2013)
  • Manual download (ImageNet)

The others are commented out to be able to import torchvision.

# from .celeba import CelebA
from .cifar import Cifar10, Cifar100
from .clevr import CLEVR
from .coco import Coco
from .country211 import Country211
from .cub200 import CUB200
from .dtd import DTD
from .eurosat import EuroSAT

# from .clevr import CLEVR
# from .coco import Coco
# from .country211 import Country211
# from .cub200 import CUB200
# from .dtd import DTD
# from .eurosat import EuroSAT
from .fer2013 import FER2013
from .food101 import Food101
from .gtsrb import GTSRB

# from .food101 import Food101
# from .gtsrb import GTSRB
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM
from .sbd import SBD
from .semeion import SEMEION
from .stanford_cars import StanfordCars
from .svhn import SVHN
from .usps import USPS
from .voc import VOC

# from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
# from .oxford_iiit_pet import OxfordIIITPet
# from .pcam import PCAM
# from .sbd import SBD
# from .semeion import SEMEION
# from .stanford_cars import StanfordCars
# from .svhn import SVHN
# from .usps import USPS
# from .voc import VOC
8 changes: 4 additions & 4 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Filter,
IterKeyZipper,
)
from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils import Dataset, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
read_mat,
Expand Down Expand Up @@ -49,13 +49,13 @@ def __init__(
)

def _resources(self) -> List[OnlineResource]:
images = GDriveResource(
images = OnlineResource.from_gdrive(
"137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp",
file_name="101_ObjectCategories.tar.gz",
sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926",
preprocess="decompress",
)
anns = GDriveResource(
anns = OnlineResource.from_gdrive(
"175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m",
file_name="Annotations.tar",
sha256="1717f4e10aa837b05956e3f4c94456527b143eec0d95e935028b30aff40663d8",
Expand Down Expand Up @@ -173,7 +173,7 @@ def __init__(

def _resources(self) -> List[OnlineResource]:
return [
GDriveResource(
OnlineResource.from_gdrive(
"1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
file_name="256_ObjectCategories.tar",
sha256="08ff01b03c65566014ae88eb0490dbe4419fc7ac4de726ee1163e39fd809543e",
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Filter,
Mapper,
)
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils import Dataset, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_shuffling,
path_comparator,
Expand Down Expand Up @@ -58,7 +58,7 @@ def _is_data_file(self, data: Tuple[str, BinaryIO]) -> Optional[int]:

def _resources(self) -> List[OnlineResource]:
return [
HttpResource(
OnlineResource.from_http(
f"https://www.cs.toronto.edu/~kriz/{self._FILE_NAME}",
sha256=self._SHA256,
)
Expand Down
10 changes: 3 additions & 7 deletions torchvision/prototype/datasets/_builtin/fer2013.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@

import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser
from torchvision.prototype.datasets.utils import (
Dataset,
OnlineResource,
KaggleDownloadResource,
)
from torchvision.prototype.datasets.utils import Dataset, ManualDownloadResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
Expand Down Expand Up @@ -43,8 +39,8 @@ def __init__(
"test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3",
}

def _resources(self) -> List[OnlineResource]:
archive = KaggleDownloadResource(
def _resources(self) -> List[ManualDownloadResource]:
archive = ManualDownloadResource.from_kaggle(
"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge",
file_name=f"{self._split}.csv.zip",
sha256=self._CHECKSUMS[self._split],
Expand Down
31 changes: 14 additions & 17 deletions torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@
TarArchiveLoader,
Enumerator,
)
from torchvision.prototype.datasets.utils import (
OnlineResource,
ManualDownloadResource,
Dataset,
)
from torchvision.prototype.datasets.utils import ManualDownloadResource, Dataset
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
getitem,
Expand All @@ -41,11 +37,6 @@ def _info() -> Dict[str, Any]:
return dict(categories=categories, wnids=wnids)


class ImageNetResource(ManualDownloadResource):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of subclassing just to parametrize I followed the approach in #6052.

def __init__(self, **kwargs: Any) -> None:
super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs)


class ImageNetDemux(enum.IntEnum):
META = 0
LABEL = 1
Expand Down Expand Up @@ -80,16 +71,22 @@ def __init__(
"test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4",
}

def _resources(self) -> List[OnlineResource]:
name = "test_v10102019" if self._split == "test" else self._split
images = ImageNetResource(
file_name=f"ILSVRC2012_img_{name}.tar",
sha256=self._IMAGES_CHECKSUMS[name],
def _imagenet_resource(self, *, file_name: str, sha256: str) -> ManualDownloadResource:
return ManualDownloadResource(
"https://image-net.org/",
instructions="Register on https://image-net.org/ and follow the instructions there.",
file_name=file_name,
sha256=sha256,
)
resources: List[OnlineResource] = [images]

def _resources(self) -> List[ManualDownloadResource]:
name = "test_v10102019" if self._split == "test" else self._split
images = self._imagenet_resource(file_name=f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name])

resources = [images]

if self._split == "val":
devkit = ImageNetResource(
devkit = self._imagenet_resource(
file_name="ILSVRC2012_devkit_t12.tar.gz",
sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953",
)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import _internal # usort: skip
from ._dataset import Dataset
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource
from ._resource import OnlineResource, ManualDownloadResource
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/utils/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]:
yield from self._dp

@abc.abstractmethod
def _resources(self) -> List[OnlineResource]:
def _resources(self) -> Sequence[OnlineResource]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gets rid of the annoying mypy warning when returning List[ManualDownloadResource] Instead of List[OnlineResource].

pass

@abc.abstractmethod
Expand Down
Loading