Skip to content

Commit d69bc15

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Cleanups for FLAVA datasets (#5164)
Summary: * Change default of download for Food101 and DTD * Set download default to False and put it at the end * Keep stuff private * GTSRB: train -> split. Also use pathlib * mypy * Remove split and partition for SUN397 * mypy * mypy * move download param for SST2 * Use make_dataset in SST2 * Use a base URL for GTSRB * Let's make this code more complictaed than it needs to be because why not Reviewed By: jdsgomes, prabhat00155 Differential Revision: D33739381 fbshipit-source-id: a2bcfcdc2296ffe62f8e75c8107ff1d0a87957f1
1 parent a24d670 commit d69bc15

14 files changed

+89
-125
lines changed

mypy.ini

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,7 @@ ignore_missing_imports = True
117117
[mypy-torchdata.*]
118118

119119
ignore_missing_imports = True
120+
121+
[mypy-h5py.*]
122+
123+
ignore_missing_imports = True

test/test_datasets.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2281,11 +2281,6 @@ def inject_fake_data(self, tmpdir: str, config):
22812281
class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
22822282
DATASET_CLASS = datasets.SUN397
22832283

2284-
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
2285-
split=("train", "test"),
2286-
partition=(1, 10, None),
2287-
)
2288-
22892284
def inject_fake_data(self, tmpdir: str, config):
22902285
data_dir = pathlib.Path(tmpdir) / "SUN397"
22912286
data_dir.mkdir()
@@ -2308,18 +2303,7 @@ def inject_fake_data(self, tmpdir: str, config):
23082303
with open(data_dir / "ClassName.txt", "w") as file:
23092304
file.writelines("\n".join(f"/{cls[0]}/{cls}" for cls in sampled_classes))
23102305

2311-
if config["partition"] is not None:
2312-
num_samples = max(len(im_paths) // (2 if config["split"] == "train" else 3), 1)
2313-
2314-
with open(data_dir / f"{config['split'].title()}ing_{config['partition']:02d}.txt", "w") as file:
2315-
file.writelines(
2316-
"\n".join(
2317-
f"/{f_path.relative_to(data_dir).as_posix()}"
2318-
for f_path in random.choices(im_paths, k=num_samples)
2319-
)
2320-
)
2321-
else:
2322-
num_samples = len(im_paths)
2306+
num_samples = len(im_paths)
23232307

23242308
return num_samples
23252309

@@ -2397,17 +2381,17 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
23972381
DATASET_CLASS = datasets.GTSRB
23982382
FEATURE_TYPES = (PIL.Image.Image, int)
23992383

2400-
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
2384+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
24012385

24022386
def inject_fake_data(self, tmpdir: str, config):
2403-
root_folder = os.path.join(tmpdir, "GTSRB")
2387+
root_folder = os.path.join(tmpdir, "gtsrb")
24042388
os.makedirs(root_folder, exist_ok=True)
24052389

24062390
# Train data
2407-
train_folder = os.path.join(root_folder, "Training")
2391+
train_folder = os.path.join(root_folder, "GTSRB", "Training")
24082392
os.makedirs(train_folder, exist_ok=True)
24092393

2410-
num_examples = 3
2394+
num_examples = 3 if config["split"] == "train" else 4
24112395
classes = ("00000", "00042", "00012")
24122396
for class_idx in classes:
24132397
datasets_utils.create_image_folder(
@@ -2419,7 +2403,7 @@ def inject_fake_data(self, tmpdir: str, config):
24192403

24202404
total_number_of_examples = num_examples * len(classes)
24212405
# Test data
2422-
test_folder = os.path.join(root_folder, "Final_Test", "Images")
2406+
test_folder = os.path.join(root_folder, "GTSRB", "Final_Test", "Images")
24232407
os.makedirs(test_folder, exist_ok=True)
24242408

24252409
with open(os.path.join(root_folder, "GT-final_test.csv"), "w") as csv_file:

torchvision/datasets/clevr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
split: str = "train",
3535
transform: Optional[Callable] = None,
3636
target_transform: Optional[Callable] = None,
37-
download: bool = True,
37+
download: bool = False,
3838
) -> None:
3939
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
4040
super().__init__(root, transform=transform, target_transform=target_transform)

torchvision/datasets/country211.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
split: str = "train",
3333
transform: Optional[Callable] = None,
3434
target_transform: Optional[Callable] = None,
35-
download: bool = True,
35+
download: bool = False,
3636
) -> None:
3737
self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
3838

