Skip to content

feats: add loader in classification task datasets. #8939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch.nn.functional as F
from common_utils import combinations_grid
from torchvision import datasets
from torchvision.io import decode_image
from torchvision.transforms import v2


Expand Down Expand Up @@ -1175,6 +1176,8 @@ class SBUTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SBU
FEATURE_TYPES = (PIL.Image.Image, str)

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir, config):
num_images = 3

Expand Down Expand Up @@ -1413,6 +1416,8 @@ class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase):
_IMAGES_FOLDER = "images"
_ANNOTATIONS_FILE = "captions.html"

SUPPORT_TV_IMAGE_DECODE = True

def dataset_args(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)
root = tmpdir / self._IMAGES_FOLDER
Expand Down Expand Up @@ -1482,6 +1487,8 @@ class Flickr30kTestCase(Flickr8kTestCase):

_ANNOTATIONS_FILE = "captions.token"

SUPPORT_TV_IMAGE_DECODE = True

def _image_file_name(self, idx):
return f"{idx}.jpg"

Expand Down Expand Up @@ -1942,6 +1949,8 @@ class LFWPeopleTestCase(datasets_utils.DatasetTestCase):
_IMAGES_DIR = {"original": "lfw", "funneled": "lfw_funneled", "deepfunneled": "lfw-deepfunneled"}
_file_id = {"10fold": "", "train": "DevTrain", "test": "DevTest"}

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) / "lfw-py"
os.makedirs(tmpdir, exist_ok=True)
Expand Down Expand Up @@ -1978,6 +1987,18 @@ def _create_random_id(self):
part2 = datasets_utils.create_random_string(random.randint(4, 7))
return f"{part1}_{part2}"

def test_tv_decode_image_support(self):
if not self.SUPPORT_TV_IMAGE_DECODE:
pytest.skip(f"{self.DATASET_CLASS.__name__} does not support torchvision.io.decode_image.")

with self.create_dataset(
config=dict(
loader=decode_image,
)
) as (dataset, _):
image = dataset[0][0]
assert isinstance(image, torch.Tensor)


class LFWPairsTestCase(LFWPeopleTestCase):
DATASET_CLASS = datasets.LFWPairs
Expand Down Expand Up @@ -2335,6 +2356,8 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase):

ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir: str, config):
root_folder = pathlib.Path(tmpdir) / "food-101"
image_folder = root_folder / "images"
Expand Down Expand Up @@ -2371,6 +2394,7 @@ class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase):
ADDITIONAL_CONFIGS = combinations_grid(
split=("train", "val", "trainval", "test"), annotation_level=("variant", "family", "manufacturer")
)
SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir: str, config):
split = config["split"]
Expand Down Expand Up @@ -2420,6 +2444,8 @@ def inject_fake_data(self, tmpdir: str, config):
class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SUN397

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir: str, config):
data_dir = pathlib.Path(tmpdir) / "SUN397"
data_dir.mkdir()
Expand Down Expand Up @@ -2451,6 +2477,8 @@ class DTDTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DTD
FEATURE_TYPES = (PIL.Image.Image, int)

SUPPORT_TV_IMAGE_DECODE = True

ADDITIONAL_CONFIGS = combinations_grid(
split=("train", "test", "val"),
# There is no need to test the whole matrix here, since each fold is treated exactly the same
Expand Down Expand Up @@ -2611,6 +2639,7 @@ class CLEVRClassificationTestCase(datasets_utils.ImageDatasetTestCase):
FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))

ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir, config):
data_folder = pathlib.Path(tmpdir) / "clevr" / "CLEVR_v1.0"
Expand Down Expand Up @@ -2708,6 +2737,8 @@ class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
REQUIRED_PACKAGES = ("scipy",)
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir, config):
import scipy.io as io
from numpy.core.records import fromarrays
Expand Down Expand Up @@ -2782,6 +2813,8 @@ class Flowers102TestCase(datasets_utils.ImageDatasetTestCase):
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
REQUIRED_PACKAGES = ("scipy",)

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir: str, config):
base_folder = pathlib.Path(tmpdir) / "flowers-102"

