From 2f710187864833d88f5f4bd74966c613225d78a6 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 26 Feb 2025 01:56:14 +0800 Subject: [PATCH 1/5] feats: add loader in classification task datasets. --- test/test_datasets.py | 37 +++++++++++++++++++++++++++ torchvision/datasets/clevr.py | 13 +++++++--- torchvision/datasets/country211.py | 4 +-- torchvision/datasets/dtd.py | 13 +++++++--- torchvision/datasets/eurosat.py | 2 +- torchvision/datasets/fgvc_aircraft.py | 11 +++++--- torchvision/datasets/flickr.py | 25 ++++++++++++------ torchvision/datasets/flowers102.py | 14 +++++++--- torchvision/datasets/folder.py | 2 +- torchvision/datasets/food101.py | 13 +++++++--- torchvision/datasets/imagenet.py | 2 +- torchvision/datasets/imagenette.py | 15 ++++++----- torchvision/datasets/lfw.py | 27 ++++++++++--------- torchvision/datasets/rendered_sst2.py | 15 ++++++----- torchvision/datasets/sbu.py | 11 +++++--- torchvision/datasets/stanford_cars.py | 15 +++++++---- torchvision/datasets/sun397.py | 13 +++++++--- 17 files changed, 164 insertions(+), 68 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index f98a18372a5..1413d2c312d 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -24,6 +24,7 @@ import torch.nn.functional as F from common_utils import combinations_grid from torchvision import datasets +from torchvision.io import decode_image from torchvision.transforms import v2 @@ -1175,6 +1176,8 @@ class SBUTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.SBU FEATURE_TYPES = (PIL.Image.Image, str) + SUPPORT_TV_IMAGE_DECODE = True + def inject_fake_data(self, tmpdir, config): num_images = 3 @@ -1413,6 +1416,8 @@ class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase): _IMAGES_FOLDER = "images" _ANNOTATIONS_FILE = "captions.html" + SUPPORT_TV_IMAGE_DECODE = True + def dataset_args(self, tmpdir, config): tmpdir = pathlib.Path(tmpdir) root = tmpdir / self._IMAGES_FOLDER @@ -1482,6 +1487,8 @@ class Flickr30kTestCase(Flickr8kTestCase): _ANNOTATIONS_FILE = "captions.token" + SUPPORT_TV_IMAGE_DECODE = True + def _image_file_name(self, idx): return f"{idx}.jpg" @@ -1942,6 +1949,8 @@ class LFWPeopleTestCase(datasets_utils.DatasetTestCase): _IMAGES_DIR = {"original": "lfw", "funneled": "lfw_funneled", "deepfunneled": "lfw-deepfunneled"} _file_id = {"10fold": "", "train": "DevTrain", "test": "DevTest"} + SUPPORT_TV_IMAGE_DECODE = True + def inject_fake_data(self, tmpdir, config): tmpdir = pathlib.Path(tmpdir) / "lfw-py" os.makedirs(tmpdir, exist_ok=True) @@ -1978,6 +1987,18 @@ def _create_random_id(self): part2 = datasets_utils.create_random_string(random.randint(4, 7)) return f"{part1}_{part2}" + def test_tv_decode_image_support(self): + if not self.SUPPORT_TV_IMAGE_DECODE: + pytest.skip(f"{self.DATASET_CLASS.__name__} does not support torchvision.io.decode_image.") + + with self.create_dataset( + config=dict( + loader=decode_image, + ) + ) as (dataset, _): + image = dataset[0][0] + assert isinstance(image, torch.Tensor) + class LFWPairsTestCase(LFWPeopleTestCase): DATASET_CLASS = datasets.LFWPairs @@ -2335,6 +2356,8 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase): ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test")) + SUPPORT_TV_IMAGE_DECODE = True + def inject_fake_data(self, tmpdir: str, config): root_folder = pathlib.Path(tmpdir) / "food-101" image_folder = root_folder / "images" @@ -2371,6 +2394,7 @@ class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase): ADDITIONAL_CONFIGS = combinations_grid( split=("train", "val", "trainval", "test"), annotation_level=("variant", "family", "manufacturer") ) + SUPPORT_TV_IMAGE_DECODE = True def inject_fake_data(self, tmpdir: str, config): split = config["split"] @@ -2420,6 +2444,8 @@ def inject_fake_data(self, tmpdir: str, config): class SUN397TestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.SUN397 + SUPPORT_TV_IMAGE_DECODE = True + def inject_fake_data(self, tmpdir: str, config): data_dir = pathlib.Path(tmpdir) / "SUN397" data_dir.mkdir() @@ -2451,6 +2477,8 @@ class DTDTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.DTD FEATURE_TYPES = (PIL.Image.Image, int) + SUPPORT_TV_IMAGE_DECODE = True + ADDITIONAL_CONFIGS = combinations_grid( split=("train", "test", "val"), # There is no need to test the whole matrix here, since each fold is treated exactly the same @@ -2611,6 +2639,7 @@ class CLEVRClassificationTestCase(datasets_utils.ImageDatasetTestCase): FEATURE_TYPES = (PIL.Image.Image, (int, type(None))) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test")) + SUPPORT_TV_IMAGE_DECODE = True def inject_fake_data(self, tmpdir, config): data_folder = pathlib.Path(tmpdir) / "clevr" / "CLEVR_v1.0" @@ -2708,6 +2737,8 @@ class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase): REQUIRED_PACKAGES = ("scipy",) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test")) + SUPPORT_TV_IMAGE_DECODE = True + def inject_fake_data(self, tmpdir, config): import scipy.io as io from numpy.core.records import fromarrays @@ -2782,6 +2813,8 @@ class Flowers102TestCase(datasets_utils.ImageDatasetTestCase): ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test")) REQUIRED_PACKAGES = ("scipy",) + SUPPORT_TV_IMAGE_DECODE = True + def inject_fake_data(self, tmpdir: str, config): base_folder = pathlib.Path(tmpdir) / "flowers-102" @@ -2840,6 +2873,8 @@ class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase): ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test")) SPLIT_TO_FOLDER = {"train": "train", "val": "valid", "test": "test"} + SUPPORT_TV_IMAGE_DECODE = True + def inject_fake_data(self, tmpdir: str, config): root_folder = pathlib.Path(tmpdir) / "rendered-sst2" image_folder = root_folder / self.SPLIT_TO_FOLDER[config["split"]] @@ -3500,6 +3535,8 @@ class ImagenetteTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Imagenette ADDITIONAL_CONFIGS = combinations_grid(split=["train", "val"], size=["full", "320px", "160px"]) + SUPPORT_TV_IMAGE_DECODE = True + _WNIDS = [ "n01440764", "n02102040", diff --git a/torchvision/datasets/clevr.py b/torchvision/datasets/clevr.py index 328eb7d79da..26cee162d32 100644 --- a/torchvision/datasets/clevr.py +++ b/torchvision/datasets/clevr.py @@ -3,7 +3,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union from urllib.parse import urlparse -from PIL import Image +from .folder import default_loader from .utils import download_and_extract_archive, verify_str_arg from .vision import VisionDataset @@ -18,11 +18,14 @@ class CLEVRClassification(VisionDataset): root (str or ``pathlib.Path``): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is set to True. split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. - transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed - version. E.g, ``transforms.RandomCrop`` + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, + and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in them target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ _URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip" @@ -35,9 +38,11 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Callable[[str], Any] = default_loader, ) -> None: self._split = verify_str_arg(split, "split", ("train", "val", "test")) super().__init__(root, transform=transform, target_transform=target_transform) + self.loader = loader self._base_folder = pathlib.Path(self.root) / "clevr" self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem @@ -65,7 +70,7 @@ def __getitem__(self, idx: int) -> Tuple[Any, Any]: image_file = self._image_files[idx] label = self._labels[idx] - image = Image.open(image_file).convert("RGB") + image = self.loader(image_file) if self.transform: image = self.transform(image) diff --git a/torchvision/datasets/country211.py b/torchvision/datasets/country211.py index 26b49552771..50d49db00a7 100644 --- a/torchvision/datasets/country211.py +++ b/torchvision/datasets/country211.py @@ -16,8 +16,8 @@ class Country211(ImageFolder): Args: root (str or ``pathlib.Path``): Root directory of the dataset. split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``. - transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed - version. E.g, ``transforms.RandomCrop``. + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, + and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/country211/``. If dataset is already downloaded, it is not downloaded again. diff --git a/torchvision/datasets/dtd.py b/torchvision/datasets/dtd.py index 71c556bd201..edf96c1ae57 100644 --- a/torchvision/datasets/dtd.py +++ b/torchvision/datasets/dtd.py @@ -2,7 +2,7 @@ import pathlib from typing import Any, Callable, Optional, Tuple, Union -import PIL.Image +from .folder import default_loader from .utils import download_and_extract_archive, verify_str_arg from .vision import VisionDataset @@ -21,12 +21,15 @@ class DTD(VisionDataset): The partition only changes which split each image belongs to. Thus, regardless of the selected partition, combining all splits will result in all images. - transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed - version. E.g, ``transforms.RandomCrop``. + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, + and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Default is False. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ _URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz" @@ -40,6 +43,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Callable[[str], Any] = default_loader, ) -> None: self._split = verify_str_arg(split, "split", ("train", "val", "test")) if not isinstance(partition, int) and not (1 <= partition <= 10): @@ -72,13 +76,14 @@ def __init__( self.classes = sorted(set(classes)) self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) self._labels = [self.class_to_idx[cls] for cls in classes] + self.loader = loader def __len__(self) -> int: return len(self._image_files) def __getitem__(self, idx: int) -> Tuple[Any, Any]: image_file, label = self._image_files[idx], self._labels[idx] - image = PIL.Image.open(image_file).convert("RGB") + image = self.loader(image_file) if self.transform: image = self.transform(image) diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index 5b96b067fba..4efec57029f 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -14,7 +14,7 @@ class EuroSAT(ImageFolder): Args: root (str or ``pathlib.Path``): Root directory of dataset where ``root/eurosat`` exists. - transform (callable, optional): A function/transform that takes in a PIL image + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. diff --git a/torchvision/datasets/fgvc_aircraft.py b/torchvision/datasets/fgvc_aircraft.py index bbf4e970a78..c0f2e147614 100644 --- a/torchvision/datasets/fgvc_aircraft.py +++ b/torchvision/datasets/fgvc_aircraft.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Any, Callable, Optional, Tuple, Union -import PIL.Image +from .folder import default_loader from .utils import download_and_extract_archive, verify_str_arg from .vision import VisionDataset @@ -29,13 +29,16 @@ class FGVCAircraft(VisionDataset): ``trainval`` and ``test``. annotation_level (str, optional): The annotation level, supports ``variant``, ``family`` and ``manufacturer``. - transform (callable, optional): A function/transform that takes in a PIL image + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ _URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz" @@ -48,6 +51,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Callable[[str], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test")) @@ -87,13 +91,14 @@ def __init__( image_name, label_name = line.strip().split(" ", 1) self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg")) self._labels.append(self.class_to_idx[label_name]) + self.loader = loader def __len__(self) -> int: return len(self._image_files) def __getitem__(self, idx: int) -> Tuple[Any, Any]: image_file, label = self._image_files[idx], self._labels[idx] - image = PIL.Image.open(image_file).convert("RGB") + image = self.loader(image_file) if self.transform: image = self.transform(image) diff --git a/torchvision/datasets/flickr.py b/torchvision/datasets/flickr.py index 1021309db05..c80978368dd 100644 --- a/torchvision/datasets/flickr.py +++ b/torchvision/datasets/flickr.py @@ -5,8 +5,7 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from PIL import Image - +from .folder import default_loader from .vision import VisionDataset @@ -59,10 +58,13 @@ class Flickr8k(VisionDataset): Args: root (str or ``pathlib.Path``): Root directory where images are downloaded to. ann_file (string): Path to annotation file. - transform (callable, optional): A function/transform that takes in a PIL image - and returns a transformed version. E.g, ``transforms.PILToTensor`` + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, + and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ def __init__( @@ -71,6 +73,7 @@ def __init__( ann_file: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self.ann_file = os.path.expanduser(ann_file) @@ -82,6 +85,7 @@ def __init__( self.annotations = parser.annotations self.ids = list(sorted(self.annotations.keys())) + self.loader = loader def __getitem__(self, index: int) -> Tuple[Any, Any]: """ @@ -94,7 +98,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: img_id = self.ids[index] # Image - img = Image.open(img_id).convert("RGB") + img = self.loader(img_id) if self.transform is not None: img = self.transform(img) @@ -115,10 +119,13 @@ class Flickr30k(VisionDataset): Args: root (str or ``pathlib.Path``): Root directory where images are downloaded to. ann_file (string): Path to annotation file. - transform (callable, optional): A function/transform that takes in a PIL image - and returns a transformed version. E.g, ``transforms.PILToTensor`` + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, + and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ def __init__( @@ -127,6 +134,7 @@ def __init__( ann_file: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self.ann_file = os.path.expanduser(ann_file) @@ -139,6 +147,7 @@ def __init__( self.annotations[img_id[:-2]].append(caption) self.ids = list(sorted(self.annotations.keys())) + self.loader = loader def __getitem__(self, index: int) -> Tuple[Any, Any]: """ @@ -152,7 +161,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: # Image filename = os.path.join(self.root, img_id) - img = Image.open(filename).convert("RGB") + img = self.loader(filename) if self.transform is not None: img = self.transform(img) diff --git a/torchvision/datasets/flowers102.py b/torchvision/datasets/flowers102.py index 707a7687858..c5ca5680037 100644 --- a/torchvision/datasets/flowers102.py +++ b/torchvision/datasets/flowers102.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Any, Callable, Optional, Tuple, Union -import PIL.Image +from .folder import default_loader from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg from .vision import VisionDataset @@ -24,12 +24,15 @@ class Flowers102(VisionDataset): Args: root (str or ``pathlib.Path``): Root directory of the dataset. split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. - transform (callable, optional): A function/transform that takes in a PIL image and returns a - transformed version. E.g, ``transforms.RandomCrop``. + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, + and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ _download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/" @@ -47,6 +50,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Callable[[str], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "val", "test")) @@ -73,12 +77,14 @@ def __init__( self._labels.append(image_id_to_label[image_id]) self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg") + self.loader = loader + def __len__(self) -> int: return len(self._image_files) def __getitem__(self, idx: int) -> Tuple[Any, Any]: image_file, label = self._image_files[idx], self._labels[idx] - image = PIL.Image.open(image_file).convert("RGB") + image = self.loader(image_file) if self.transform: image = self.transform(image) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 8f2f65c7b61..7fa4999a775 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -300,7 +300,7 @@ class ImageFolder(DatasetFolder): Args: root (str or ``pathlib.Path``): Root directory path. - transform (callable, optional): A function/transform that takes in a PIL image + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. diff --git a/torchvision/datasets/food101.py b/torchvision/datasets/food101.py index f734787c1bf..d49822b7fde 100644 --- a/torchvision/datasets/food101.py +++ b/torchvision/datasets/food101.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Any, Callable, Optional, Tuple, Union -import PIL.Image +from .folder import default_loader from .utils import download_and_extract_archive, verify_str_arg from .vision import VisionDataset @@ -21,12 +21,15 @@ class Food101(VisionDataset): Args: root (str or ``pathlib.Path``): Root directory of the dataset. split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. - transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed - version. E.g, ``transforms.RandomCrop``. + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, + and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Default is False. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz" @@ -39,6 +42,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Callable[[str], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "test")) @@ -65,13 +69,14 @@ def __init__( self._image_files += [ self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths ] + self.loader = loader def __len__(self) -> int: return len(self._image_files) def __getitem__(self, idx: int) -> Tuple[Any, Any]: image_file, label = self._image_files[idx], self._labels[idx] - image = PIL.Image.open(image_file).convert("RGB") + image = self.loader(image_file) if self.transform: image = self.transform(image) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 2d7e1e2f4d7..89492eec635 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -31,7 +31,7 @@ class ImageNet(ImageFolder): Args: root (str or ``pathlib.Path``): Root directory of the ImageNet Dataset. split (string, optional): The dataset split, supports ``train``, or ``val``. - transform (callable, optional): A function/transform that takes in a PIL image + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. diff --git a/torchvision/datasets/imagenette.py b/torchvision/datasets/imagenette.py index 0b27f3b25e5..7dac9243e08 100644 --- a/torchvision/datasets/imagenette.py +++ b/torchvision/datasets/imagenette.py @@ -1,9 +1,7 @@ from pathlib import Path from typing import Any, Callable, Optional, Tuple, Union -from PIL import Image - -from .folder import find_classes, make_dataset +from .folder import default_loader, find_classes, make_dataset from .utils import download_and_extract_archive, verify_str_arg from .vision import VisionDataset @@ -17,9 +15,12 @@ class Imagenette(VisionDataset): size (string, optional): The image size. Supports ``"full"`` (default), ``"320px"``, and ``"160px"``. download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already downloaded archives are not downloaded again. - transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed - version, e.g. ``transforms.RandomCrop``. + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, + and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. Attributes: classes (list): List of the class name tuples. @@ -54,6 +55,7 @@ def __init__( download=False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) @@ -75,6 +77,7 @@ def __init__( class_name: idx for wnid, idx in self.wnid_to_idx.items() for class_name in self._WNID_TO_CLASS[wnid] } self._samples = make_dataset(self._image_root, self.wnid_to_idx, extensions=".jpeg") + self.loader = loader def _check_exists(self) -> bool: return self._size_root.exists() @@ -87,7 +90,7 @@ def _download(self): def __getitem__(self, idx: int) -> Tuple[Any, Any]: path, label = self._samples[idx] - image = Image.open(path).convert("RGB") + image = self.loader(path) if self.transform is not None: image = self.transform(image) diff --git a/torchvision/datasets/lfw.py b/torchvision/datasets/lfw.py index 18374fc3c9b..efc5ee354d1 100644 --- a/torchvision/datasets/lfw.py +++ b/torchvision/datasets/lfw.py @@ -2,8 +2,7 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from PIL import Image - +from .folder import default_loader from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg from .vision import VisionDataset @@ -39,6 +38,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Callable[[str], Any] = default_loader, ) -> None: super().__init__(os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform) @@ -57,11 +57,7 @@ def __init__( raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") self.images_dir = os.path.join(self.root, images_dir) - - def _loader(self, path: str) -> Image.Image: - with open(path, "rb") as f: - img = Image.open(f) - return img.convert("RGB") + self._loader = loader def _check_integrity(self) -> bool: st1 = check_integrity(os.path.join(self.root, self.filename), self.md5) @@ -101,14 +97,16 @@ class LFWPeople(_LFW): ``10fold`` (default). image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or ``deepfunneled``. Defaults to ``funneled``. - transform (callable, optional): A function/transform that takes in a PIL image - and returns a transformed version. E.g, ``transforms.RandomRotation`` + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, + and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. - + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ def __init__( @@ -119,8 +117,9 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Callable[[str], Any] = default_loader, ) -> None: - super().__init__(root, split, image_set, "people", transform, target_transform, download) + super().__init__(root, split, image_set, "people", transform, target_transform, download, loader=loader) self.class_to_idx = self._get_classes() self.data, self.targets = self._get_people() @@ -190,6 +189,9 @@ class LFWPairs(_LFW): download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ @@ -201,8 +203,9 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Callable[[str], Any] = default_loader, ) -> None: - super().__init__(root, split, image_set, "pairs", transform, target_transform, download) + super().__init__(root, split, image_set, "pairs", transform, target_transform, download, loader=loader) self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir) diff --git a/torchvision/datasets/rendered_sst2.py b/torchvision/datasets/rendered_sst2.py index 48b0ddfc4fb..e8543c1e8a3 100644 --- a/torchvision/datasets/rendered_sst2.py +++ b/torchvision/datasets/rendered_sst2.py @@ -1,9 +1,7 @@ from pathlib import Path from typing import Any, Callable, Optional, Tuple, Union -import PIL.Image - -from .folder import make_dataset +from .folder import default_loader, make_dataset from .utils import download_and_extract_archive, verify_str_arg from .vision import VisionDataset @@ -22,12 +20,15 @@ class RenderedSST2(VisionDataset): Args: root (str or ``pathlib.Path``): Root directory of the dataset. split (string, optional): The dataset split, supports ``"train"`` (default), `"val"` and ``"test"``. - transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed - version. E.g, ``transforms.RandomCrop``. + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, + and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Default is False. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ _URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz" @@ -40,6 +41,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Callable[[str], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "val", "test")) @@ -55,13 +57,14 @@ def __init__( raise RuntimeError("Dataset not found. You can use download=True to download it") self._samples = make_dataset(str(self._base_folder / self._split_to_folder[self._split]), extensions=("png",)) + self.loader = loader def __len__(self) -> int: return len(self._samples) def __getitem__(self, idx: int) -> Tuple[Any, Any]: image_file, label = self._samples[idx] - image = PIL.Image.open(image_file).convert("RGB") + image = self.loader(image_file) if self.transform: image = self.transform(image) diff --git a/torchvision/datasets/sbu.py b/torchvision/datasets/sbu.py index b5f46101e07..fb82cccd380 100644 --- a/torchvision/datasets/sbu.py +++ b/torchvision/datasets/sbu.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Any, Callable, Optional, Tuple, Union -from PIL import Image +from .folder import default_loader from .utils import check_integrity, download_and_extract_archive, download_url from .vision import VisionDataset @@ -14,13 +14,16 @@ class SBU(VisionDataset): Args: root (str or ``pathlib.Path``): Root directory of dataset where tarball ``SBUCaptionedPhotoDataset.tar.gz`` exists. - transform (callable, optional): A function/transform that takes in a PIL image + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ url = "https://www.cs.rice.edu/~vo9/sbucaptions/SBUCaptionedPhotoDataset.tar.gz" @@ -33,8 +36,10 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = True, + loader: Callable[[str], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) + self.loader = loader if download: self.download() @@ -67,7 +72,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: tuple: (image, target) where target is a caption for the photo. """ filename = os.path.join(self.root, "dataset", self.photos[index]) - img = Image.open(filename).convert("RGB") + img = self.loader(filename) if self.transform is not None: img = self.transform(img) diff --git a/torchvision/datasets/stanford_cars.py b/torchvision/datasets/stanford_cars.py index 6264de82eb7..844555fee84 100644 --- a/torchvision/datasets/stanford_cars.py +++ b/torchvision/datasets/stanford_cars.py @@ -1,7 +1,7 @@ import pathlib from typing import Any, Callable, Optional, Tuple, Union -from PIL import Image +from .folder import default_loader from .utils import verify_str_arg from .vision import VisionDataset @@ -24,7 +24,7 @@ class StanfordCars(VisionDataset): Args: root (str or ``pathlib.Path``): Root directory of dataset split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. - transform (callable, optional): A function/transform that takes in a PIL image + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. @@ -36,6 +36,9 @@ class StanfordCars(VisionDataset): `dataset on Kaggle `_. In both cases, first download and configure the dataset locally, and use the dataset with ``"download=False"``. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ def __init__( @@ -45,6 +48,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Callable[[str], Any] = default_loader, ) -> None: try: @@ -84,6 +88,7 @@ def __init__( self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} + self.loader = loader def __len__(self) -> int: return len(self._samples) @@ -91,13 +96,13 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> Tuple[Any, Any]: """Returns pil_image and class_id for given index""" image_path, target = self._samples[idx] - pil_image = Image.open(image_path).convert("RGB") + image = self.loader(image_path) if self.transform is not None: - pil_image = self.transform(pil_image) + image = self.transform(image) if self.target_transform is not None: target = self.target_transform(target) - return pil_image, target + return image, target def _check_exists(self) -> bool: if not (self._base_folder / "devkit").is_dir(): diff --git a/torchvision/datasets/sun397.py b/torchvision/datasets/sun397.py index 4db0a3cf237..cfc9068f1eb 100644 --- a/torchvision/datasets/sun397.py +++ b/torchvision/datasets/sun397.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Any, Callable, Optional, Tuple, Union -import PIL.Image +from .folder import default_loader from .utils import download_and_extract_archive from .vision import VisionDataset @@ -15,12 +15,15 @@ class SUN397(VisionDataset): Args: root (str or ``pathlib.Path``): Root directory of the dataset. - transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed - version. E.g, ``transforms.RandomCrop``. + transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, + and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz" @@ -32,6 +35,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Callable[[str], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._data_dir = Path(self.root) / "SUN397" @@ -51,13 +55,14 @@ def __init__( self._labels = [ self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files ] + self.loader = loader def __len__(self) -> int: return len(self._image_files) def __getitem__(self, idx: int) -> Tuple[Any, Any]: image_file, label = self._image_files[idx], self._labels[idx] - image = PIL.Image.open(image_file).convert("RGB") + image = self.loader(image_file) if self.transform: image = self.transform(image) From 4196062ba47cb4543595acf8bf03004744d8b5bd Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 27 Feb 2025 20:50:30 +0800 Subject: [PATCH 2/5] add `Path` type in `pil_loader`, `accimage_loader` and `default_loader` acceptable input type. --- torchvision/datasets/folder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 7fa4999a775..3f946a3ff55 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -257,7 +257,7 @@ def __len__(self) -> int: IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") -def pil_loader(path: str) -> Image.Image: +def pil_loader(path: str | Path) -> Image.Image: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, "rb") as f: img = Image.open(f) @@ -265,7 +265,7 @@ def pil_loader(path: str) -> Image.Image: # TODO: specify the return type -def accimage_loader(path: str) -> Any: +def accimage_loader(path: str | Path) -> Any: import accimage try: @@ -275,7 +275,7 @@ def accimage_loader(path: str) -> Any: return pil_loader(path) -def default_loader(path: str) -> Any: +def default_loader(path: str | Path) -> Any: from torchvision import get_image_backend if get_image_backend() == "accimage": From c6720c7e07755da58d3bf6de4f0ad773b7f29693 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 27 Feb 2025 21:48:18 +0800 Subject: [PATCH 3/5] fix: fix type annotation. --- torchvision/datasets/clevr.py | 2 +- torchvision/datasets/dtd.py | 2 +- torchvision/datasets/flowers102.py | 2 +- torchvision/datasets/food101.py | 2 +- torchvision/datasets/sun397.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/datasets/clevr.py b/torchvision/datasets/clevr.py index 26cee162d32..bf053f6dda8 100644 --- a/torchvision/datasets/clevr.py +++ b/torchvision/datasets/clevr.py @@ -38,7 +38,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - loader: Callable[[str], Any] = default_loader, + loader: Callable[[str | pathlib.Path], Any] = default_loader, ) -> None: self._split = verify_str_arg(split, "split", ("train", "val", "test")) super().__init__(root, transform=transform, target_transform=target_transform) diff --git a/torchvision/datasets/dtd.py b/torchvision/datasets/dtd.py index edf96c1ae57..f61e68bc473 100644 --- a/torchvision/datasets/dtd.py +++ b/torchvision/datasets/dtd.py @@ -43,7 +43,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - loader: Callable[[str], Any] = default_loader, + loader: Callable[[str | pathlib.Path], Any] = default_loader, ) -> None: self._split = verify_str_arg(split, "split", ("train", "val", "test")) if not isinstance(partition, int) and not (1 <= partition <= 10): diff --git a/torchvision/datasets/flowers102.py b/torchvision/datasets/flowers102.py index c5ca5680037..81717e5aedf 100644 --- a/torchvision/datasets/flowers102.py +++ b/torchvision/datasets/flowers102.py @@ -50,7 +50,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - loader: Callable[[str], Any] = default_loader, + loader: Callable[[str | Path], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "val", "test")) diff --git a/torchvision/datasets/food101.py b/torchvision/datasets/food101.py index d49822b7fde..308c8ca0b60 100644 --- a/torchvision/datasets/food101.py +++ b/torchvision/datasets/food101.py @@ -42,7 +42,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - loader: Callable[[str], Any] = default_loader, + loader: Callable[[str | Path], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "test")) diff --git a/torchvision/datasets/sun397.py b/torchvision/datasets/sun397.py index cfc9068f1eb..d613a3a19a0 100644 --- a/torchvision/datasets/sun397.py +++ b/torchvision/datasets/sun397.py @@ -35,7 +35,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - loader: Callable[[str], Any] = default_loader, + loader: Callable[[str | Path], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._data_dir = Path(self.root) / "SUN397" From 6dabc65bfe23d9974c2281986f4fa6a3597ccbe2 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 27 Feb 2025 22:00:37 +0800 Subject: [PATCH 4/5] fix: use `Union[str, pathlib.Path]` instead of `|` annotation. --- torchvision/datasets/clevr.py | 2 +- torchvision/datasets/dtd.py | 2 +- torchvision/datasets/flowers102.py | 2 +- torchvision/datasets/food101.py | 2 +- torchvision/datasets/sun397.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/datasets/clevr.py b/torchvision/datasets/clevr.py index bf053f6dda8..6ce73fcb184 100644 --- a/torchvision/datasets/clevr.py +++ b/torchvision/datasets/clevr.py @@ -38,7 +38,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - loader: Callable[[str | pathlib.Path], Any] = default_loader, + loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader, ) -> None: self._split = verify_str_arg(split, "split", ("train", "val", "test")) super().__init__(root, transform=transform, target_transform=target_transform) diff --git a/torchvision/datasets/dtd.py b/torchvision/datasets/dtd.py index f61e68bc473..17a8bdb694e 100644 --- a/torchvision/datasets/dtd.py +++ b/torchvision/datasets/dtd.py @@ -43,7 +43,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - loader: Callable[[str | pathlib.Path], Any] = default_loader, + loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader, ) -> None: self._split = verify_str_arg(split, "split", ("train", "val", "test")) if not isinstance(partition, int) and not (1 <= partition <= 10): diff --git a/torchvision/datasets/flowers102.py b/torchvision/datasets/flowers102.py index 81717e5aedf..7b02270d967 100644 --- a/torchvision/datasets/flowers102.py +++ b/torchvision/datasets/flowers102.py @@ -50,7 +50,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - loader: Callable[[str | Path], Any] = default_loader, + loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "val", "test")) diff --git a/torchvision/datasets/food101.py b/torchvision/datasets/food101.py index 308c8ca0b60..a6a654f6ab1 100644 --- a/torchvision/datasets/food101.py +++ b/torchvision/datasets/food101.py @@ -42,7 +42,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - loader: Callable[[str | Path], Any] = default_loader, + loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "test")) diff --git a/torchvision/datasets/sun397.py b/torchvision/datasets/sun397.py index d613a3a19a0..7040416da32 100644 --- a/torchvision/datasets/sun397.py +++ b/torchvision/datasets/sun397.py @@ -35,7 +35,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - loader: Callable[[str | Path], Any] = default_loader, + loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._data_dir = Path(self.root) / "SUN397" From d1be85bfeea0a436f752399f4070220ea6d2018a Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 27 Feb 2025 22:10:40 +0800 Subject: [PATCH 5/5] fix: fix type annotation again. --- torchvision/datasets/flowers102.py | 2 +- torchvision/datasets/folder.py | 6 +++--- torchvision/datasets/food101.py | 2 +- torchvision/datasets/sun397.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchvision/datasets/flowers102.py b/torchvision/datasets/flowers102.py index 7b02270d967..b8cf01dd01b 100644 --- a/torchvision/datasets/flowers102.py +++ b/torchvision/datasets/flowers102.py @@ -50,7 +50,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader, + loader: Callable[[Union[str, Path]], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "val", "test")) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 3f946a3ff55..65e168791e4 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -257,7 +257,7 @@ def __len__(self) -> int: IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") -def pil_loader(path: str | Path) -> Image.Image: +def pil_loader(path: Union[str, Path]) -> Image.Image: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, "rb") as f: img = Image.open(f) @@ -265,7 +265,7 @@ def pil_loader(path: str | Path) -> Image.Image: # TODO: specify the return type -def accimage_loader(path: str | Path) -> Any: +def accimage_loader(path: Union[str, Path]) -> Any: import accimage try: @@ -275,7 +275,7 @@ def accimage_loader(path: str | Path) -> Any: return pil_loader(path) -def default_loader(path: str | Path) -> Any: +def default_loader(path: Union[str, Path]) -> Any: from torchvision import get_image_backend if get_image_backend() == "accimage": diff --git a/torchvision/datasets/food101.py b/torchvision/datasets/food101.py index a6a654f6ab1..107e60e7d33 100644 --- a/torchvision/datasets/food101.py +++ b/torchvision/datasets/food101.py @@ -42,7 +42,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader, + loader: Callable[[Union[str, Path]], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "test")) diff --git a/torchvision/datasets/sun397.py b/torchvision/datasets/sun397.py index 7040416da32..60115d93d4c 100644 --- a/torchvision/datasets/sun397.py +++ b/torchvision/datasets/sun397.py @@ -35,7 +35,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader, + loader: Callable[[Union[str, Path]], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._data_dir = Path(self.root) / "SUN397"