Skip to content

Commit 5c9c835

Browse files
pmeierNicolasHug
andauthored
Add DTD dataset (#5115)
* add DTD as prototype dataset * add old style dataset * add test for old dataset * fix tests for windows * add dataset to docs * remove properties and use pathlib * Apply suggestions from code review Co-authored-by: Nicolas Hug <[email protected]> * fold -> partition Co-authored-by: Nicolas Hug <[email protected]>
1 parent df628c4 commit 5c9c835

File tree

7 files changed

+317
-0
lines changed

7 files changed

+317
-0
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
3838
Cityscapes
3939
CocoCaptions
4040
CocoDetection
41+
DTD
4142
EMNIST
4243
FakeData
4344
FashionMNIST

test/test_datasets.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,5 +2205,41 @@ def inject_fake_data(self, tmpdir: str, config):
22052205
return len(sampled_classes * n_samples_per_class)
22062206

22072207

2208+
class DTDTestCase(datasets_utils.ImageDatasetTestCase):
2209+
DATASET_CLASS = datasets.DTD
2210+
FEATURE_TYPES = (PIL.Image.Image, int)
2211+
2212+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
2213+
split=("train", "test", "val"),
2214+
# There is no need to test the whole matrix here, since each fold is treated exactly the same
2215+
partition=(1, 5, 10),
2216+
)
2217+
2218+
def inject_fake_data(self, tmpdir: str, config):
2219+
data_folder = pathlib.Path(tmpdir) / "dtd" / "dtd"
2220+
2221+
num_images_per_class = 3
2222+
image_folder = data_folder / "images"
2223+
image_files = []
2224+
for cls in ("banded", "marbled", "zigzagged"):
2225+
image_files.extend(
2226+
datasets_utils.create_image_folder(
2227+
image_folder,
2228+
cls,
2229+
file_name_fn=lambda idx: f"{cls}_{idx:04d}.jpg",
2230+
num_examples=num_images_per_class,
2231+
)
2232+
)
2233+
2234+
meta_folder = data_folder / "labels"
2235+
meta_folder.mkdir()
2236+
image_ids = [str(path.relative_to(path.parents[1])).replace(os.sep, "/") for path in image_files]
2237+
image_ids_in_config = random.choices(image_ids, k=len(image_files) // 2)
2238+
with open(meta_folder / f"{config['split']}{config['partition']}.txt", "w") as file:
2239+
file.write("\n".join(image_ids_in_config) + "\n")
2240+
2241+
return len(image_ids_in_config)
2242+
2243+
22082244
if __name__ == "__main__":
22092245
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .cifar import CIFAR10, CIFAR100
55
from .cityscapes import Cityscapes
66
from .coco import CocoCaptions, CocoDetection
7+
from .dtd import DTD
78
from .fakedata import FakeData
89
from .flickr import Flickr8k, Flickr30k
910
from .folder import ImageFolder, DatasetFolder
@@ -79,4 +80,5 @@
7980
"FlyingThings3D",
8081
"HD1K",
8182
"Food101",
83+
"DTD",
8284
)

torchvision/datasets/dtd.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import os
2+
import pathlib
3+
from typing import Optional, Callable
4+
5+
import PIL.Image
6+
7+
from .utils import verify_str_arg, download_and_extract_archive
8+
from .vision import VisionDataset
9+
10+
11+
class DTD(VisionDataset):
12+
"""`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_.
13+
14+
Args:
15+
root (string): Root directory of the dataset.
16+
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
17+
partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
18+
19+
.. note::
20+
21+
The partition only changes which split each image belongs to. Thus, regardless of the selected
22+
partition, combining all splits will result in all images.
23+
24+
download (bool, optional): If True, downloads the dataset from the internet and
25+
puts it in root directory. If dataset is already downloaded, it is not
26+
downloaded again.
27+
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
28+
version. E.g, ``transforms.RandomCrop``.
29+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
30+
"""
31+
32+
_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
33+
_MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1"
34+
35+
def __init__(
36+
self,
37+
root: str,
38+
split: str = "train",
39+
partition: int = 1,
40+
download: bool = True,
41+
transform: Optional[Callable] = None,
42+
target_transform: Optional[Callable] = None,
43+
) -> None:
44+
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
45+
if not isinstance(partition, int) and not (1 <= partition <= 10):
46+
raise ValueError(
47+
f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, "
48+
f"but got {partition} instead"
49+
)
50+
self._partition = partition
51+
52+
super().__init__(root, transform=transform, target_transform=target_transform)
53+
self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower()
54+
self._data_folder = self._base_folder / "dtd"
55+
self._meta_folder = self._data_folder / "labels"
56+
self._images_folder = self._data_folder / "images"
57+
58+
if download:
59+
self._download()
60+
61+
if not self._check_exists():
62+
raise RuntimeError("Dataset not found. You can use download=True to download it")
63+
64+
self._image_files = []
65+
classes = []
66+
with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file:
67+
for line in file:
68+
cls, name = line.strip().split("/")
69+
self._image_files.append(self._images_folder.joinpath(cls, name))
70+
classes.append(cls)
71+
72+
self.classes = sorted(set(classes))
73+
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
74+
self._labels = [self.class_to_idx[cls] for cls in classes]
75+
76+
def __len__(self) -> int:
77+
return len(self._image_files)
78+
79+
def __getitem__(self, idx):
80+
image_file, label = self._image_files[idx], self._labels[idx]
81+
image = PIL.Image.open(image_file).convert("RGB")
82+
83+
if self.transform:
84+
image = self.transform(image)
85+
86+
if self.target_transform:
87+
label = self.target_transform(label)
88+
89+
return image, label
90+
91+
def extra_repr(self) -> str:
92+
return f"split={self._split}, partition={self._partition}"
93+
94+
def _check_exists(self) -> bool:
95+
return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder)
96+
97+
def _download(self) -> None:
98+
if self._check_exists():
99+
return
100+
download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5)

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .celeba import CelebA
33
from .cifar import Cifar10, Cifar100
44
from .coco import Coco
5+
from .dtd import DTD
56
from .imagenet import ImageNet
67
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
78
from .sbd import SBD
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
banded
2+
blotchy
3+
braided
4+
bubbly
5+
bumpy
6+
chequered
7+
cobwebbed
8+
cracked
9+
crosshatched
10+
crystalline
11+
dotted
12+
fibrous
13+
flecked
14+
freckled
15+
frilly
16+
gauzy
17+
grid
18+
grooved
19+
honeycombed
20+
interlaced
21+
knitted
22+
lacelike
23+
lined
24+
marbled
25+
matted
26+
meshed
27+
paisley
28+
perforated
29+
pitted
30+
pleated
31+
polka-dotted
32+
porous
33+
potholed
34+
scaly
35+
smeared
36+
spiralled
37+
sprinkled
38+
stained
39+
stratified
40+
striped
41+
studded
42+
swirly
43+
veined
44+
waffled
45+
woven
46+
wrinkled
47+
zigzagged
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import io
2+
import pathlib
3+
from typing import Any, Callable, Dict, List, Optional, Tuple
4+
5+
import torch
6+
from torchdata.datapipes.iter import (
7+
IterDataPipe,
8+
Mapper,
9+
Shuffler,
10+
Filter,
11+
IterKeyZipper,
12+
Demultiplexer,
13+
LineReader,
14+
CSVParser,
15+
)
16+
from torchvision.prototype.datasets.utils import (
17+
Dataset,
18+
DatasetConfig,
19+
DatasetInfo,
20+
HttpResource,
21+
OnlineResource,
22+
DatasetType,
23+
)
24+
from torchvision.prototype.datasets.utils._internal import (
25+
INFINITE_BUFFER_SIZE,
26+
hint_sharding,
27+
path_comparator,
28+
getitem,
29+
)
30+
from torchvision.prototype.features import Label
31+
32+
33+
class DTD(Dataset):
34+
def _make_info(self) -> DatasetInfo:
35+
return DatasetInfo(
36+
"dtd",
37+
type=DatasetType.IMAGE,
38+
homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/",
39+
valid_options=dict(
40+
split=("train", "test", "val"),
41+
fold=tuple(str(fold) for fold in range(1, 11)),
42+
),
43+
)
44+
45+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
46+
archive = HttpResource(
47+
"https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz",
48+
sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205",
49+
decompress=True,
50+
)
51+
return [archive]
52+
53+
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
54+
path = pathlib.Path(data[0])
55+
if path.parent.name == "labels":
56+
if path.name == "labels_joint_anno.txt":
57+
return 1
58+
59+
return 0
60+
elif path.parents[1].name == "images":
61+
return 2
62+
else:
63+
return None
64+
65+
def _image_key_fn(self, data: Tuple[str, Any]) -> str:
66+
path = pathlib.Path(data[0])
67+
return str(path.relative_to(path.parents[1]))
68+
69+
def _collate_and_decode_sample(
70+
self,
71+
data: Tuple[Tuple[str, List[str]], Tuple[str, io.IOBase]],
72+
*,
73+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
74+
) -> Dict[str, Any]:
75+
(_, joint_categories_data), image_data = data
76+
_, *joint_categories = joint_categories_data
77+
path, buffer = image_data
78+
79+
category = pathlib.Path(path).parent.name
80+
81+
return dict(
82+
joint_categories={category for category in joint_categories if category},
83+
label=Label(self.info.categories.index(category), category=category),
84+
path=path,
85+
image=decoder(buffer) if decoder else buffer,
86+
)
87+
88+
def _make_datapipe(
89+
self,
90+
resource_dps: List[IterDataPipe],
91+
*,
92+
config: DatasetConfig,
93+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
94+
) -> IterDataPipe[Dict[str, Any]]:
95+
archive_dp = resource_dps[0]
96+
97+
splits_dp, joint_categories_dp, images_dp = Demultiplexer(
98+
archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
99+
)
100+
101+
splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt"))
102+
splits_dp = LineReader(splits_dp, decode=True, return_path=False)
103+
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
104+
splits_dp = hint_sharding(splits_dp)
105+
106+
joint_categories_dp = CSVParser(joint_categories_dp, delimiter=" ")
107+
108+
dp = IterKeyZipper(
109+
splits_dp,
110+
joint_categories_dp,
111+
key_fn=getitem(),
112+
ref_key_fn=getitem(0),
113+
buffer_size=INFINITE_BUFFER_SIZE,
114+
)
115+
dp = IterKeyZipper(
116+
dp,
117+
images_dp,
118+
key_fn=getitem(0),
119+
ref_key_fn=self._image_key_fn,
120+
buffer_size=INFINITE_BUFFER_SIZE,
121+
)
122+
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
123+
124+
def _filter_images(self, data: Tuple[str, Any]) -> bool:
125+
return self._classify_archive(data) == 2
126+
127+
def _generate_categories(self, root: pathlib.Path) -> List[str]:
128+
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
129+
dp = Filter(dp, self._filter_images)
130+
return sorted({pathlib.Path(path).parent.name for path, _ in dp})

0 commit comments

Comments
 (0)