diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst
index d1346bb4bdd..794c4f3fa75 100644
--- a/docs/source/datasets.rst
+++ b/docs/source/datasets.rst
@@ -117,6 +117,7 @@ Stereo Matching
SintelStereo
InStereo2k
ETH3DStereo
+ Middlebury2014Stereo
Image pairs
~~~~~~~~~~~
diff --git a/test/test_datasets.py b/test/test_datasets.py
index b5ca24ab9c9..e16f2a1609a 100644
--- a/test/test_datasets.py
+++ b/test/test_datasets.py
@@ -3218,5 +3218,98 @@ def test_bad_input(self):
pass
+class Middlebury2014StereoTestCase(datasets_utils.ImageDatasetTestCase):
+ DATASET_CLASS = datasets.Middlebury2014Stereo
+ ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
+ split=("train", "additional"),
+ calibration=("perfect", "imperfect", "both"),
+ use_ambient_views=(True, False),
+ )
+ FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
+
+ @staticmethod
+ def _make_scene_folder(root_dir: str, scene_name: str, split: str) -> None:
+ calibrations = [None] if split == "test" else ["-perfect", "-imperfect"]
+ root_dir = pathlib.Path(root_dir)
+
+ for c in calibrations:
+ scene_dir = root_dir / f"{scene_name}{c}"
+ os.makedirs(scene_dir, exist_ok=True)
+ # make normal images first
+ datasets_utils.create_image_file(root=scene_dir, name="im0.png", size=(3, 100, 100))
+ datasets_utils.create_image_file(root=scene_dir, name="im1.png", size=(3, 100, 100))
+ datasets_utils.create_image_file(root=scene_dir, name="im1E.png", size=(3, 100, 100))
+ datasets_utils.create_image_file(root=scene_dir, name="im1L.png", size=(3, 100, 100))
+ # these are going to end up being gray scale images
+ datasets_utils.make_fake_pfm_file(h=100, w=100, file_name=scene_dir / "disp0.pfm")
+ datasets_utils.make_fake_pfm_file(h=100, w=100, file_name=scene_dir / "disp1.pfm")
+
+ def inject_fake_data(self, tmpdir, config):
+ split_scene_map = {
+ "train": ["Adirondack", "Jadeplant", "Motorcycle", "Piano"],
+ "additional": ["Backpack", "Bicycle1", "Cable", "Classroom1"],
+ "test": ["Plants", "Classroom2E", "Classroom2", "Australia"],
+ }
+
+ middlebury_dir = pathlib.Path(tmpdir, "Middlebury2014")
+ os.makedirs(middlebury_dir, exist_ok=True)
+
+ split_dir = middlebury_dir / config["split"]
+ os.makedirs(split_dir, exist_ok=True)
+
+ num_examples = {"train": 2, "additional": 3, "test": 4}.get(config["split"], 0)
+ for idx in range(num_examples):
+ scene_name = split_scene_map[config["split"]][idx]
+ self._make_scene_folder(root_dir=split_dir, scene_name=scene_name, split=config["split"])
+
+ if config["calibration"] == "both":
+ num_examples *= 2
+ return num_examples
+
+ def test_train_splits(self):
+ for split, calibration in itertools.product(["train", "additional"], ["perfect", "imperfect", "both"]):
+ with self.create_dataset(split=split, calibration=calibration) as (dataset, _):
+ for left, right, disparity, mask in dataset:
+ datasets_utils.shape_test_for_stereo(left, right, disparity, mask)
+
+ def test_test_split(self):
+ for split in ["test"]:
+ with self.create_dataset(split=split, calibration=None) as (dataset, _):
+ for left, right, disparity, mask in dataset:
+ datasets_utils.shape_test_for_stereo(left, right)
+
+ def test_augmented_view_usage(self):
+ with self.create_dataset(split="train", use_ambient_views=True) as (dataset, _):
+ for left, right, disparity, mask in dataset:
+ datasets_utils.shape_test_for_stereo(left, right, disparity, mask)
+
+ def test_value_err_train(self):
+ # train set invalid
+ split = "train"
+ calibration = None
+ with pytest.raises(
+ ValueError,
+ match=f"Split '{split}' has calibration settings, however None was provided as an argument."
+ f"\nSetting calibration to 'perfect' for split '{split}'. Available calibration settings are: 'perfect', 'imperfect', 'both'.",
+ ):
+ with self.create_dataset(split=split, calibration=calibration):
+ pass
+
+ def test_value_err_test(self):
+ # test set invalid
+ split = "test"
+ calibration = "perfect"
+ with pytest.raises(
+ ValueError, match="Split 'test' has only no calibration settings, please set `calibration=None`."
+ ):
+ with self.create_dataset(split=split, calibration=calibration):
+ pass
+
+ def test_bad_input(self):
+ with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
+ with self.create_dataset(split="bad"):
+ pass
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py
index e809bbf1695..e20ff1d85d2 100644
--- a/torchvision/datasets/__init__.py
+++ b/torchvision/datasets/__init__.py
@@ -7,6 +7,7 @@
InStereo2k,
Kitti2012Stereo,
Kitti2015Stereo,
+ Middlebury2014Stereo,
SceneFlowStereo,
SintelStereo,
)
@@ -119,6 +120,7 @@
"Kitti2012Stereo",
"Kitti2015Stereo",
"CarlaStereo",
+ "Middlebury2014Stereo",
"CREStereo",
"FallingThingsStereo",
"SceneFlowStereo",
diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py
index 14fe1b60f44..cd06cfe1cab 100644
--- a/torchvision/datasets/_stereo_matching.py
+++ b/torchvision/datasets/_stereo_matching.py
@@ -1,6 +1,8 @@
import functools
import json
import os
+import random
+import shutil
from abc import ABC, abstractmethod
from glob import glob
from pathlib import Path
@@ -9,7 +11,7 @@
import numpy as np
from PIL import Image
-from .utils import _read_pfm, verify_str_arg
+from .utils import _read_pfm, download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
__all__ = ()
@@ -50,7 +52,7 @@ def __init__(self, root: str, transforms: Optional[Callable] = None):
self._images = [] # type: ignore
self._disparities = [] # type: ignore
- def _read_img(self, file_path: str) -> Image.Image:
+ def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
img = Image.open(file_path)
if img.mode != "RGB":
img = img.convert("RGB")
@@ -363,6 +365,263 @@ def __getitem__(self, index: int) -> Tuple:
return super().__getitem__(index)
+class Middlebury2014Stereo(StereoMatchingDataset):
+ """Publicly available scenes from the Middlebury dataset `2014 version `.
+
+ The dataset mostly follows the original format, without containing the ambient subdirectories. : ::
+
+ root
+ Middlebury2014
+ train
+ scene1-{perfect,imperfect}
+ calib.txt
+ im{0,1}.png
+ im1E.png
+ im1L.png
+ disp{0,1}.pfm
+ disp{0,1}-n.png
+ disp{0,1}-sd.pfm
+ disp{0,1}y.pfm
+ scene2-{perfect,imperfect}
+ calib.txt
+ im{0,1}.png
+ im1E.png
+ im1L.png
+ disp{0,1}.pfm
+ disp{0,1}-n.png
+ disp{0,1}-sd.pfm
+ disp{0,1}y.pfm
+ ...
+ additional
+ scene1-{perfect,imperfect}
+ calib.txt
+ im{0,1}.png
+ im1E.png
+ im1L.png
+ disp{0,1}.pfm
+ disp{0,1}-n.png
+ disp{0,1}-sd.pfm
+ disp{0,1}y.pfm
+ ...
+ test
+ scene1
+ calib.txt
+ im{0,1}.png
+ scene2
+ calib.txt
+ im{0,1}.png
+ ...
+
+ Args:
+ root (string): Root directory of the Middleburry 2014 Dataset.
+ split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
+ use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
+ The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
+ calibration (string, optional): Wether or not to use the calibrated (default) or uncalibrated scenes.
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+ download (boolean, optional): Wether or not to download the dataset in the ``root`` directory.
+ """
+
+ splits = {
+ "train": [
+ "Adirondack",
+ "Jadeplant",
+ "Motorcycle",
+ "Piano",
+ "Pipes",
+ "Playroom",
+ "Playtable",
+ "Recycle",
+ "Shelves",
+ "Vintage",
+ ],
+ "additional": [
+ "Backpack",
+ "Bicycle1",
+ "Cable",
+ "Classroom1",
+ "Couch",
+ "Flowers",
+ "Mask",
+ "Shopvac",
+ "Sticks",
+ "Storage",
+ "Sword1",
+ "Sword2",
+ "Umbrella",
+ ],
+ "test": [
+ "Plants",
+ "Classroom2E",
+ "Classroom2",
+ "Australia",
+ "DjembeL",
+ "CrusadeP",
+ "Crusade",
+ "Hoops",
+ "Bicycle2",
+ "Staircase",
+ "Newkuba",
+ "AustraliaP",
+ "Djembe",
+ "Livingroom",
+ "Computer",
+ ],
+ }
+
+ _has_built_in_disparity_mask = True
+
+ def __init__(
+ self,
+ root: str,
+ split: str = "train",
+ calibration: Optional[str] = "perfect",
+ use_ambient_views: bool = False,
+ transforms: Optional[Callable] = None,
+ download: bool = False,
+ ):
+ super().__init__(root, transforms)
+
+ verify_str_arg(split, "split", valid_values=("train", "test", "additional"))
+ self.split = split
+
+ if calibration:
+ verify_str_arg(calibration, "calibration", valid_values=("perfect", "imperfect", "both", None)) # type: ignore
+ if split == "test":
+ raise ValueError("Split 'test' has only no calibration settings, please set `calibration=None`.")
+ else:
+ if split != "test":
+ raise ValueError(
+ f"Split '{split}' has calibration settings, however None was provided as an argument."
+ f"\nSetting calibration to 'perfect' for split '{split}'. Available calibration settings are: 'perfect', 'imperfect', 'both'.",
+ )
+
+ if download:
+ self._download_dataset(root)
+
+ root = Path(root) / "Middlebury2014"
+
+ if not os.path.exists(root / split):
+ raise FileNotFoundError(f"The {split} directory was not found in the provided root directory")
+
+ split_scenes = self.splits[split]
+ # check that the provided root folder contains the scene splits
+ if not any(
+ # using startswith to account for perfect / imperfect calibrartion
+ scene.startswith(s)
+ for scene in os.listdir(root / split)
+ for s in split_scenes
+ ):
+ raise FileNotFoundError(f"Provided root folder does not contain any scenes from the {split} split.")
+
+ calibrartion_suffixes = {
+ None: [""],
+ "perfect": ["-perfect"],
+ "imperfect": ["-imperfect"],
+ "both": ["-perfect", "-imperfect"],
+ }[calibration]
+
+ for calibration_suffix in calibrartion_suffixes:
+ scene_pattern = "*" + calibration_suffix
+ left_img_pattern = str(root / split / scene_pattern / "im0.png")
+ right_img_pattern = str(root / split / scene_pattern / "im1.png")
+ self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
+
+ if split == "test":
+ self._disparities = list((None, None) for _ in self._images)
+ else:
+ left_dispartity_pattern = str(root / split / scene_pattern / "disp0.pfm")
+ right_dispartity_pattern = str(root / split / scene_pattern / "disp1.pfm")
+ self._disparities += self._scan_pairs(left_dispartity_pattern, right_dispartity_pattern)
+
+ self.use_ambient_views = use_ambient_views
+
+ def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
+ """
+ Function that reads either the original right image or an augmented view when ``use_ambient_views`` is True.
+ When ``use_ambient_views`` is True, the dataset will return at random one of ``[im1.png, im1E.png, im1L.png]``
+ as the right image.
+ """
+ ambient_file_paths: List[Union[str, Path]] # make mypy happy
+
+ if not isinstance(file_path, Path):
+ file_path = Path(file_path)
+
+ if file_path.name == "im1.png" and self.use_ambient_views:
+ base_path = file_path.parent
+ # initialize sampleable container
+ ambient_file_paths = list(base_path / view_name for view_name in ["im1E.png", "im1L.png"])
+ # double check that we're not going to try to read from an invalid file path
+ ambient_file_paths = list(filter(lambda p: os.path.exists(p), ambient_file_paths))
+ # keep the original image as an option as well for uniform sampling between base views
+ ambient_file_paths.append(file_path)
+ file_path = random.choice(ambient_file_paths) # type: ignore
+ return super()._read_img(file_path)
+
+ def _read_disparity(self, file_path: str) -> Tuple:
+ # test split has not disparity maps
+ if file_path is None:
+ return None, None
+
+ disparity_map = _read_pfm_file(file_path)
+ disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
+ disparity_map[disparity_map == np.inf] = 0 # remove infinite disparities
+ valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
+ return disparity_map, valid_mask
+
+ def _download_dataset(self, root: str):
+ base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
+ # train and additional splits have 2 different calibration settings
+ root = Path(root) / "Middlebury2014"
+ split_name = self.split
+
+ if split_name != "test":
+ for split_scene in self.splits[split_name]:
+ split_root = root / split_name
+ for calibration in ["perfect", "imperfect"]:
+ scene_name = f"{split_scene}-{calibration}"
+ scene_url = f"{base_url}/{scene_name}.zip"
+ print(f"Downloading {scene_url}")
+ # download the scene only if it doesn't exist
+ if not (split_root / scene_name).exists():
+ download_and_extract_archive(
+ url=scene_url,
+ filename=f"{scene_name}.zip",
+ download_root=str(split_root),
+ remove_finished=True,
+ )
+ else:
+ os.makedirs(root / "test")
+ if any(s not in os.listdir(root / "test") for s in self.splits["test"]):
+ # test split is downloaded from a different location
+ test_set_url = "https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-F.zip"
+ # the unzip is going to produce a directory MiddEval3 with two subdirectories trainingF and testF
+ # we want to move the contents from testF into the directory
+ download_and_extract_archive(url=test_set_url, download_root=str(root), remove_finished=True)
+ for scene_dir, scene_names, _ in os.walk(str(root / "MiddEval3/testF")):
+ for scene in scene_names:
+ scene_dst_dir = root / "test"
+ scene_src_dir = Path(scene_dir) / scene
+ os.makedirs(scene_dst_dir, exist_ok=True)
+ shutil.move(str(scene_src_dir), str(scene_dst_dir))
+
+ # cleanup MiddEval3 directory
+ shutil.rmtree(str(root / "MiddEval3"))
+
+ def __getitem__(self, index: int) -> Tuple:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+ ``valid_mask`` is implicitly ``None`` for `split=test`.
+ """
+ return super().__getitem__(index)
+
+
class CREStereo(StereoMatchingDataset):
"""Synthetic dataset used in training the `CREStereo `_ architecture.
Dataset details on the official paper `repo `_.
@@ -432,7 +691,7 @@ def __init__(
def _read_disparity(self, file_path: str) -> Tuple:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze the disparity map into (C, H, W) format
- disparity_map = disparity_map[None, :, :] / 256.0
+ disparity_map = disparity_map[None, :, :] / 32.0
valid_mask = None
return disparity_map, valid_mask