From f5c291be5137b99ca735f24d2fa47e764cb476ab Mon Sep 17 00:00:00 2001 From: frgfm Date: Sun, 19 Dec 2021 16:09:32 +0100 Subject: [PATCH 01/16] feat: Added EuroSAT dataset --- torchvision/datasets/__init__.py | 2 + torchvision/datasets/eurosat.py | 76 ++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 torchvision/datasets/eurosat.py diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 80859791004..36d46056e2b 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -4,6 +4,7 @@ from .cifar import CIFAR10, CIFAR100 from .cityscapes import Cityscapes from .coco import CocoCaptions, CocoDetection +from .eurosat import EuroSAT from .fakedata import FakeData from .flickr import Flickr8k, Flickr30k from .folder import ImageFolder, DatasetFolder @@ -77,4 +78,5 @@ "FlyingChairs", "FlyingThings3D", "HD1K", + "EuroSAT", ) diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py new file mode 100644 index 00000000000..80154cecbfb --- /dev/null +++ b/torchvision/datasets/eurosat.py @@ -0,0 +1,76 @@ +import os +from typing import Any + +from .folder import ImageFolder +from .utils import download_and_extract_archive, check_integrity + + +class EuroSAT(ImageFolder): + """`EuroSAT `_ Dataset. + + Args: + root (string): Root directory of dataset where ``EuroSAT.zip`` exists. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + url = "https://madm.dfki.de/files/sentinel/EuroSAT.zip" + md5 = "c8fa014336c82ac7804f0398fcb19387" + filename = "EuroSAT.zip" + + _classes = [ + "Annual Crop", + "Forest", + "Herbaceous Vegetation", + "Highway", + "Industrial Buildings", + "Pasture", + "Permanent Crop", + "Residential Buildings", + "River", + "Sea & Lake", + ] + + def __init__( + self, + root: str, + download: bool = False, + **kwargs: Any, + ) -> None: + self.root = os.path.expanduser(root) + + # Download + if download: + self.download() + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + # ImageFolder + super().__init__(os.path.join(self.data_folder, "2750"), **kwargs) + self.classes = self._classes + self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} + + def __len__(self) -> int: + return len(self.data) + + @property + def data_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__) + + def _check_exists(self) -> bool: + return check_integrity(os.path.join(self.data_folder, self.filename)) + + def download(self) -> None: + """Download the EuroSAT data if it doesn't exist already.""" + + if self._check_exists(): + return + + os.makedirs(self.data_folder, exist_ok=True) + print(f"Downloading {self.url}") + download_and_extract_archive(self.url, download_root=self.data_folder, filename=self.filename, md5=self.md5) From 1619c6a3eb3cbbeead99e4a07d88d9d4cfc1bcfb Mon Sep 17 00:00:00 2001 From: frgfm Date: Sun, 19 Dec 2021 16:18:19 +0100 Subject: [PATCH 02/16] test: Added unittest --- test/test_datasets.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/test_datasets.py b/test/test_datasets.py index 761f11d77dc..856538afc28 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2168,5 +2168,23 @@ def inject_fake_data(self, tmpdir, config): return num_sequences * (num_examples_per_sequence - 1) +class EuroSATTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.EuroSAT + + def inject_fake_data(self, tmpdir, config): + tmpdir = pathlib.Path(tmpdir) + + wnid = "AnnualCrop" + num_examples = 3 + datasets_utils.create_image_folder( + root=tmpdir, + name=tmpdir / wnid / wnid, + file_name_fn=lambda image_idx: f"{wnid}_{image_idx}.JPEG", + num_examples=num_examples, + ) + + return num_examples + + if __name__ == "__main__": unittest.main() From cca559ac96a7ba483280f1990865d276fd54e650 Mon Sep 17 00:00:00 2001 From: frgfm Date: Sun, 19 Dec 2021 16:21:31 +0100 Subject: [PATCH 03/16] docs: Improved comments --- torchvision/datasets/eurosat.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index 80154cecbfb..651964f6c04 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -6,7 +6,7 @@ class EuroSAT(ImageFolder): - """`EuroSAT `_ Dataset. + """RGB version of the `EuroSAT `_ Dataset. Args: root (string): Root directory of dataset where ``EuroSAT.zip`` exists. @@ -44,9 +44,10 @@ def __init__( ) -> None: self.root = os.path.expanduser(root) - # Download + # Download the dataset if download: self.download() + if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") From ad3adcb711d80eaad19bdfedc0c7ee280924880f Mon Sep 17 00:00:00 2001 From: frgfm Date: Sun, 19 Dec 2021 16:24:10 +0100 Subject: [PATCH 04/16] docs: Updated the documentation --- docs/source/datasets.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 7f09ff245ca..d0ad7c5089d 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -39,6 +39,7 @@ You can also create your own datasets using the provided :ref:`base classes Date: Sun, 19 Dec 2021 17:27:38 +0100 Subject: [PATCH 05/16] docs: Removed unnecessary comments --- torchvision/datasets/eurosat.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index 651964f6c04..cd3b05dc60f 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -44,14 +44,12 @@ def __init__( ) -> None: self.root = os.path.expanduser(root) - # Download the dataset if download: self.download() if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - # ImageFolder super().__init__(os.path.join(self.data_folder, "2750"), **kwargs) self.classes = self._classes self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} @@ -67,7 +65,6 @@ def _check_exists(self) -> bool: return check_integrity(os.path.join(self.data_folder, self.filename)) def download(self) -> None: - """Download the EuroSAT data if it doesn't exist already.""" if self._check_exists(): return From 58a2423bf8354c32bcf43daebbc7a70a996af8c1 Mon Sep 17 00:00:00 2001 From: frgfm Date: Sun, 19 Dec 2021 17:27:51 +0100 Subject: [PATCH 06/16] fix: Fixed class implementation --- torchvision/datasets/eurosat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index cd3b05dc60f..b05464e9342 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -53,13 +53,14 @@ def __init__( super().__init__(os.path.join(self.data_folder, "2750"), **kwargs) self.classes = self._classes self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} + self.root = os.path.expanduser(root) def __len__(self) -> int: return len(self.data) @property def data_folder(self) -> str: - return os.path.join(self.root, self.__class__.__name__) + return os.path.join(self.root, self.__class__.__name__.lower()) def _check_exists(self) -> bool: return check_integrity(os.path.join(self.data_folder, self.filename)) From 4de1d31c476d6e81124bc577aa90fd4d081adda1 Mon Sep 17 00:00:00 2001 From: frgfm Date: Sun, 19 Dec 2021 17:27:59 +0100 Subject: [PATCH 07/16] test: Fixed unittest --- test/test_datasets.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 856538afc28..54af5246a17 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2174,12 +2174,12 @@ class EuroSATTestCase(datasets_utils.ImageDatasetTestCase): def inject_fake_data(self, tmpdir, config): tmpdir = pathlib.Path(tmpdir) - wnid = "AnnualCrop" + category = "AnnualCrop" num_examples = 3 datasets_utils.create_image_folder( root=tmpdir, - name=tmpdir / wnid / wnid, - file_name_fn=lambda image_idx: f"{wnid}_{image_idx}.JPEG", + name=tmpdir / "2750" / category, + file_name_fn=lambda image_idx: f"{category}_{image_idx}.JPEG", num_examples=num_examples, ) From 6993c36b0759f4ef53cb9b4983e27869ea0adb85 Mon Sep 17 00:00:00 2001 From: frgfm Date: Mon, 20 Dec 2021 13:21:19 +0100 Subject: [PATCH 08/16] fix: Fixed magic method len --- torchvision/datasets/eurosat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index b05464e9342..caf18b21b4c 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -56,7 +56,7 @@ def __init__( self.root = os.path.expanduser(root) def __len__(self) -> int: - return len(self.data) + return len(self.samples) @property def data_folder(self) -> str: From fe051cda1d3558c92c46e375f8834b9cc537d6f4 Mon Sep 17 00:00:00 2001 From: frgfm Date: Mon, 20 Dec 2021 13:21:29 +0100 Subject: [PATCH 09/16] test: Fixed unittest --- test/test_datasets.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 54af5246a17..eb9c7a99e2b 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2170,20 +2170,23 @@ def inject_fake_data(self, tmpdir, config): class EuroSATTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.EuroSAT + FEATURE_TYPES = (PIL.Image.Image, int) def inject_fake_data(self, tmpdir, config): - tmpdir = pathlib.Path(tmpdir) + img_folder = os.path.join(tmpdir, "eurosat", "2750") + os.makedirs(img_folder) - category = "AnnualCrop" - num_examples = 3 - datasets_utils.create_image_folder( - root=tmpdir, - name=tmpdir / "2750" / category, - file_name_fn=lambda image_idx: f"{category}_{image_idx}.JPEG", - num_examples=num_examples, - ) + num_examples_per_class = 3 + classes = ("AnnualCrop", "Forest") + for cls in classes: + datasets_utils.create_image_folder( + root=img_folder, + name=cls, + file_name_fn=lambda idx: f"{cls}_{idx}.jpg", + num_examples=num_examples_per_class, + ) - return num_examples + return len(classes) * num_examples_per_class if __name__ == "__main__": From 69d5b786f2648751b5f2ce1413b3730a6349a9a6 Mon Sep 17 00:00:00 2001 From: frgfm Date: Mon, 20 Dec 2021 13:29:27 +0100 Subject: [PATCH 10/16] refactor: Refactored EuroSAT --- torchvision/datasets/eurosat.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index caf18b21b4c..71f8cf9cf29 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -23,18 +23,18 @@ class EuroSAT(ImageFolder): md5 = "c8fa014336c82ac7804f0398fcb19387" filename = "EuroSAT.zip" - _classes = [ - "Annual Crop", - "Forest", - "Herbaceous Vegetation", - "Highway", - "Industrial Buildings", - "Pasture", - "Permanent Crop", - "Residential Buildings", - "River", - "Sea & Lake", - ] + _class_map = { + "AnnualCrop": "Annual Crop", + "Forest": "Forest", + "HerbaceousVegetation": "Herbaceous Vegetation", + "Highway": "Highway", + "Industrial": "Industrial Buildings", + "Pasture": "Pasture", + "PermanentCrop": "Permanent Crop", + "Residential": "Residential Buildings", + "River": "River", + "SeaLake": "Sea & Lake", + } def __init__( self, @@ -51,8 +51,7 @@ def __init__( raise RuntimeError("Dataset not found. You can use download=True to download it") super().__init__(os.path.join(self.data_folder, "2750"), **kwargs) - self.classes = self._classes - self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} + self.classes = [self._class_map[cls] for cls in self.classes] self.root = os.path.expanduser(root) def __len__(self) -> int: From cbf79e1242a2df71e8a9b0692ae370e85d918ecb Mon Sep 17 00:00:00 2001 From: frgfm Date: Mon, 20 Dec 2021 19:29:45 +0100 Subject: [PATCH 11/16] refactor: Applied modifications --- torchvision/datasets/eurosat.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index 71f8cf9cf29..93fd7908136 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -21,18 +21,13 @@ class EuroSAT(ImageFolder): url = "https://madm.dfki.de/files/sentinel/EuroSAT.zip" md5 = "c8fa014336c82ac7804f0398fcb19387" - filename = "EuroSAT.zip" _class_map = { "AnnualCrop": "Annual Crop", - "Forest": "Forest", "HerbaceousVegetation": "Herbaceous Vegetation", - "Highway": "Highway", "Industrial": "Industrial Buildings", - "Pasture": "Pasture", "PermanentCrop": "Permanent Crop", "Residential": "Residential Buildings", - "River": "River", "SeaLake": "Sea & Lake", } @@ -50,25 +45,28 @@ def __init__( if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - super().__init__(os.path.join(self.data_folder, "2750"), **kwargs) - self.classes = [self._class_map[cls] for cls in self.classes] + super().__init__(self._data_folder, **kwargs) + self.classes = [self._class_map.get(cls, cls) for cls in self.classes] self.root = os.path.expanduser(root) def __len__(self) -> int: return len(self.samples) @property - def data_folder(self) -> str: + def _base_folder(self) -> str: return os.path.join(self.root, self.__class__.__name__.lower()) + @property + def _data_folder(self) -> str: + return os.path.join(self._base_folder, "2750") + def _check_exists(self) -> bool: - return check_integrity(os.path.join(self.data_folder, self.filename)) + return os.path.exists(self._data_folder) def download(self) -> None: if self._check_exists(): return - os.makedirs(self.data_folder, exist_ok=True) - print(f"Downloading {self.url}") - download_and_extract_archive(self.url, download_root=self.data_folder, filename=self.filename, md5=self.md5) + os.makedirs(self._base_folder, exist_ok=True) + download_and_extract_archive(self.url, download_root=self._base_folder, md5=self.md5) From c823d1e77b79e29197f4fa942b9d9becfa4bf355 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 21 Dec 2021 08:01:05 +0100 Subject: [PATCH 12/16] Apply suggestions from code review --- test/test_datasets.py | 6 +++--- torchvision/datasets/eurosat.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index eb9c7a99e2b..4e53f206c44 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2173,14 +2173,14 @@ class EuroSATTestCase(datasets_utils.ImageDatasetTestCase): FEATURE_TYPES = (PIL.Image.Image, int) def inject_fake_data(self, tmpdir, config): - img_folder = os.path.join(tmpdir, "eurosat", "2750") - os.makedirs(img_folder) + data_folder = os.path.join(tmpdir, "eurosat", "2750") + os.makedirs(data_folder) num_examples_per_class = 3 classes = ("AnnualCrop", "Forest") for cls in classes: datasets_utils.create_image_folder( - root=img_folder, + root=data_folder, name=cls, file_name_fn=lambda idx: f"{cls}_{idx}.jpg", num_examples=num_examples_per_class, diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index 93fd7908136..2785094da68 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -2,7 +2,7 @@ from typing import Any from .folder import ImageFolder -from .utils import download_and_extract_archive, check_integrity +from .utils import download_and_extract_archive class EuroSAT(ImageFolder): From e6f59fd2ee431d5db1941d25cbd4fe9aa5196d55 Mon Sep 17 00:00:00 2001 From: frgfm Date: Thu, 6 Jan 2022 01:09:21 +0100 Subject: [PATCH 13/16] refactor: Applied request changes --- torchvision/datasets/eurosat.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index 2785094da68..959b8896d7d 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -6,13 +6,13 @@ class EuroSAT(ImageFolder): - """RGB version of the `EuroSAT `_ Dataset. + """RGB version of the `EuroSAT `_ Dataset. Args: - root (string): Root directory of dataset where ``EuroSAT.zip`` exists. + root (string): Root directory of dataset where ``root/eurosat`` exists. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not - downloaded again. + downloaded again. Default is False. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the @@ -45,6 +45,8 @@ def __init__( if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") + self._base_folder = os.path.join(self.root, self.__class__.__name__.lower()) + self._data_folder = os.path.join(self._base_folder, "2750") super().__init__(self._data_folder, **kwargs) self.classes = [self._class_map.get(cls, cls) for cls in self.classes] self.root = os.path.expanduser(root) @@ -52,14 +54,6 @@ def __init__( def __len__(self) -> int: return len(self.samples) - @property - def _base_folder(self) -> str: - return os.path.join(self.root, self.__class__.__name__.lower()) - - @property - def _data_folder(self) -> str: - return os.path.join(self._base_folder, "2750") - def _check_exists(self) -> bool: return os.path.exists(self._data_folder) From f3669701df4473b320f9c485dab462bd35454542 Mon Sep 17 00:00:00 2001 From: frgfm Date: Thu, 6 Jan 2022 13:18:58 +0100 Subject: [PATCH 14/16] refactor: Made var explicit --- torchvision/datasets/eurosat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index 959b8896d7d..fe0e303c279 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -45,7 +45,7 @@ def __init__( if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - self._base_folder = os.path.join(self.root, self.__class__.__name__.lower()) + self._base_folder = os.path.join(self.root, "eurosat") self._data_folder = os.path.join(self._base_folder, "2750") super().__init__(self._data_folder, **kwargs) self.classes = [self._class_map.get(cls, cls) for cls in self.classes] From ae050946c9c0c7d01ac07c7da3e5a97d96e5077b Mon Sep 17 00:00:00 2001 From: frgfm Date: Thu, 6 Jan 2022 14:48:48 +0100 Subject: [PATCH 15/16] fix: Fixed attribute initialization order --- torchvision/datasets/eurosat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index fe0e303c279..17ea030a252 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -38,6 +38,8 @@ def __init__( **kwargs: Any, ) -> None: self.root = os.path.expanduser(root) + self._base_folder = os.path.join(self.root, "eurosat") + self._data_folder = os.path.join(self._base_folder, "2750") if download: self.download() @@ -45,8 +47,6 @@ def __init__( if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - self._base_folder = os.path.join(self.root, "eurosat") - self._data_folder = os.path.join(self._base_folder, "2750") super().__init__(self._data_folder, **kwargs) self.classes = [self._class_map.get(cls, cls) for cls in self.classes] self.root = os.path.expanduser(root) From fad24b80afbdc74d4d6ae70ddb25a3ee422bba6c Mon Sep 17 00:00:00 2001 From: frgfm Date: Mon, 17 Jan 2022 23:11:17 +0100 Subject: [PATCH 16/16] refactor: Removed name mapping --- torchvision/datasets/eurosat.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index 17ea030a252..d7876b7afd5 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -22,15 +22,6 @@ class EuroSAT(ImageFolder): url = "https://madm.dfki.de/files/sentinel/EuroSAT.zip" md5 = "c8fa014336c82ac7804f0398fcb19387" - _class_map = { - "AnnualCrop": "Annual Crop", - "HerbaceousVegetation": "Herbaceous Vegetation", - "Industrial": "Industrial Buildings", - "PermanentCrop": "Permanent Crop", - "Residential": "Residential Buildings", - "SeaLake": "Sea & Lake", - } - def __init__( self, root: str, @@ -48,7 +39,6 @@ def __init__( raise RuntimeError("Dataset not found. You can use download=True to download it") super().__init__(self._data_folder, **kwargs) - self.classes = [self._class_map.get(cls, cls) for cls in self.classes] self.root = os.path.expanduser(root) def __len__(self) -> int: