Skip to content

Add DTD dataset #5115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
Cityscapes
CocoCaptions
CocoDetection
DTD
EMNIST
FakeData
FashionMNIST
Expand Down
36 changes: 36 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2205,5 +2205,41 @@ def inject_fake_data(self, tmpdir: str, config):
return len(sampled_classes * n_samples_per_class)


class DTDTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DTD
FEATURE_TYPES = (PIL.Image.Image, int)

ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "test", "val"),
# There is no need to test the whole matrix here, since each fold is treated exactly the same
partition=(1, 5, 10),
)

def inject_fake_data(self, tmpdir: str, config):
data_folder = pathlib.Path(tmpdir) / "dtd" / "dtd"

num_images_per_class = 3
image_folder = data_folder / "images"
image_files = []
for cls in ("banded", "marbled", "zigzagged"):
image_files.extend(
datasets_utils.create_image_folder(
image_folder,
cls,
file_name_fn=lambda idx: f"{cls}_{idx:04d}.jpg",
num_examples=num_images_per_class,
)
)

meta_folder = data_folder / "labels"
meta_folder.mkdir()
image_ids = [str(path.relative_to(path.parents[1])).replace(os.sep, "/") for path in image_files]
image_ids_in_config = random.choices(image_ids, k=len(image_files) // 2)
with open(meta_folder / f"{config['split']}{config['partition']}.txt", "w") as file:
file.write("\n".join(image_ids_in_config) + "\n")

return len(image_ids_in_config)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .cifar import CIFAR10, CIFAR100
from .cityscapes import Cityscapes
from .coco import CocoCaptions, CocoDetection
from .dtd import DTD
from .fakedata import FakeData
from .flickr import Flickr8k, Flickr30k
from .folder import ImageFolder, DatasetFolder
Expand Down Expand Up @@ -79,4 +80,5 @@
"FlyingThings3D",
"HD1K",
"Food101",
"DTD",
)
100 changes: 100 additions & 0 deletions torchvision/datasets/dtd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import pathlib
from typing import Optional, Callable

import PIL.Image

from .utils import verify_str_arg, download_and_extract_archive
from .vision import VisionDataset


class DTD(VisionDataset):
"""`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_.

Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.

.. note::

The partition only changes which split each image belongs to. Thus, regardless of the selected
partition, combining all splits will result in all images.

download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""

_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
_MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1"

def __init__(
self,
root: str,
split: str = "train",
partition: int = 1,
download: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
if not isinstance(partition, int) and not (1 <= partition <= 10):
raise ValueError(
f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, "
f"but got {partition} instead"
)
self._partition = partition

super().__init__(root, transform=transform, target_transform=target_transform)
self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower()
self._data_folder = self._base_folder / "dtd"
self._meta_folder = self._data_folder / "labels"
self._images_folder = self._data_folder / "images"

if download:
self._download()

if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")

self._image_files = []
classes = []
with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file:
for line in file:
cls, name = line.strip().split("/")
self._image_files.append(self._images_folder.joinpath(cls, name))
classes.append(cls)

self.classes = sorted(set(classes))
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
self._labels = [self.class_to_idx[cls] for cls in classes]

def __len__(self) -> int:
return len(self._image_files)

def __getitem__(self, idx):
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")

if self.transform:
image = self.transform(image)

if self.target_transform:
label = self.target_transform(label)

return image, label

def extra_repr(self) -> str:
return f"split={self._split}, partition={self._partition}"

def _check_exists(self) -> bool:
return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder)

def _download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5)
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .celeba import CelebA
from .cifar import Cifar10, Cifar100
from .coco import Coco
from .dtd import DTD
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .sbd import SBD
Expand Down
47 changes: 47 additions & 0 deletions torchvision/prototype/datasets/_builtin/dtd.categories
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
banded
blotchy
braided
bubbly
bumpy
chequered
cobwebbed
cracked
crosshatched
crystalline
dotted
fibrous
flecked
freckled
frilly
gauzy
grid
grooved
honeycombed
interlaced
knitted
lacelike
lined
marbled
matted
meshed
paisley
perforated
pitted
pleated
polka-dotted
porous
potholed
scaly
smeared
spiralled
sprinkled
stained
stratified
striped
studded
swirly
veined
waffled
woven
wrinkled
zigzagged
130 changes: 130 additions & 0 deletions torchvision/prototype/datasets/_builtin/dtd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import io
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
Filter,
IterKeyZipper,
Demultiplexer,
LineReader,
CSVParser,
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
hint_sharding,
path_comparator,
getitem,
)
from torchvision.prototype.features import Label


class DTD(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"dtd",
type=DatasetType.IMAGE,
homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/",
valid_options=dict(
split=("train", "test", "val"),
fold=tuple(str(fold) for fold in range(1, 11)),
),
)

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
archive = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz",
sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205",
decompress=True,
)
return [archive]

def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
if path.parent.name == "labels":
if path.name == "labels_joint_anno.txt":
return 1

return 0
elif path.parents[1].name == "images":
return 2
Comment on lines +56 to +61
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of hardcoded 0 1 2, would a small private enum be overkill here? or maybe just hardcoded named constants? No strong opinion but this might help readability

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that could work. If that is ok with you, I'll put it in my backlog, because we should do that for all datasets that use a Demultiplexer.

else:
return None

def _image_key_fn(self, data: Tuple[str, Any]) -> str:
path = pathlib.Path(data[0])
return str(path.relative_to(path.parents[1]))

def _collate_and_decode_sample(
self,
data: Tuple[Tuple[str, List[str]], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
(_, joint_categories_data), image_data = data
_, *joint_categories = joint_categories_data
path, buffer = image_data

category = pathlib.Path(path).parent.name

return dict(
joint_categories={category for category in joint_categories if category},
label=Label(self.info.categories.index(category), category=category),
path=path,
image=decoder(buffer) if decoder else buffer,
)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]

splits_dp, joint_categories_dp, images_dp = Demultiplexer(
archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)

splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt"))
splits_dp = LineReader(splits_dp, decode=True, return_path=False)
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
splits_dp = hint_sharding(splits_dp)

joint_categories_dp = CSVParser(joint_categories_dp, delimiter=" ")

dp = IterKeyZipper(
splits_dp,
joint_categories_dp,
key_fn=getitem(),
ref_key_fn=getitem(0),
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = IterKeyZipper(
dp,
images_dp,
key_fn=getitem(0),
ref_key_fn=self._image_key_fn,
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))

def _filter_images(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == 2

def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
dp = Filter(dp, self._filter_images)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})