Skip to content

Commit 97385df

Browse files
Dbhasin1Dbhasin1NicolasHug
authored
add EuroSAT prototype dataset (#5452)
* add eurosat * revert formatting * port test and make style changes * add eurosat to __init__ * fix pathlib error * create dataset zipfile and revert pre commit changes * remove unecessary variable in resources * revert auto formatter changes and modify ufmt version * revert change to contributing guide Co-authored-by: Dbhasin1 <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 01f07ee commit 97385df

File tree

3 files changed

+70
-0
lines changed

3 files changed

+70
-0
lines changed

test/builtin_dataset_mocks.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,6 +1327,24 @@ def cub200(info, root, config):
13271327
return num_samples_map[config.split]
13281328

13291329

1330+
@register_mock
1331+
def eurosat(info, root, config):
1332+
data_folder = pathlib.Path(root, "eurosat", "2750")
1333+
data_folder.mkdir(parents=True)
1334+
1335+
num_examples_per_class = 3
1336+
classes = ("AnnualCrop", "Forest")
1337+
for cls in classes:
1338+
create_image_folder(
1339+
root=data_folder,
1340+
name=cls,
1341+
file_name_fn=lambda idx: f"{cls}_{idx}.jpg",
1342+
num_examples=num_examples_per_class,
1343+
)
1344+
make_zip(root, "EuroSAT.zip", data_folder)
1345+
return len(classes) * num_examples_per_class
1346+
1347+
13301348
@register_mock
13311349
def svhn(info, root, config):
13321350
import scipy.io as sio

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .country211 import Country211
77
from .cub200 import CUB200
88
from .dtd import DTD
9+
from .eurosat import EuroSAT
910
from .fer2013 import FER2013
1011
from .gtsrb import GTSRB
1112
from .imagenet import ImageNet
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import pathlib
2+
from typing import Any, Dict, List, Tuple
3+
4+
from torchdata.datapipes.iter import IterDataPipe, Mapper
5+
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource
6+
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
7+
from torchvision.prototype.features import EncodedImage, Label
8+
9+
10+
class EuroSAT(Dataset):
11+
def _make_info(self) -> DatasetInfo:
12+
return DatasetInfo(
13+
"eurosat",
14+
homepage="https://github.com/phelber/eurosat",
15+
categories=(
16+
"AnnualCrop",
17+
"Forest",
18+
"HerbaceousVegetation",
19+
"Highway",
20+
"Industrial," "Pasture",
21+
"PermanentCrop",
22+
"Residential",
23+
"River",
24+
"SeaLake",
25+
),
26+
)
27+
28+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
29+
return [
30+
HttpResource(
31+
"https://madm.dfki.de/files/sentinel/EuroSAT.zip",
32+
sha256="8ebea626349354c5328b142b96d0430e647051f26efc2dc974c843f25ecf70bd",
33+
)
34+
]
35+
36+
def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
37+
path, buffer = data
38+
category = pathlib.Path(path).parent.name
39+
return dict(
40+
label=Label.from_category(category, categories=self.categories),
41+
path=path,
42+
image=EncodedImage.from_file(buffer),
43+
)
44+
45+
def _make_datapipe(
46+
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
47+
) -> IterDataPipe[Dict[str, Any]]:
48+
dp = resource_dps[0]
49+
dp = hint_sharding(dp)
50+
dp = hint_shuffling(dp)
51+
return Mapper(dp, self._prepare_sample)

0 commit comments

Comments
 (0)