torchvision/datasets/dtd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ class DTD(VisionDataset):
2121
The partition only changes which split each image belongs to. Thus, regardless of the selected
2222
partition, combining all splits will result in all images.
2323
24-
download (bool, optional): If True, downloads the dataset from the internet and
25-
puts it in root directory. If dataset is already downloaded, it is not
26-
downloaded again.
2724
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
2825
version. E.g, ``transforms.RandomCrop``.
2926
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
27+
download (bool, optional): If True, downloads the dataset from the internet and
28+
puts it in root directory. If dataset is already downloaded, it is not
29+
downloaded again. Default is False.
3030
"""
3131

3232
_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
@@ -37,9 +37,9 @@ def __init__(
3737
root: str,
3838
split: str = "train",
3939
partition: int = 1,
40-
download: bool = True,
4140
transform: Optional[Callable] = None,
4241
target_transform: Optional[Callable] = None,
42+
download: bool = False,
4343
) -> None:
4444
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
4545
if not isinstance(partition, int) and not (1 <= partition <= 10):

torchvision/datasets/eurosat.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Any
2+
from typing import Callable, Optional
33

44
from .folder import ImageFolder
55
from .utils import download_and_extract_archive
@@ -10,23 +10,21 @@ class EuroSAT(ImageFolder):
1010
1111
Args:
1212
root (string): Root directory of dataset where ``root/eurosat`` exists.
13-
download (bool, optional): If True, downloads the dataset from the internet and
14-
puts it in root directory. If dataset is already downloaded, it is not
15-
downloaded again. Default is False.
1613
transform (callable, optional): A function/transform that takes in an PIL image
1714
and returns a transformed version. E.g, ``transforms.RandomCrop``
1815
target_transform (callable, optional): A function/transform that takes in the
1916
target and transforms it.
17+
download (bool, optional): If True, downloads the dataset from the internet and
18+
puts it in root directory. If dataset is already downloaded, it is not
19+
downloaded again. Default is False.
2020
"""
2121