Expand Down Expand Up @@ -2840,6 +2873,8 @@ class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase):
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
SPLIT_TO_FOLDER = {"train": "train", "val": "valid", "test": "test"}

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir: str, config):
root_folder = pathlib.Path(tmpdir) / "rendered-sst2"
image_folder = root_folder / self.SPLIT_TO_FOLDER[config["split"]]
Expand Down Expand Up @@ -3500,6 +3535,8 @@ class ImagenetteTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Imagenette
ADDITIONAL_CONFIGS = combinations_grid(split=["train", "val"], size=["full", "320px", "160px"])

SUPPORT_TV_IMAGE_DECODE = True

_WNIDS = [
"n01440764",
"n02102040",
Expand Down
13 changes: 9 additions & 4 deletions torchvision/datasets/clevr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, List, Optional, Tuple, Union
from urllib.parse import urlparse

from PIL import Image
from .folder import default_loader

from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
Expand All @@ -18,11 +18,14 @@ class CLEVRClassification(VisionDataset):
root (str or ``pathlib.Path``): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
set to True.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in them target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If
dataset is already downloaded, it is not downloaded again.
loader (callable, optional): A function to load an image given its path.
By default, it uses PIL as its image loader, but users could also pass in
``torchvision.io.decode_image`` for decoding image data into tensors directly.
"""

_URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip"
Expand All @@ -35,9 +38,11 @@ def __init__(
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
super().__init__(root, transform=transform, target_transform=target_transform)
self.loader = loader
self._base_folder = pathlib.Path(self.root) / "clevr"
self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem

Expand Down Expand Up @@ -65,7 +70,7 @@ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file = self._image_files[idx]
label = self._labels[idx]

image = Image.open(image_file).convert("RGB")
image = self.loader(image_file)

if self.transform:
image = self.transform(image)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/country211.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class Country211(ImageFolder):
Args:
root (str or ``pathlib.Path``): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and puts it into
``root/country211/``. If dataset is already downloaded, it is not downloaded again.
Expand Down
13 changes: 9 additions & 4 deletions torchvision/datasets/dtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pathlib
from typing import Any, Callable, Optional, Tuple, Union

import PIL.Image
from .folder import default_loader

from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
Expand All @@ -21,12 +21,15 @@ class DTD(VisionDataset):
The partition only changes which split each image belongs to. Thus, regardless of the selected
partition, combining all splits will result in all images.

transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
loader (callable, optional): A function to load an image given its path.
By default, it uses PIL as its image loader, but users could also pass in
``torchvision.io.decode_image`` for decoding image data into tensors directly.
"""

_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
Expand All @@ -40,6 +43,7 @@ def __init__(
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
if not isinstance(partition, int) and not (1 <= partition <= 10):
Expand Down Expand Up @@ -72,13 +76,14 @@ def __init__(
self.classes = sorted(set(classes))
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
self._labels = [self.class_to_idx[cls] for cls in classes]
self.loader = loader

def __len__(self) -> int:
return len(self._image_files)

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
image = self.loader(image_file)

if self.transform:
image = self.transform(image)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/datasets/eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class EuroSAT(ImageFolder):

Args:
root (str or ``pathlib.Path``): Root directory of dataset where ``root/eurosat`` exists.
transform (callable, optional): A function/transform that takes in a PIL image
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
Expand Down
11 changes: 8 additions & 3 deletions torchvision/datasets/fgvc_aircraft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union

import PIL.Image
from .folder import default_loader

from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
Expand All @@ -29,13 +29,16 @@ class FGVCAircraft(VisionDataset):
``trainval`` and ``test``.
annotation_level (str, optional): The annotation level, supports ``variant``,
``family`` and ``manufacturer``.
transform (callable, optional): A function/transform that takes in a PIL image
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
loader (callable, optional): A function to load an image given its path.
By default, it uses PIL as its image loader, but users could also pass in
``torchvision.io.decode_image`` for decoding image data into tensors directly.
"""

_URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
Expand All @@ -48,6 +51,7 @@ def __init__(
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
loader: Callable[[str], Any] = default_loader,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
Expand Down Expand Up @@ -87,13 +91,14 @@ def __init__(
image_name, label_name = line.strip().split(" ", 1)
self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg"))
self._labels.append(self.class_to_idx[label_name])
self.loader = loader

def __len__(self) -> int:
return len(self._image_files)

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
image = self.loader(image_file)

if self.transform:
image = self.transform(image)
Expand Down
Loading