-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add Rendered sst2 dataset #5220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
31fadbe
1e578b7
85e4429
4e3d900
615b612
a0bbece
ba966f4
6cdd49b
d4f1638
069bba4
409dcad
78f5e45
e6c95ad
39b2441
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2665,5 +2665,27 @@ def inject_fake_data(self, tmpdir: str, config): | |||||
return num_images | ||||||
|
||||||
|
||||||
class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase): | ||||||
DATASET_CLASS = datasets.RenderedSST2 | ||||||
FEATURE_TYPES = (PIL.Image.Image, int) | ||||||
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "valid", "test")) | ||||||
|
||||||
def inject_fake_data(self, tmpdir: str, config): | ||||||
root_folder = pathlib.Path(tmpdir) / "rendered-sst2" | ||||||
image_folder = root_folder / config["split"] | ||||||
|
||||||
num_images_per_class = 5 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To slightly increase robustness:
Suggested change
|
||||||
sampled_classes = ["positive", "negative"] | ||||||
for cls in sampled_classes: | ||||||
datasets_utils.create_image_folder( | ||||||
image_folder, | ||||||
cls, | ||||||
file_name_fn=lambda idx: f"{idx}.png", | ||||||
num_examples=num_images_per_class, | ||||||
) | ||||||
|
||||||
return len(sampled_classes) * num_images_per_class | ||||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
unittest.main() |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,91 @@ | ||||||
from pathlib import Path | ||||||
from typing import Any, Tuple, Callable, Optional | ||||||
|
||||||
import PIL.Image | ||||||
|
||||||
from .utils import verify_str_arg, download_and_extract_archive | ||||||
from .vision import VisionDataset | ||||||
|
||||||
|
||||||
class RenderedSST2(VisionDataset): | ||||||
"""`The Rendered SST2 Dataset <https://github.com/openai/CLIP/blob/main/data/rendered-sst2.md>`_. | ||||||
|
||||||
Rendered SST2 is a image classification dataset used to evaluate the models capability on optical | ||||||
character recognition. This dataset was generated bu rendering sentences in the Standford Sentiment | ||||||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
Treebank v2 dataset. | ||||||
|
||||||
This dataset contains two classes (positive and negative) and is divided in three splits: a train | ||||||
split containing 6920 images (3610 positive and 3310 negative), a validation split containing 872 images | ||||||
(444 positive and 428 negative), and a test split containing 1821 images (909 positive and 912 negative). | ||||||
|
||||||
Args: | ||||||
root (string): Root directory of the dataset. | ||||||
split (string, optional): The dataset split, supports ``"train"`` (default), `"valid"` and ``"test"``. | ||||||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed | ||||||
version. E.g, ``transforms.RandomCrop``. | ||||||
target_transform (callable, optional): A function/transform that takes in the target and transforms it. | ||||||
jdsgomes marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
""" | ||||||
|
||||||
_URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz" | ||||||
_MD5 = "2384d08e9dcfa4bd55b324e610496ee5" | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
root: str, | ||||||
split: str = "train", | ||||||
download: bool = True, | ||||||
transform: Optional[Callable] = None, | ||||||
target_transform: Optional[Callable] = None, | ||||||
) -> None: | ||||||
super().__init__(root, transform=transform, target_transform=target_transform) | ||||||
self._split = verify_str_arg(split, "split", ("train", "valid", "test")) | ||||||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
self._base_folder = Path(self.root) / "rendered-sst2" | ||||||
self.classes = ["negative", "positive"] | ||||||
self.class_to_idx = {"negative": 0, "positive": 1} | ||||||
|
||||||
if download: | ||||||
self._download() | ||||||
|
||||||
if not self._check_exists(): | ||||||
raise RuntimeError("Dataset not found. You can use download=True to download it") | ||||||
|
||||||
self._labels = [] | ||||||
self._image_files = [] | ||||||
|
||||||
for p in (self._base_folder / self._split).glob("**/*.png"): | ||||||
self._labels.append(self.class_to_idx[p.parent.name]) | ||||||
self._image_files.append(p) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is something that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, Either self._image_files, self._labels = zip(*make_dataset(str(self._base_folder / self._split))) or self._samples = make_dataset(str(self._base_folder / self._split)) and do image_file, label = self._samples[idx] in No strong opinion, but if we go for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I took care of that in #5164 ! |
||||||
print(self._labels) | ||||||
print(self._image_files) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oops :p
Suggested change
|
||||||
|
||||||
def __len__(self) -> int: | ||||||
return len(self._image_files) | ||||||
|
||||||
def __getitem__(self, idx) -> Tuple[Any, Any]: | ||||||
image_file, label = self._image_files[idx], self._labels[idx] | ||||||
image = PIL.Image.open(image_file).convert("RGB") | ||||||
|
||||||
if self.transform: | ||||||
image = self.transform(image) | ||||||
|
||||||
if self.target_transform: | ||||||
label = self.target_transform(label) | ||||||
|
||||||
return image, label | ||||||
|
||||||
def extra_repr(self) -> str: | ||||||
return f"split={self._split}" | ||||||
|
||||||
def _check_exists(self) -> bool: | ||||||
for class_label in set(self.classes): | ||||||
if not ( | ||||||
(self._base_folder / self._split / class_label).exists() | ||||||
and (self._base_folder / self._split / class_label).is_dir() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I think that In [1]: from pathlib import Path
In [2]: Path("alfjnaljefeajlfbaeljnaljen").is_dir()
Out[2]: False |
||||||
): | ||||||
return False | ||||||
return True | ||||||
|
||||||
def _download(self) -> None: | ||||||
if self._check_exists(): | ||||||
return | ||||||
download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: we don't need this line as it's the default of the
datasets_utils.ImageDatasetTestCase
class