diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index af7ac072e31..6b22bc9c7ce 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -111,6 +111,7 @@ Stereo Matching CarlaStereo Kitti2012Stereo Kitti2015Stereo + SceneFlowStereo Image pairs ~~~~~~~~~~~ diff --git a/test/test_datasets.py b/test/test_datasets.py index 54696b0d6a8..dcf698387bc 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -13,7 +13,7 @@ import unittest import xml.etree.ElementTree as ET import zipfile -from typing import Union +from typing import Callable, Tuple, Union import datasets_utils import numpy as np @@ -2841,5 +2841,80 @@ def test_train_splits(self): datasets_utils.shape_test_for_stereo(left, right, disparity) +class SceneFlowStereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.SceneFlowStereo + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + variant=("FlyingThings3D", "Driving", "Monkaa"), pass_name=("clean", "final", "both") + ) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) + + @staticmethod + def _create_pfm_folder( + root: str, name: str, file_name_fn: Callable[..., str], num_examples: int, size: Tuple[int, int] + ) -> None: + root = pathlib.Path(root) / name + os.makedirs(root, exist_ok=True) + + for i in range(num_examples): + datasets_utils.make_fake_pfm_file(size[0], size[1], root / file_name_fn(i)) + + def inject_fake_data(self, tmpdir, config): + scene_flow_dir = pathlib.Path(tmpdir) / "SceneFlow" + os.makedirs(scene_flow_dir, exist_ok=True) + + variant_dir = scene_flow_dir / config["variant"] + os.makedirs(variant_dir, exist_ok=True) + + num_examples = {"FlyingThings3D": 4, "Driving": 6, "Monkaa": 5}.get(config["variant"], 0) + + passes = { + "clean": ["frames_cleanpass"], + "final": ["frames_finalpass"], + "both": ["frames_cleanpass", "frames_finalpass"], + }.get(config["pass_name"], []) + + for pass_dir_name in passes: + # create pass directories + pass_dir = variant_dir / pass_dir_name + disp_dir = variant_dir / "disparity" + os.makedirs(pass_dir, exist_ok=True) + os.makedirs(disp_dir, exist_ok=True) + + for direction in ["left", "right"]: + for scene_idx in range(num_examples): + os.makedirs(pass_dir / f"scene_{scene_idx:06d}", exist_ok=True) + datasets_utils.create_image_folder( + root=pass_dir / f"scene_{scene_idx:06d}", + name=direction, + file_name_fn=lambda i: f"{i:06d}.png", + num_examples=1, + size=(3, 200, 100), + ) + + os.makedirs(disp_dir / f"scene_{scene_idx:06d}", exist_ok=True) + self._create_pfm_folder( + root=disp_dir / f"scene_{scene_idx:06d}", + name=direction, + file_name_fn=lambda i: f"{i:06d}.pfm", + num_examples=1, + size=(100, 200), + ) + + if config["pass_name"] == "both": + num_examples *= 2 + return num_examples + + def test_splits(self): + for variant_name, pass_name in itertools.product(["FlyingThings3D", "Driving", "Monkaa"], ["clean", "final"]): + with self.create_dataset(variant=variant_name, pass_name=pass_name) as (dataset, _): + for left, right, disparity in dataset: + datasets_utils.shape_test_for_stereo(left, right, disparity) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument variant"): + with self.create_dataset(variant="bad"): + pass + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index d8b6293fb42..a2a29f9cb35 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,5 +1,5 @@ from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel -from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo +from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo, SceneFlowStereo from .caltech import Caltech101, Caltech256 from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 @@ -109,4 +109,5 @@ "Kitti2012Stereo", "Kitti2015Stereo", "CarlaStereo", + "SceneFlowStereo", ) diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py index de213fc0368..53ef4913ec5 100644 --- a/torchvision/datasets/_stereo_matching.py +++ b/torchvision/datasets/_stereo_matching.py @@ -359,3 +359,109 @@ def __getitem__(self, index: int) -> Tuple: Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. """ return super().__getitem__(index) + + +class SceneFlowStereo(StereoMatchingDataset): + """Dataset interface for `Scene Flow `_ datasets. + This interface provides access to the `FlyingThings3D, `Monkaa` and `Driving` datasets. + + The dataset is expected to have the following structre: :: + + root + SceneFlow + Monkaa + frames_cleanpass + scene1 + left + img1.png + img2.png + right + img1.png + img2.png + scene2 + left + img1.png + img2.png + right + img1.png + img2.png + frames_finalpass + scene1 + left + img1.png + img2.png + right + img1.png + img2.png + ... + ... + disparity + scene1 + left + img1.pfm + img2.pfm + right + img1.pfm + img2.pfm + FlyingThings3D + ... + ... + + Args: + root (string): Root directory where SceneFlow is located. + variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving". + pass_name (string): Which pass to use, "clean" (default), "final" or "both". + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + + """ + + def __init__( + self, + root: str, + variant: str = "FlyingThings3D", + pass_name: str = "clean", + transforms: Optional[Callable] = None, + ): + super().__init__(root, transforms) + + root = Path(root) / "SceneFlow" + + verify_str_arg(variant, "variant", valid_values=("FlyingThings3D", "Driving", "Monkaa")) + verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both")) + + passes = { + "clean": ["frames_cleanpass"], + "final": ["frames_finalpass"], + "both": ["frames_cleanpass", "frames_finalpass"], + }[pass_name] + + root = root / variant + + for p in passes: + left_image_pattern = str(root / p / "*" / "left" / "*.png") + right_image_pattern = str(root / p / "*" / "right" / "*.png") + self._images += self._scan_pairs(left_image_pattern, right_image_pattern) + + left_disparity_pattern = str(root / "disparity" / "*" / "left" / "*.pfm") + right_disparity_pattern = str(root / "disparity" / "*" / "right" / "*.pfm") + self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + + def _read_disparity(self, file_path: str) -> Tuple: + disparity_map = _read_pfm_file(file_path) + disparity_map = np.abs(disparity_map) # ensure that the disparity is positive + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3-tuple with ``(img_left, img_right, disparity)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + If a ``valid_mask`` is generated within the ``transforms`` parameter, + a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. + """ + return super().__getitem__(index)