-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add DTD dataset #5115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DTD dataset #5115
Changes from all commits
b351430
e9b4c1c
52ab249
b484881
b7be23f
840c5d8
e044a8e
dae1424
6df6403
9e4476a
7d21152
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import os | ||
import pathlib | ||
from typing import Optional, Callable | ||
|
||
import PIL.Image | ||
|
||
from .utils import verify_str_arg, download_and_extract_archive | ||
from .vision import VisionDataset | ||
|
||
|
||
class DTD(VisionDataset): | ||
"""`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_. | ||
|
||
Args: | ||
root (string): Root directory of the dataset. | ||
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. | ||
partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``. | ||
|
||
.. note:: | ||
|
||
The partition only changes which split each image belongs to. Thus, regardless of the selected | ||
partition, combining all splits will result in all images. | ||
|
||
download (bool, optional): If True, downloads the dataset from the internet and | ||
puts it in root directory. If dataset is already downloaded, it is not | ||
downloaded again. | ||
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed | ||
version. E.g, ``transforms.RandomCrop``. | ||
target_transform (callable, optional): A function/transform that takes in the target and transforms it. | ||
""" | ||
|
||
_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz" | ||
_MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1" | ||
|
||
def __init__( | ||
self, | ||
root: str, | ||
split: str = "train", | ||
partition: int = 1, | ||
download: bool = True, | ||
transform: Optional[Callable] = None, | ||
target_transform: Optional[Callable] = None, | ||
) -> None: | ||
self._split = verify_str_arg(split, "split", ("train", "val", "test")) | ||
if not isinstance(partition, int) and not (1 <= partition <= 10): | ||
raise ValueError( | ||
f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, " | ||
f"but got {partition} instead" | ||
) | ||
self._partition = partition | ||
|
||
super().__init__(root, transform=transform, target_transform=target_transform) | ||
self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower() | ||
self._data_folder = self._base_folder / "dtd" | ||
self._meta_folder = self._data_folder / "labels" | ||
self._images_folder = self._data_folder / "images" | ||
|
||
if download: | ||
self._download() | ||
|
||
if not self._check_exists(): | ||
raise RuntimeError("Dataset not found. You can use download=True to download it") | ||
|
||
self._image_files = [] | ||
classes = [] | ||
with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file: | ||
for line in file: | ||
cls, name = line.strip().split("/") | ||
self._image_files.append(self._images_folder.joinpath(cls, name)) | ||
classes.append(cls) | ||
|
||
self.classes = sorted(set(classes)) | ||
self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) | ||
self._labels = [self.class_to_idx[cls] for cls in classes] | ||
|
||
def __len__(self) -> int: | ||
return len(self._image_files) | ||
|
||
def __getitem__(self, idx): | ||
image_file, label = self._image_files[idx], self._labels[idx] | ||
image = PIL.Image.open(image_file).convert("RGB") | ||
|
||
if self.transform: | ||
image = self.transform(image) | ||
|
||
if self.target_transform: | ||
label = self.target_transform(label) | ||
|
||
return image, label | ||
|
||
def extra_repr(self) -> str: | ||
return f"split={self._split}, partition={self._partition}" | ||
|
||
def _check_exists(self) -> bool: | ||
return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder) | ||
|
||
def _download(self) -> None: | ||
if self._check_exists(): | ||
return | ||
download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
banded | ||
blotchy | ||
braided | ||
bubbly | ||
bumpy | ||
chequered | ||
cobwebbed | ||
cracked | ||
crosshatched | ||
crystalline | ||
dotted | ||
fibrous | ||
flecked | ||
freckled | ||
frilly | ||
gauzy | ||
grid | ||
grooved | ||
honeycombed | ||
interlaced | ||
knitted | ||
lacelike | ||
lined | ||
marbled | ||
matted | ||
meshed | ||
paisley | ||
perforated | ||
pitted | ||
pleated | ||
polka-dotted | ||
porous | ||
potholed | ||
scaly | ||
smeared | ||
spiralled | ||
sprinkled | ||
stained | ||
stratified | ||
striped | ||
studded | ||
swirly | ||
veined | ||
waffled | ||
woven | ||
wrinkled | ||
zigzagged |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import io | ||
import pathlib | ||
from typing import Any, Callable, Dict, List, Optional, Tuple | ||
|
||
import torch | ||
from torchdata.datapipes.iter import ( | ||
IterDataPipe, | ||
Mapper, | ||
Shuffler, | ||
Filter, | ||
IterKeyZipper, | ||
Demultiplexer, | ||
LineReader, | ||
CSVParser, | ||
) | ||
from torchvision.prototype.datasets.utils import ( | ||
Dataset, | ||
DatasetConfig, | ||
DatasetInfo, | ||
HttpResource, | ||
OnlineResource, | ||
DatasetType, | ||
) | ||
from torchvision.prototype.datasets.utils._internal import ( | ||
INFINITE_BUFFER_SIZE, | ||
hint_sharding, | ||
path_comparator, | ||
getitem, | ||
) | ||
from torchvision.prototype.features import Label | ||
|
||
|
||
class DTD(Dataset): | ||
def _make_info(self) -> DatasetInfo: | ||
return DatasetInfo( | ||
"dtd", | ||
type=DatasetType.IMAGE, | ||
homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", | ||
valid_options=dict( | ||
split=("train", "test", "val"), | ||
fold=tuple(str(fold) for fold in range(1, 11)), | ||
), | ||
) | ||
|
||
def resources(self, config: DatasetConfig) -> List[OnlineResource]: | ||
archive = HttpResource( | ||
"https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz", | ||
sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205", | ||
decompress=True, | ||
) | ||
return [archive] | ||
|
||
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: | ||
path = pathlib.Path(data[0]) | ||
if path.parent.name == "labels": | ||
if path.name == "labels_joint_anno.txt": | ||
return 1 | ||
|
||
return 0 | ||
elif path.parents[1].name == "images": | ||
return 2 | ||
Comment on lines
+56
to
+61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. instead of hardcoded 0 1 2, would a small private enum be overkill here? or maybe just hardcoded named constants? No strong opinion but this might help readability There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that could work. If that is ok with you, I'll put it in my backlog, because we should do that for all datasets that use a |
||
else: | ||
return None | ||
|
||
def _image_key_fn(self, data: Tuple[str, Any]) -> str: | ||
path = pathlib.Path(data[0]) | ||
return str(path.relative_to(path.parents[1])) | ||
|
||
def _collate_and_decode_sample( | ||
self, | ||
data: Tuple[Tuple[str, List[str]], Tuple[str, io.IOBase]], | ||
*, | ||
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], | ||
) -> Dict[str, Any]: | ||
(_, joint_categories_data), image_data = data | ||
_, *joint_categories = joint_categories_data | ||
path, buffer = image_data | ||
|
||
category = pathlib.Path(path).parent.name | ||
|
||
return dict( | ||
joint_categories={category for category in joint_categories if category}, | ||
label=Label(self.info.categories.index(category), category=category), | ||
path=path, | ||
image=decoder(buffer) if decoder else buffer, | ||
) | ||
|
||
def _make_datapipe( | ||
self, | ||
resource_dps: List[IterDataPipe], | ||
*, | ||
config: DatasetConfig, | ||
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], | ||
) -> IterDataPipe[Dict[str, Any]]: | ||
archive_dp = resource_dps[0] | ||
|
||
splits_dp, joint_categories_dp, images_dp = Demultiplexer( | ||
archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE | ||
) | ||
|
||
splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt")) | ||
splits_dp = LineReader(splits_dp, decode=True, return_path=False) | ||
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE) | ||
splits_dp = hint_sharding(splits_dp) | ||
|
||
joint_categories_dp = CSVParser(joint_categories_dp, delimiter=" ") | ||
|
||
dp = IterKeyZipper( | ||
splits_dp, | ||
joint_categories_dp, | ||
key_fn=getitem(), | ||
ref_key_fn=getitem(0), | ||
buffer_size=INFINITE_BUFFER_SIZE, | ||
) | ||
dp = IterKeyZipper( | ||
dp, | ||
images_dp, | ||
key_fn=getitem(0), | ||
ref_key_fn=self._image_key_fn, | ||
buffer_size=INFINITE_BUFFER_SIZE, | ||
) | ||
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) | ||
|
||
def _filter_images(self, data: Tuple[str, Any]) -> bool: | ||
return self._classify_archive(data) == 2 | ||
|
||
def _generate_categories(self, root: pathlib.Path) -> List[str]: | ||
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) | ||
dp = Filter(dp, self._filter_images) | ||
return sorted({pathlib.Path(path).parent.name for path, _ in dp}) |
Uh oh!
There was an error while loading. Please reload this page.