diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 43b4103646a..6a552a96923 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -611,6 +611,7 @@ class ImageDatasetTestCase(DatasetTestCase): """ FEATURE_TYPES = (PIL.Image.Image, int) + SUPPORT_TV_IMAGE_DECODE: bool = False @contextlib.contextmanager def create_dataset( @@ -632,22 +633,34 @@ def create_dataset( # This problem only occurs during testing since some tests, e.g. DatasetTestCase.test_feature_types open an # image, but never use the underlying data. During normal operation it is reasonable to assume that the # user wants to work with the image he just opened rather than deleting the underlying file. - with self._force_load_images(): + with self._force_load_images(loader=(config or {}).get("loader", None)): yield dataset, info @contextlib.contextmanager - def _force_load_images(self): - open = PIL.Image.open + def _force_load_images(self, loader: Optional[Callable[[str], Any]] = None): + open = loader or PIL.Image.open def new(fp, *args, **kwargs): image = open(fp, *args, **kwargs) - if isinstance(fp, (str, pathlib.Path)): + if isinstance(fp, (str, pathlib.Path)) and isinstance(image, PIL.Image.Image): image.load() return image - with unittest.mock.patch("PIL.Image.open", new=new): + with unittest.mock.patch(open.__module__ + "." + open.__qualname__, new=new): yield + def test_tv_decode_image_support(self): + if not self.SUPPORT_TV_IMAGE_DECODE: + pytest.skip(f"{self.DATASET_CLASS.__name__} does not support torchvision.io.decode_image.") + + with self.create_dataset( + config=dict( + loader=torchvision.io.decode_image, + ) + ) as (dataset, _): + image = dataset[0][0] + assert isinstance(image, torch.Tensor) + class VideoDatasetTestCase(DatasetTestCase): """Abstract base class for video dataset testcases. diff --git a/test/test_datasets.py b/test/test_datasets.py index 1c1d05ac42a..f98a18372a5 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -405,6 +405,8 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): REQUIRED_PACKAGES = ("scipy",) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val")) + SUPPORT_TV_IMAGE_DECODE = True + def inject_fake_data(self, tmpdir, config): tmpdir = pathlib.Path(tmpdir) @@ -2308,6 +2310,7 @@ def inject_fake_data(self, tmpdir, config): class EuroSATTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.EuroSAT FEATURE_TYPES = (PIL.Image.Image, int) + SUPPORT_TV_IMAGE_DECODE = True def inject_fake_data(self, tmpdir, config): data_folder = os.path.join(tmpdir, "eurosat", "2750") @@ -2749,6 +2752,8 @@ class Country211TestCase(datasets_utils.ImageDatasetTestCase): ADDITIONAL_CONFIGS = combinations_grid(split=("train", "valid", "test")) + SUPPORT_TV_IMAGE_DECODE = True + def inject_fake_data(self, tmpdir: str, config): split_folder = pathlib.Path(tmpdir) / "country211" / config["split"] split_folder.mkdir(parents=True, exist_ok=True) diff --git a/torchvision/datasets/country211.py b/torchvision/datasets/country211.py index a0f82ee1226..26b49552771 100644 --- a/torchvision/datasets/country211.py +++ b/torchvision/datasets/country211.py @@ -1,7 +1,7 @@ from pathlib import Path -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union -from .folder import ImageFolder +from .folder import default_loader, ImageFolder from .utils import download_and_extract_archive, verify_str_arg @@ -21,6 +21,9 @@ class Country211(ImageFolder): target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/country211/``. If dataset is already downloaded, it is not downloaded again. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ _URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz" @@ -33,6 +36,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Callable[[str], Any] = default_loader, ) -> None: self._split = verify_str_arg(split, "split", ("train", "valid", "test")) @@ -46,7 +50,12 @@ def __init__( if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform) + super().__init__( + str(self._base_folder / self._split), + transform=transform, + target_transform=target_transform, + loader=loader, + ) self.root = str(root) def _check_exists(self) -> bool: diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index c6571d2abab..5b96b067fba 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -1,8 +1,8 @@ import os from pathlib import Path -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union -from .folder import ImageFolder +from .folder import default_loader, ImageFolder from .utils import download_and_extract_archive @@ -21,6 +21,9 @@ class EuroSAT(ImageFolder): download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Default is False. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ def __init__( @@ -29,6 +32,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Callable[[str], Any] = default_loader, ) -> None: self.root = os.path.expanduser(root) self._base_folder = os.path.join(self.root, "eurosat") @@ -40,7 +44,12 @@ def __init__( if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - super().__init__(self._data_folder, transform=transform, target_transform=target_transform) + super().__init__( + self._data_folder, + transform=transform, + target_transform=target_transform, + loader=loader, + ) self.root = os.path.expanduser(root) def __len__(self) -> int: diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index d7caf328d2b..2d7e1e2f4d7 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -36,6 +36,8 @@ class ImageNet(ImageFolder): target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. Attributes: classes (list): List of the class name tuples.