-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Replace torchvision.datasets.utils with functionality from torchdata #6060
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,26 @@ | ||
from .caltech import Caltech101, Caltech256 | ||
from .celeba import CelebA | ||
|
||
# from .celeba import CelebA | ||
from .cifar import Cifar10, Cifar100 | ||
from .clevr import CLEVR | ||
from .coco import Coco | ||
from .country211 import Country211 | ||
from .cub200 import CUB200 | ||
from .dtd import DTD | ||
from .eurosat import EuroSAT | ||
|
||
# from .clevr import CLEVR | ||
# from .coco import Coco | ||
# from .country211 import Country211 | ||
# from .cub200 import CUB200 | ||
# from .dtd import DTD | ||
# from .eurosat import EuroSAT | ||
from .fer2013 import FER2013 | ||
from .food101 import Food101 | ||
from .gtsrb import GTSRB | ||
|
||
# from .food101 import Food101 | ||
# from .gtsrb import GTSRB | ||
from .imagenet import ImageNet | ||
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST | ||
from .oxford_iiit_pet import OxfordIIITPet | ||
from .pcam import PCAM | ||
from .sbd import SBD | ||
from .semeion import SEMEION | ||
from .stanford_cars import StanfordCars | ||
from .svhn import SVHN | ||
from .usps import USPS | ||
from .voc import VOC | ||
|
||
# from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST | ||
# from .oxford_iiit_pet import OxfordIIITPet | ||
# from .pcam import PCAM | ||
# from .sbd import SBD | ||
# from .semeion import SEMEION | ||
# from .stanford_cars import StanfordCars | ||
# from .svhn import SVHN | ||
# from .usps import USPS | ||
# from .voc import VOC |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,11 +14,7 @@ | |
TarArchiveLoader, | ||
Enumerator, | ||
) | ||
from torchvision.prototype.datasets.utils import ( | ||
OnlineResource, | ||
ManualDownloadResource, | ||
Dataset, | ||
) | ||
from torchvision.prototype.datasets.utils import ManualDownloadResource, Dataset | ||
from torchvision.prototype.datasets.utils._internal import ( | ||
INFINITE_BUFFER_SIZE, | ||
getitem, | ||
|
@@ -41,11 +37,6 @@ def _info() -> Dict[str, Any]: | |
return dict(categories=categories, wnids=wnids) | ||
|
||
|
||
class ImageNetResource(ManualDownloadResource): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of subclassing just to parametrize I followed the approach in #6052. |
||
def __init__(self, **kwargs: Any) -> None: | ||
super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs) | ||
|
||
|
||
class ImageNetDemux(enum.IntEnum): | ||
META = 0 | ||
LABEL = 1 | ||
|
@@ -80,16 +71,22 @@ def __init__( | |
"test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4", | ||
} | ||
|
||
def _resources(self) -> List[OnlineResource]: | ||
name = "test_v10102019" if self._split == "test" else self._split | ||
images = ImageNetResource( | ||
file_name=f"ILSVRC2012_img_{name}.tar", | ||
sha256=self._IMAGES_CHECKSUMS[name], | ||
def _imagenet_resource(self, *, file_name: str, sha256: str) -> ManualDownloadResource: | ||
return ManualDownloadResource( | ||
"https://image-net.org/", | ||
instructions="Register on https://image-net.org/ and follow the instructions there.", | ||
file_name=file_name, | ||
sha256=sha256, | ||
) | ||
resources: List[OnlineResource] = [images] | ||
|
||
def _resources(self) -> List[ManualDownloadResource]: | ||
name = "test_v10102019" if self._split == "test" else self._split | ||
images = self._imagenet_resource(file_name=f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name]) | ||
|
||
resources = [images] | ||
|
||
if self._split == "val": | ||
devkit = ImageNetResource( | ||
devkit = self._imagenet_resource( | ||
file_name="ILSVRC2012_devkit_t12.tar.gz", | ||
sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953", | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from . import _internal # usort: skip | ||
from ._dataset import Dataset | ||
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource | ||
from ._resource import OnlineResource, ManualDownloadResource |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: | |
yield from self._dp | ||
|
||
@abc.abstractmethod | ||
def _resources(self) -> List[OnlineResource]: | ||
def _resources(self) -> Sequence[OnlineResource]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This gets rid of the annoying |
||
pass | ||
|
||
@abc.abstractmethod | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only made changes to four datasets to cover all cases, i.e.
The others are commented out to be able to import
torchvision
.