From 336524d30ff92d2acddb4b74a41755b245b512bd Mon Sep 17 00:00:00 2001 From: Ponku Date: Tue, 2 Aug 2022 16:01:40 +0100 Subject: [PATCH 1/3] Added Middlebury2014 dataset --- docs/source/datasets.rst | 1 + test/test_datasets.py | 93 ++++++++ torchvision/datasets/__init__.py | 3 +- torchvision/datasets/_stereo_matching.py | 257 ++++++++++++++++++++++- 4 files changed, 352 insertions(+), 2 deletions(-) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index af7ac072e31..40d5ca576e3 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -111,6 +111,7 @@ Stereo Matching CarlaStereo Kitti2012Stereo Kitti2015Stereo + Middlebury2014Stereo Image pairs ~~~~~~~~~~~ diff --git a/test/test_datasets.py b/test/test_datasets.py index 54696b0d6a8..337493c2208 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2841,5 +2841,98 @@ def test_train_splits(self): datasets_utils.shape_test_for_stereo(left, right, disparity) +class Middlebury2014StereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Middlebury2014Stereo + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + split=("train", "additional"), + calibration=("perfect", "imperfect", "both"), + use_ambient_views=(True, False), + ) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) + + @staticmethod + def _make_scene_folder(root_dir: str, scene_name: str, split: str) -> None: + calibrations = [None] if split == "test" else ["-perfect", "-imperfect"] + root_dir = pathlib.Path(root_dir) + + for c in calibrations: + scene_dir = root_dir / f"{scene_name}{c}" + os.makedirs(scene_dir, exist_ok=True) + # make normal images first + datasets_utils.create_image_file(root=scene_dir, name="im0.png", size=(3, 100, 100)) + datasets_utils.create_image_file(root=scene_dir, name="im1.png", size=(3, 100, 100)) + datasets_utils.create_image_file(root=scene_dir, name="im1E.png", size=(3, 100, 100)) + datasets_utils.create_image_file(root=scene_dir, name="im1L.png", size=(3, 100, 100)) + # these are going to end up being gray scale images + datasets_utils.make_fake_pfm_file(h=100, w=100, file_name=os.path.join(scene_dir, "disp0.pfm")) + datasets_utils.make_fake_pfm_file(h=100, w=100, file_name=os.path.join(scene_dir, "disp1.pfm")) + + def inject_fake_data(self, tmpdir, config): + split_scene_map = { + "train": ["Adirondack", "Jadeplant", "Motorcycle", "Piano"], + "additional": ["Backpack", "Bicycle1", "Cable", "Classroom1"], + "test": ["Plants", "Classroom2E", "Classroom2", "Australia"], + } + + middlebury_dir = pathlib.Path(tmpdir, "Middlebury2014") + os.makedirs(middlebury_dir, exist_ok=True) + + split_dir = middlebury_dir / config["split"] + os.makedirs(split_dir, exist_ok=True) + + num_examples = {"train": 2, "additional": 3, "test": 4}.get(config["split"], 0) + for idx in range(num_examples): + scene_name = split_scene_map[config["split"]][idx] + self._make_scene_folder(root_dir=split_dir, scene_name=scene_name, split=config["split"]) + + if config["calibration"] == "both": + num_examples *= 2 + return num_examples + + def test_train_splits(self): + for split, calibration in itertools.product(["train", "additional"], ["perfect", "imperfect", "both"]): + with self.create_dataset(split=split, calibration=calibration) as (dataset, _): + for left, right, disparity in dataset: + datasets_utils.shape_test_for_stereo(left, right, disparity) + + def test_test_split(self): + for split in ["test"]: + with self.create_dataset(split=split, calibration=None) as (dataset, _): + for left, right, disparity in dataset: + datasets_utils.shape_test_for_stereo(left, right, disparity) + + def test_augmented_view_usage(self): + with self.create_dataset(split="train", use_ambient_views=True) as (dataset, _): + for left, right, disparity in dataset: + datasets_utils.shape_test_for_stereo(left, right, disparity) + + def test_value_err_train(self): + # train set invalid + split = "train" + calibration = None + with pytest.raises( + ValueError, + match=f"Split '{split}' has calibration settings, however None was provided as an argument." + f"\nSetting calibration to 'perfect' for split '{split}'. Available calibration settings are: 'perfect', 'imperfect', 'both'.", + ): + with self.create_dataset(split=split, calibration=calibration): + pass + + def test_value_err_test(self): + # test set invalid + split = "test" + calibration = "perfect" + with pytest.raises( + ValueError, match="Split 'test' has only no calibration settings, please set `calibration=None`." + ): + with self.create_dataset(split=split, calibration=calibration): + pass + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index d8b6293fb42..15b4d06cea4 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,5 +1,5 @@ from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel -from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo +from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo, Middlebury2014Stereo from .caltech import Caltech101, Caltech256 from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 @@ -109,4 +109,5 @@ "Kitti2012Stereo", "Kitti2015Stereo", "CarlaStereo", + "Middlebury2014Stereo", ) diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py index de213fc0368..41c19c436f2 100644 --- a/torchvision/datasets/_stereo_matching.py +++ b/torchvision/datasets/_stereo_matching.py @@ -1,4 +1,7 @@ import functools +import os +import random +import shutil from abc import ABC, abstractmethod from glob import glob from pathlib import Path @@ -7,7 +10,7 @@ import numpy as np from PIL import Image -from .utils import _read_pfm, verify_str_arg +from .utils import _read_pfm, download_and_extract_archive, verify_str_arg from .vision import VisionDataset __all__ = () @@ -359,3 +362,255 @@ def __getitem__(self, index: int) -> Tuple: Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. """ return super().__getitem__(index) + + +class Middlebury2014Stereo(StereoMatchingDataset): + """Publicly available scenes from the Middlebury dataset `2014 version `. + + The dataset mostly follows the original format, without containing the ambient subdirectories. : :: + + root + Middlebury2014 + train + scene1-{perfect,imperfect} + calib.txt + im{0,1}.png + im1E.png + im1L.png + disp{0,1}.pfm + disp{0,1}-n.png + disp{0,1}-sd.pfm + disp{0,1}y.pfm + scene2-{perfect,imperfect} + calib.txt + im{0,1}.png + im1E.png + im1L.png + disp{0,1}.pfm + disp{0,1}-n.png + disp{0,1}-sd.pfm + disp{0,1}y.pfm + ... + additional + scene1-{perfect,imperfect} + calib.txt + im{0,1}.png + im1E.png + im1L.png + disp{0,1}.pfm + disp{0,1}-n.png + disp{0,1}-sd.pfm + disp{0,1}y.pfm + ... + test + scene1 + calib.txt + im{0,1}.png + scene2 + calib.txt + im{0,1}.png + ... + + Args: + root (string): Root directory of the Middleburry 2014 Dataset. + split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional" + use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible. + The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``. + calibration (string, optional): Wether or not to use the calibrated (default) or uncalibrated scenes. + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + download (boolean, optional): Wether or not to download the dataset in the ``root`` directory. + """ + + splits = { + "train": [ + "Adirondack", + "Jadeplant", + "Motorcycle", + "Piano", + "Pipes", + "Playroom", + "Playtable", + "Recycle", + "Shelves", + "Vintage", + ], + "additional": [ + "Backpack", + "Bicycle1", + "Cable", + "Classroom1", + "Couch", + "Flowers", + "Mask", + "Shopvac", + "Sticks", + "Storage", + "Sword1", + "Sword2", + "Umbrella", + ], + "test": [ + "Plants", + "Classroom2E", + "Classroom2", + "Australia", + "DjembeL", + "CrusadeP", + "Crusade", + "Hoops", + "Bicycle2", + "Staircase", + "Newkuba", + "AustraliaP", + "Djembe", + "Livingroom", + "Computer", + ], + } + + def __init__( + self, + root: str, + split: str = "train", + calibration: Optional[str] = "perfect", + use_ambient_views: bool = False, + transforms: Optional[Callable] = None, + download: bool = False, + ): + super().__init__(root, transforms) + + verify_str_arg(split, "split", valid_values=("train", "test", "additional")) + self.split = split + + if calibration: + verify_str_arg(calibration, "calibration", valid_values=("perfect", "imperfect", "both", None)) # type: ignore + if split == "test": + raise ValueError("Split 'test' has only no calibration settings, please set `calibration=None`.") + else: + if split != "test": + raise ValueError( + f"Split '{split}' has calibration settings, however None was provided as an argument." + f"\nSetting calibration to 'perfect' for split '{split}'. Available calibration settings are: 'perfect', 'imperfect', 'both'.", + ) + + if download: + self._download_dataset(root) + + root = Path(root) / "Middlebury2014" + + if not os.path.exists(root / split): + raise FileNotFoundError(f"The {split} directory was not found in the provided root directory") + + split_scenes = self.splits[split] + # check that the provided root folder contains the scene splits + if not any( + # using startswith to account for perfect / imperfect calibrartion + scene.startswith(s) + for scene in os.listdir(root / split) + for s in split_scenes + ): + raise FileNotFoundError(f"Provided root folder does not contain any scenes from the {split} split.") + + calibrartion_suffixes = { + None: [""], + "perfect": ["-perfect"], + "imperfect": ["-imperfect"], + "both": ["-perfect", "-imperfect"], + }[calibration] + + for calibration_suffix in calibrartion_suffixes: + scene_pattern = "*" + calibration_suffix + left_img_pattern = str(root / split / scene_pattern / "im0.png") + right_img_pattern = str(root / split / scene_pattern / "im1.png") + self._images += self._scan_pairs(left_img_pattern, right_img_pattern) + + if split == "test": + self._disparities = list((None, None) for _ in self._images) + else: + left_dispartity_pattern = str(root / split / scene_pattern / "disp0.pfm") + right_dispartity_pattern = str(root / split / scene_pattern / "disp1.pfm") + self._disparities += self._scan_pairs(left_dispartity_pattern, right_dispartity_pattern) + + self.use_ambient_views = use_ambient_views + + def _read_img(self, file_path: str) -> Image.Image: + """ + Function that reads either the original right image or an augmented view when ``use_ambient_views`` is True. + When ``use_ambient_views`` is True, the dataset will return at random one of ``[im1.png, im1E.png, im1L.png]`` + as the right image. + """ + ambient_file_paths: List[Union[str, Path]] # make mypy happy + + if os.path.basename(file_path) == "im1.png" and self.use_ambient_views: + base_path = Path(file_path).parent + # initialize sampleable container + ambient_file_paths = list(base_path / view_name for view_name in ["im1E.png", "im1L.png"]) + # double check that we're not going to try to read from an invalid file path + ambient_file_paths = list(filter(lambda p: os.path.exists(p), ambient_file_paths)) + # keep the original image as an option as well for uniform sampling between base views + ambient_file_paths.append(file_path) + file_path = random.choice(ambient_file_paths) # type: ignore + return super()._read_img(file_path) + + def _read_disparity(self, file_path: str) -> Tuple: + # test split has not disparity maps + if file_path is None: + return None, None + + disparity_map = _read_pfm_file(file_path) + disparity_map = np.abs(disparity_map) # ensure that the disparity is positive + valid_mask = None + return disparity_map, valid_mask + + def _download_dataset(self, root: str): + base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip" + # train and additional splits have 2 different calibration settings + root = Path(root) / "Middlebury2014" + split_name = self.split + + if split_name != "test": + for split_scene in self.splits[split_name]: + split_root = root / split_name + for calibration in ["perfect", "imperfect"]: + scene_name = f"{split_scene}-{calibration}" + scene_url = f"{base_url}/{scene_name}.zip" + print(f"Downloading {scene_url}") + # download the scene only if it doesn't exist + if not os.path.exists(split_root / scene_name): + download_and_extract_archive( + url=scene_url, + filename=f"{scene_name}.zip", + download_root=str(split_root), + remove_finished=True, + ) + else: + os.makedirs(root / "test") + if any(s not in os.listdir(root / "test") for s in self.splits["test"]): + # test split is downloaded from a different location + test_set_url = "https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-F.zip" + # the unzip is going to produce a directory MiddEval3 with two subdirectories trainingF and testF + # we want to move the contents from testF into the directory + download_and_extract_archive(url=test_set_url, download_root=str(root), remove_finished=True) + for scene_dir, scene_names, _ in os.walk(str(root / "MiddEval3/testF")): + for scene in scene_names: + scene_dst_dir = root / "test" + scene_src_dir = Path(scene_dir) / scene + os.makedirs(scene_dst_dir, exist_ok=True) + shutil.move(str(scene_src_dir), str(scene_dst_dir)) + + # cleanup MiddEval3 directory + shutil.rmtree(str(root / "MiddEval3")) + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3-tuple with ``(img_left, img_right, disparity)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + If a ``valid_mask`` is generated within the ``transforms`` parameter, + a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. + """ + return super().__getitem__(index) From c2f2214e3166f014f77c53d5bd7d4a712f0b60d8 Mon Sep 17 00:00:00 2001 From: Ponku Date: Fri, 9 Sep 2022 00:33:46 +0100 Subject: [PATCH 2/3] adressed nits --- test/test_datasets.py | 4 ++-- torchvision/datasets/_stereo_matching.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 6453e619c9a..e16f2a1609a 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -3241,8 +3241,8 @@ def _make_scene_folder(root_dir: str, scene_name: str, split: str) -> None: datasets_utils.create_image_file(root=scene_dir, name="im1E.png", size=(3, 100, 100)) datasets_utils.create_image_file(root=scene_dir, name="im1L.png", size=(3, 100, 100)) # these are going to end up being gray scale images - datasets_utils.make_fake_pfm_file(h=100, w=100, file_name=os.path.join(scene_dir, "disp0.pfm")) - datasets_utils.make_fake_pfm_file(h=100, w=100, file_name=os.path.join(scene_dir, "disp1.pfm")) + datasets_utils.make_fake_pfm_file(h=100, w=100, file_name=scene_dir / "disp0.pfm") + datasets_utils.make_fake_pfm_file(h=100, w=100, file_name=scene_dir / "disp1.pfm") def inject_fake_data(self, tmpdir, config): split_scene_map = { diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py index 34f086426e8..b5fc59dcfd3 100644 --- a/torchvision/datasets/_stereo_matching.py +++ b/torchvision/datasets/_stereo_matching.py @@ -544,8 +544,11 @@ def _read_img(self, file_path: str) -> Image.Image: """ ambient_file_paths: List[Union[str, Path]] # make mypy happy - if os.path.basename(file_path) == "im1.png" and self.use_ambient_views: - base_path = Path(file_path).parent + if not isinstance(file_path, Path): + file_path = Path(file_path) + + if file_path.name == "im1.png" and self.use_ambient_views: + base_path = file_path.parent # initialize sampleable container ambient_file_paths = list(base_path / view_name for view_name in ["im1E.png", "im1L.png"]) # double check that we're not going to try to read from an invalid file path @@ -580,7 +583,7 @@ def _download_dataset(self, root: str): scene_url = f"{base_url}/{scene_name}.zip" print(f"Downloading {scene_url}") # download the scene only if it doesn't exist - if not os.path.exists(split_root / scene_name): + if not (split_root / scene_name).exists(): download_and_extract_archive( url=scene_url, filename=f"{scene_name}.zip", From fa7dcc7cfed7d04436bda6a66ed9689b0f720ce4 Mon Sep 17 00:00:00 2001 From: Ponku Date: Fri, 9 Sep 2022 01:05:19 +0100 Subject: [PATCH 3/3] mypy fix --- torchvision/datasets/_stereo_matching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py index b5fc59dcfd3..cd06cfe1cab 100644 --- a/torchvision/datasets/_stereo_matching.py +++ b/torchvision/datasets/_stereo_matching.py @@ -52,7 +52,7 @@ def __init__(self, root: str, transforms: Optional[Callable] = None): self._images = [] # type: ignore self._disparities = [] # type: ignore - def _read_img(self, file_path: str) -> Image.Image: + def _read_img(self, file_path: Union[str, Path]) -> Image.Image: img = Image.open(file_path) if img.mode != "RGB": img = img.convert("RGB") @@ -536,7 +536,7 @@ def __init__( self.use_ambient_views = use_ambient_views - def _read_img(self, file_path: str) -> Image.Image: + def _read_img(self, file_path: Union[str, Path]) -> Image.Image: """ Function that reads either the original right image or an augmented view when ``use_ambient_views`` is True. When ``use_ambient_views`` is True, the dataset will return at random one of ``[im1.png, im1E.png, im1L.png]``