diff --git a/test/test_datasets.py b/test/test_datasets.py index 1413d2c312d..8d4eca688a2 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2038,6 +2038,8 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase): FLOW_H, FLOW_W = 3, 4 + SUPPORT_TV_IMAGE_DECODE = True + def inject_fake_data(self, tmpdir, config): root = pathlib.Path(tmpdir) / "Sintel" @@ -2104,6 +2106,8 @@ class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase): ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test")) FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) + SUPPORT_TV_IMAGE_DECODE = True + def inject_fake_data(self, tmpdir, config): root = pathlib.Path(tmpdir) / "KittiFlow" @@ -2223,6 +2227,8 @@ class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase): FLOW_H, FLOW_W = 3, 4 + SUPPORT_TV_IMAGE_DECODE = True + def inject_fake_data(self, tmpdir, config): root = pathlib.Path(tmpdir) / "FlyingThings3D" @@ -2289,6 +2295,8 @@ def test_bad_input(self): class HD1KTestCase(KittiFlowTestCase): DATASET_CLASS = datasets.HD1K + SUPPORT_TV_IMAGE_DECODE = True + def inject_fake_data(self, tmpdir, config): root = pathlib.Path(tmpdir) / "hd1k" diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index e8d6247f03f..9ee4c4df52f 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -3,13 +3,14 @@ from abc import ABC, abstractmethod from glob import glob from pathlib import Path -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import torch from PIL import Image from ..io.image import decode_png, read_file +from .folder import default_loader from .utils import _read_pfm, verify_str_arg from .vision import VisionDataset @@ -32,19 +33,22 @@ class FlowDataset(ABC, VisionDataset): # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be. _has_builtin_flow_mask = False - def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None: + def __init__( + self, + root: Union[str, Path], + transforms: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, + ) -> None: super().__init__(root=root) self.transforms = transforms self._flow_list: List[str] = [] self._image_list: List[List[str]] = [] + self._loader = loader - def _read_img(self, file_name: str) -> Image.Image: - img = Image.open(file_name) - if img.mode != "RGB": - img = img.convert("RGB") # type: ignore[assignment] - return img + def _read_img(self, file_name: str) -> Union[Image.Image, torch.Tensor]: + return self._loader(file_name) @abstractmethod def _read_flow(self, file_name: str): @@ -70,9 +74,9 @@ def __getitem__(self, index: int) -> Union[T1, T2]: if self._has_builtin_flow_mask or valid_flow_mask is not None: # The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform - return img1, img2, flow, valid_flow_mask + return img1, img2, flow, valid_flow_mask # type: ignore[return-value] else: - return img1, img2, flow + return img1, img2, flow # type: ignore[return-value] def __len__(self) -> int: return len(self._image_list) @@ -120,6 +124,9 @@ class Sintel(FlowDataset): ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. ``valid_flow_mask`` is expected for consistency with other datasets which return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ def __init__( @@ -128,8 +135,9 @@ def __init__( split: str = "train", pass_name: str = "clean", transforms: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, ) -> None: - super().__init__(root=root, transforms=transforms) + super().__init__(root=root, transforms=transforms, loader=loader) verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both")) @@ -186,12 +194,21 @@ class KittiFlow(FlowDataset): split (string, optional): The dataset split, either "train" (default) or "test" transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ _has_builtin_flow_mask = True - def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None: - super().__init__(root=root, transforms=transforms) + def __init__( + self, + root: Union[str, Path], + split: str = "train", + transforms: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, + ) -> None: + super().__init__(root=root, transforms=transforms, loader=loader) verify_str_arg(split, "split", valid_values=("train", "test")) @@ -324,6 +341,9 @@ class FlyingThings3D(FlowDataset): ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. ``valid_flow_mask`` is expected for consistency with other datasets which return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ def __init__( @@ -333,8 +353,9 @@ def __init__( pass_name: str = "clean", camera: str = "left", transforms: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, ) -> None: - super().__init__(root=root, transforms=transforms) + super().__init__(root=root, transforms=transforms, loader=loader) verify_str_arg(split, "split", valid_values=("train", "test")) split = split.upper() @@ -414,12 +435,21 @@ class HD1K(FlowDataset): split (string, optional): The dataset split, either "train" (default) or "test" transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ _has_builtin_flow_mask = True - def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None: - super().__init__(root=root, transforms=transforms) + def __init__( + self, + root: Union[str, Path], + split: str = "train", + transforms: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, + ) -> None: + super().__init__(root=root, transforms=transforms, loader=loader) verify_str_arg(split, "split", valid_values=("train", "test"))