diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 31d0aadc64d..94305952568 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1327,6 +1327,24 @@ def cub200(info, root, config): return num_samples_map[config.split] +@register_mock +def eurosat(info, root, config): + data_folder = pathlib.Path(root, "eurosat", "2750") + data_folder.mkdir(parents=True) + + num_examples_per_class = 3 + classes = ("AnnualCrop", "Forest") + for cls in classes: + create_image_folder( + root=data_folder, + name=cls, + file_name_fn=lambda idx: f"{cls}_{idx}.jpg", + num_examples=num_examples_per_class, + ) + make_zip(root, "EuroSAT.zip", data_folder) + return len(classes) * num_examples_per_class + + @register_mock def svhn(info, root, config): import scipy.io as sio diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index d57b5555727..1567ef29811 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -6,6 +6,7 @@ from .country211 import Country211 from .cub200 import CUB200 from .dtd import DTD +from .eurosat import EuroSAT from .fer2013 import FER2013 from .gtsrb import GTSRB from .imagenet import ImageNet diff --git a/torchvision/prototype/datasets/_builtin/eurosat.py b/torchvision/prototype/datasets/_builtin/eurosat.py new file mode 100644 index 00000000000..fdbba077669 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/eurosat.py @@ -0,0 +1,51 @@ +import pathlib +from typing import Any, Dict, List, Tuple + +from torchdata.datapipes.iter import IterDataPipe, Mapper +from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling +from torchvision.prototype.features import EncodedImage, Label + + +class EuroSAT(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "eurosat", + homepage="https://github.com/phelber/eurosat", + categories=( + "AnnualCrop", + "Forest", + "HerbaceousVegetation", + "Highway", + "Industrial," "Pasture", + "PermanentCrop", + "Residential", + "River", + "SeaLake", + ), + ) + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + return [ + HttpResource( + "https://madm.dfki.de/files/sentinel/EuroSAT.zip", + sha256="8ebea626349354c5328b142b96d0430e647051f26efc2dc974c843f25ecf70bd", + ) + ] + + def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: + path, buffer = data + category = pathlib.Path(path).parent.name + return dict( + label=Label.from_category(category, categories=self.categories), + path=path, + image=EncodedImage.from_file(buffer), + ) + + def _make_datapipe( + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig + ) -> IterDataPipe[Dict[str, Any]]: + dp = resource_dps[0] + dp = hint_sharding(dp) + dp = hint_shuffling(dp) + return Mapper(dp, self._prepare_sample)