Skip to content

Add Rendered sst2 dataset #5220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
PCAM
PhotoTour
Places365
RenderedSST2
QMNIST
SBDataset
SBU
Expand Down
22 changes: 22 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2665,5 +2665,27 @@ def inject_fake_data(self, tmpdir: str, config):
return num_images


class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.RenderedSST2
FEATURE_TYPES = (PIL.Image.Image, int)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we don't need this line as it's the default of the datasets_utils.ImageDatasetTestCase class

ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "valid", "test"))

def inject_fake_data(self, tmpdir: str, config):
root_folder = pathlib.Path(tmpdir) / "rendered-sst2"
image_folder = root_folder / config["split"]

num_images_per_class = 5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To slightly increase robustness:

Suggested change
num_images_per_class = 5
num_images_per_class = {"train": 5, "test": 6, "val": 7}

sampled_classes = ["positive", "negative"]
for cls in sampled_classes:
datasets_utils.create_image_folder(
image_folder,
cls,
file_name_fn=lambda idx: f"{idx}.png",
num_examples=num_images_per_class,
)

return len(sampled_classes) * num_images_per_class


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .pcam import PCAM
from .phototour import PhotoTour
from .places365 import Places365
from .rendered_sst2 import RenderedSST2
from .sbd import SBDataset
from .sbu import SBU
from .semeion import SEMEION
Expand Down Expand Up @@ -102,4 +103,5 @@
"Country211",
"FGVCAircraft",
"EuroSAT",
"RenderedSST2",
)
91 changes: 91 additions & 0 deletions torchvision/datasets/rendered_sst2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from pathlib import Path
from typing import Any, Tuple, Callable, Optional

import PIL.Image

from .utils import verify_str_arg, download_and_extract_archive
from .vision import VisionDataset


class RenderedSST2(VisionDataset):
"""`The Rendered SST2 Dataset <https://github.com/openai/CLIP/blob/main/data/rendered-sst2.md>`_.

Rendered SST2 is a image classification dataset used to evaluate the models capability on optical
character recognition. This dataset was generated bu rendering sentences in the Standford Sentiment
Treebank v2 dataset.

This dataset contains two classes (positive and negative) and is divided in three splits: a train
split containing 6920 images (3610 positive and 3310 negative), a validation split containing 872 images
(444 positive and 428 negative), and a test split containing 1821 images (909 positive and 912 negative).

Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), `"valid"` and ``"test"``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""

_URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz"
_MD5 = "2384d08e9dcfa4bd55b324e610496ee5"

def __init__(
self,
root: str,
split: str = "train",
download: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
self._base_folder = Path(self.root) / "rendered-sst2"
self.classes = ["negative", "positive"]
self.class_to_idx = {"negative": 0, "positive": 1}

if download:
self._download()

if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")

self._labels = []
self._image_files = []

for p in (self._base_folder / self._split).glob("**/*.png"):
self._labels.append(self.class_to_idx[p.parent.name])
self._image_files.append(p)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is something that make_dataset() could be used for. But the code here is very simple so IMHO it's fine to keep as-is (perhaps @pmeier can share his thoughts).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, make_dataset could make this even shorter:

Either

self._image_files, self._labels = zip(*make_dataset(str(self._base_folder / self._split)))

or

self._samples = make_dataset(str(self._base_folder / self._split))

and do

image_file, label = self._samples[idx]

in __getitem__.

No strong opinion, but if we go for make_dataset, I would prefer the latter option.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took care of that in #5164 !

print(self._labels)
print(self._image_files)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops :p

Suggested change
print(self._labels)
print(self._image_files)


def __len__(self) -> int:
return len(self._image_files)

def __getitem__(self, idx) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")

if self.transform:
image = self.transform(image)

if self.target_transform:
label = self.target_transform(label)

return image, label

def extra_repr(self) -> str:
return f"split={self._split}"

def _check_exists(self) -> bool:
for class_label in set(self.classes):
if not (
(self._base_folder / self._split / class_label).exists()
and (self._base_folder / self._split / class_label).is_dir()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think that is_dir() properly returns False when the directory does not exist, so perhaps we can avoid using exists():

In [1]: from pathlib import Path

In [2]: Path("alfjnaljefeajlfbaeljnaljen").is_dir()
Out[2]: False

):
return False
return True

def _download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)