22-
url = "https://madm.dfki.de/files/sentinel/EuroSAT.zip"
23-
md5 = "c8fa014336c82ac7804f0398fcb19387"
24-
2522
def __init__(
2623
self,
2724
root: str,
25+
transform: Optional[Callable] = None,
26+
target_transform: Optional[Callable] = None,
2827
download: bool = False,
29-
**kwargs: Any,
3028
) -> None:
3129
self.root = os.path.expanduser(root)
3230
self._base_folder = os.path.join(self.root, "eurosat")
@@ -38,7 +36,7 @@ def __init__(
3836
if not self._check_exists():
3937
raise RuntimeError("Dataset not found. You can use download=True to download it")
4038

41-
super().__init__(self._data_folder, **kwargs)
39+
super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
4240
self.root = os.path.expanduser(root)
4341

4442
def __len__(self) -> int:
@@ -53,4 +51,8 @@ def download(self) -> None:
5351
return
5452

5553
os.makedirs(self._base_folder, exist_ok=True)
56-
download_and_extract_archive(self.url, download_root=self._base_folder, md5=self.md5)
54+
download_and_extract_archive(
55+
"https://madm.dfki.de/files/sentinel/EuroSAT.zip",
56+
download_root=self._base_folder,
57+
md5="c8fa014336c82ac7804f0398fcb19387",
58+
)

torchvision/datasets/fgvc_aircraft.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ class FGVCAircraft(VisionDataset):
2626
root (string): Root directory of the FGVC Aircraft dataset.
2727
split (string, optional): The dataset split, supports ``train``, ``val``,
2828
``trainval`` and ``test``.
29-
download (bool, optional): If True, downloads the dataset from the internet and
30-
puts it in root directory. If dataset is already downloaded, it is not
31-
downloaded again.
3229
annotation_level (str, optional): The annotation level, supports ``variant``,
3330
``family`` and ``manufacturer``.
3431
transform (callable, optional): A function/transform that takes in an PIL image
3532
and returns a transformed version. E.g, ``transforms.RandomCrop``
3633
target_transform (callable, optional): A function/transform that takes in the
3734
target and transforms it.
35+
download (bool, optional): If True, downloads the dataset from the internet and
36+
puts it in root directory. If dataset is already downloaded, it is not
37+
downloaded again.
3838
"""
3939

4040
_URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
@@ -43,10 +43,10 @@ def __init__(
4343
self,
4444
root: str,
4545
split: str = "trainval",
46-
download: bool = False,
4746
annotation_level: str = "variant",
4847
transform: Optional[Callable] = None,
4948
target_transform: Optional[Callable] = None,
49+
download: bool = False,
5050
) -> None:
5151
super().__init__(root, transform=transform, target_transform=target_transform)
5252
self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))

torchvision/datasets/flowers102.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ class Flowers102(VisionDataset):
2424
Args:
2525
root (string): Root directory of the dataset.
2626
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
27-
download (bool, optional): If true, downloads the dataset from the internet and
28-
puts it in root directory. If dataset is already downloaded, it is not
29-
downloaded again.
3027
transform (callable, optional): A function/transform that takes in an PIL image and returns a
3128
transformed version. E.g, ``transforms.RandomCrop``.
3229
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
30+
download (bool, optional): If true, downloads the dataset from the internet and
31+
puts it in root directory. If dataset is already downloaded, it is not
32+
downloaded again.
3333
"""
3434

3535
_download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
@@ -44,9 +44,9 @@ def __init__(
4444
self,
4545
root: str,
4646
split: str = "train",
47-
download: bool = True,
4847
transform: Optional[Callable] = None,
4948
target_transform: Optional[Callable] = None,
49+
download: bool = False,
5050
) -> None:
5151
super().__init__(root, transform=transform, target_transform=target_transform)
5252
self._split = verify_str_arg(split, "split", ("train", "val", "test"))

torchvision/datasets/food101.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ class Food101(VisionDataset):
2424
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
2525
version. E.g, ``transforms.RandomCrop``.
2626
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
27+
download (bool, optional): If True, downloads the dataset from the internet and
28+
puts it in root directory. If dataset is already downloaded, it is not
29+
downloaded again. Default is False.
2730
"""
2831

2932
_URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
@@ -33,9 +36,9 @@ def __init__(
3336
self,
3437
root: str,
3538
split: str = "train",
36-
download: bool = True,
3739
transform: Optional[Callable] = None,
3840
target_transform: Optional[Callable] = None,
41+
download: bool = False,
3942
) -> None:
4043
super().__init__(root, transform=transform, target_transform=target_transform)
4144
self._split = verify_str_arg(split, "split", ("train", "test"))

torchvision/datasets/gtsrb.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import csv
2-
import os
2+
import pathlib
33
from typing import Any, Callable, Optional, Tuple
44

55
import PIL
66

77
from .folder import make_dataset
8-
from .utils import download_and_extract_archive
8+
from .utils import download_and_extract_archive, verify_str_arg
99
from .vision import VisionDataset
1010

1111

@@ -14,8 +14,7 @@ class GTSRB(VisionDataset):
1414
1515
Args:
1616
root (string): Root directory of the dataset.
17-
train (bool, optional): If True, creates dataset from training set, otherwise
18-
creates from test set.
17+
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
1918
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
2019
version. E.g, ``transforms.RandomCrop``.
2120
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
@@ -24,49 +23,35 @@ class GTSRB(VisionDataset):
2423
downloaded again.
2524
"""
2625

27-
# Ground Truth for the test set
28-
_gt_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_GT.zip"
29-
_gt_csv = "GT-final_test.csv"
30-
_gt_md5 = "fe31e9c9270bbcd7b84b7f21a9d9d9e5"
31-
32-
# URLs for the test and train set
33-
_urls = (
34-
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip",
35-
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB-Training_fixed.zip",
36-
)
37-
38-
_md5s = ("c7e4e6327067d32654124b0fe9e82185", "513f3c79a4c5141765e10e952eaa2478")
39-
4026
def __init__(
4127
self,
4228
root: str,
43-
train: bool = True,
29+
split: str = "train",
4430
transform: Optional[Callable] = None,
4531
target_transform: Optional[Callable] = None,
4632
download: bool = False,
4733
) -> None:
4834

4935
super().__init__(root, transform=transform, target_transform=target_transform)
5036

51-
self.root = os.path.expanduser(root)
52-
53-
self.train = train
54-
55-
self._base_folder = os.path.join(self.root, type(self).__name__)
56-
self._target_folder = os.path.join(self._base_folder, "Training" if self.train else "Final_Test/Images")
37+
self._split = verify_str_arg(split, "split", ("train", "test"))
38+
self._base_folder = pathlib.Path(root) / "gtsrb"
39+
self._target_folder = (
40+
self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
41+
)
5742

5843
if download:
5944
self.download()
6045

6146
if not self._check_exists():
6247
raise RuntimeError("Dataset not found. You can use download=True to download it")
6348

64-
if train:
65-
samples = make_dataset(self._target_folder, extensions=(".ppm",))
49+
if self._split == "train":
50+
samples = make_dataset(str(self._target_folder), extensions=(".ppm",))
6651
else:
67-
with open(os.path.join(self._base_folder, self._gt_csv)) as csv_file:
52+
with open(self._base_folder / "GT-final_test.csv") as csv_file:
6853
samples = [
69-
(os.path.join(self._target_folder, row["Filename"]), int(row["ClassId"]))
54+
(str(self._target_folder / row["Filename"]), int(row["ClassId"]))
7055
for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
7156
]
7257

@@ -91,16 +76,28 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
9176
return sample, target
9277

9378
def _check_exists(self) -> bool:
94-
return os.path.exists(self._target_folder) and os.path.isdir(self._target_folder)
79+
return self._target_folder.is_dir()
9580

9681
def download(self) -> None:
9782
if self._check_exists():
9883
return
9984

100-
download_and_extract_archive(self._urls[self.train], download_root=self.root, md5=self._md5s[self.train])
85+
base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
10186

102-
if not self.train:
103-
# Download Ground Truth for the test set
87+
if self._split == "train":
88+
download_and_extract_archive(
89+
f"{base_url}GTSRB-Training_fixed.zip",
90+
download_root=str(self._base_folder),
91+
md5="513f3c79a4c5141765e10e952eaa2478",
92+
)
93+
else:
94+
download_and_extract_archive(
95+
f"{base_url}GTSRB_Final_Test_Images.zip",
96+
download_root=str(self._base_folder),
97+
md5="c7e4e6327067d32654124b0fe9e82185",
98+
)
10499
download_and_extract_archive(
105-
self._gt_url, download_root=self.root, extract_root=self._base_folder, md5=self._gt_md5
100+
f"{base_url}GTSRB_Final_Test_GT.zip",
101+
download_root=str(self._base_folder),
102+
md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
106103
)

torchvision/datasets/oxford_iiit_pet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
transforms: Optional[Callable] = None,
4646
transform: Optional[Callable] = None,
4747
target_transform: Optional[Callable] = None,
48-
download: bool = True,
48+
download: bool = False,
4949
):
5050
self._split = verify_str_arg(split, "split", ("trainval", "test"))
5151
if isinstance(target_types, str):

torchvision/datasets/pcam.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ def __init__(
7272
split: str = "train",
7373
transform: Optional[Callable] = None,
7474
target_transform: Optional[Callable] = None,
75-
download: bool = True,
75+
download: bool = False,
7676
):
7777
try:
78-
import h5py # type: ignore[import]
78+
import h5py
7979

8080
self.h5py = h5py
8181
except ImportError:

0 commit comments

Comments
 (0)