From e326295a02cd18017264e14dcb513bd1a04d56b8 Mon Sep 17 00:00:00 2001 From: Ponku Date: Mon, 25 Jul 2022 16:58:59 +0100 Subject: [PATCH 1/5] Broken down PR(#6269). Added an additional dataset --- test/datasets_utils.py | 46 ++++ test/test_datasets.py | 168 ++++++++++++ torchvision/datasets/__init__.py | 4 + torchvision/datasets/_stereo_matching.py | 328 +++++++++++++++++++++++ 4 files changed, 546 insertions(+) create mode 100644 torchvision/datasets/_stereo_matching.py diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 2043caae0a2..eafff3a4371 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -16,6 +16,8 @@ from collections import defaultdict from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union +import numpy as np + import PIL import PIL.Image import pytest @@ -23,6 +25,7 @@ import torchvision.datasets import torchvision.io from common_utils import disable_console_output, get_tmp_dir +from torchvision.transforms.functional import get_dimensions __all__ = [ @@ -748,6 +751,49 @@ def size(idx: int) -> Tuple[int, int, int]: ] +def shape_test_for_stereo_gt_w_mask( + left: PIL.Image.Image, right: PIL.Image.Image, disparity: np.ndarray, valid_mask: np.ndarray +): + left_dims = get_dimensions(left) + right_dims = get_dimensions(right) + c, h, w = left_dims + # check that left and right are the same size + assert left_dims == right_dims + # check general shapes + assert c == 3 + assert disparity.ndim == 3 + assert disparity.shape == (1, h, w) + # check that valid mask is the same size as the disparity + + _, dh, dw = disparity.shape + mh, mw = valid_mask.shape + assert dh == mh + assert dw == mw + + +def shape_test_for_stereo_gt_no_mask(left: PIL.Image.Image, right: PIL.Image.Image, disparity: np.ndarray): + left_dims = get_dimensions(left) + right_dims = get_dimensions(right) + c, h, w = left_dims + # check that left and right are the same size + assert left_dims == right_dims + # check general shapes + assert c == 3 + assert disparity.ndim == 3 + assert disparity.shape == (1, h, w) + + +def shape_test_for_stereo_no_gt(left: PIL.Image.Image, right: PIL.Image.Image, disparity: None): + left_dims = get_dimensions(left) + right_dims = get_dimensions(right) + c, _, _ = left_dims + # check that left and right are the same size + assert left_dims == right_dims + # check general shapes + assert c == 3 + assert disparity is None + + @requires_lazy_imports("av") def create_video_file( root: Union[pathlib.Path, str], diff --git a/test/test_datasets.py b/test/test_datasets.py index a108479aee3..e68634dddc1 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -13,6 +13,7 @@ import unittest import xml.etree.ElementTree as ET import zipfile +from typing import Union import datasets_utils import numpy as np @@ -2671,5 +2672,172 @@ def inject_fake_data(self, tmpdir: str, config): return len(sampled_classes) * num_images_per_class[config["split"]] +class Kitti2012StereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Kitti2012Stereo + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) + + def inject_fake_data(self, tmpdir, config): + kitti_dir = pathlib.Path(tmpdir) / "Kitti2012" + os.makedirs(kitti_dir, exist_ok=True) + + split_dir = kitti_dir / (config["split"] + "ing") + os.makedirs(split_dir, exist_ok=True) + + num_examples = {"train": 4, "test": 3}.get(config["split"], 0) + + datasets_utils.create_image_folder( + root=split_dir, + name="colored_0", + file_name_fn=lambda i: f"{i:06d}_10.png", + num_examples=num_examples, + size=(3, 100, 200), + ) + datasets_utils.create_image_folder( + root=split_dir, + name="colored_1", + file_name_fn=lambda i: f"{i:06d}_10.png", + num_examples=num_examples, + size=(3, 100, 200), + ) + + if config["split"] == "train": + datasets_utils.create_image_folder( + root=split_dir, + name="disp_noc", + file_name_fn=lambda i: f"{i:06d}.png", + num_examples=num_examples, + # Kitti2012 uses a single channel image for disparities + size=(1, 100, 200), + ) + + return num_examples + + def test_train_splits(self): + for split in ["train"]: + with self.create_dataset(split=split) as (dataset, _): + for left, right, disparity, mask in dataset: + assert mask is None + datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity) + + def test_test_split(self): + for split in ["test"]: + with self.create_dataset(split=split) as (dataset, _): + for left, right, disparity, mask in dataset: + assert mask is None + datasets_utils.shape_test_for_stereo_no_gt(left, right, disparity) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + +class Kitti2015StereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Kitti2015Stereo + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) + + def inject_fake_data(self, tmpdir, config): + kitti_dir = pathlib.Path(tmpdir) / "Kitti2015" + os.makedirs(kitti_dir, exist_ok=True) + + split_dir = kitti_dir / (config["split"] + "ing") + os.makedirs(split_dir, exist_ok=True) + + num_examples = {"train": 4, "test": 6}.get(config["split"], 0) + + datasets_utils.create_image_folder( + root=split_dir, + name="image_2", + file_name_fn=lambda i: f"{i:06d}_10.png", + num_examples=num_examples, + size=(3, 100, 200), + ) + datasets_utils.create_image_folder( + root=split_dir, + name="image_3", + file_name_fn=lambda i: f"{i:06d}_10.png", + num_examples=num_examples, + size=(3, 100, 200), + ) + + if config["split"] == "train": + datasets_utils.create_image_folder( + root=split_dir, + name="disp_occ_0", + file_name_fn=lambda i: f"{i:06d}.png", + num_examples=num_examples, + # Kitti2015 uses a single channel image for disparities + size=(1, 100, 200), + ) + + datasets_utils.create_image_folder( + root=split_dir, + name="disp_occ_1", + file_name_fn=lambda i: f"{i:06d}.png", + num_examples=num_examples, + # Kitti2015 uses a single channel image for disparities + size=(1, 100, 200), + ) + + return num_examples + + def test_train_splits(self): + for split in ["train"]: + with self.create_dataset(split=split) as (dataset, _): + for left, right, disparity, mask in dataset: + assert mask is None + datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity) + + def test_test_split(self): + for split in ["test"]: + with self.create_dataset(split=split) as (dataset, _): + for left, right, disparity, mask in dataset: + assert mask is None + datasets_utils.shape_test_for_stereo_no_gt(left, right, disparity) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + +class CarlaStereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.CarlaStereo + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, None)) + + @staticmethod + def _create_scene_folders(num_examples: int, root_dir: Union[str, pathlib.Path]): + # make the root_dir if it does not exits + os.makedirs(root_dir, exist_ok=True) + + for i in range(num_examples): + scene_dir = pathlib.Path(root_dir) / f"scene_{i}" + os.makedirs(scene_dir, exist_ok=True) + # populate with left right images + datasets_utils.create_image_file(root=scene_dir, name="im0.png", size=(100, 100)) + datasets_utils.create_image_file(root=scene_dir, name="im1.png", size=(100, 100)) + datasets_utils.make_fake_pfm_file(100, 100, file_name=str(scene_dir / "disp0GT.pfm")) + datasets_utils.make_fake_pfm_file(100, 100, file_name=str(scene_dir / "disp1GT.pfm")) + + def inject_fake_data(self, tmpdir, config): + carla_dir = pathlib.Path(tmpdir) / "carla-highres" + os.makedirs(carla_dir, exist_ok=True) + + split_dir = pathlib.Path(carla_dir) / "trainingF" + os.makedirs(split_dir, exist_ok=True) + + num_examples = 6 + self._create_scene_folders(num_examples=num_examples, root_dir=split_dir) + + return num_examples + + def test_train_splits(self): + with self.create_dataset() as (dataset, _): + for left, right, disparity in dataset: + datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity) + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 099d10da35d..d8b6293fb42 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,4 +1,5 @@ from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel +from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo from .caltech import Caltech101, Caltech256 from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 @@ -105,4 +106,7 @@ "FGVCAircraft", "EuroSAT", "RenderedSST2", + "Kitti2012Stereo", + "Kitti2015Stereo", + "CarlaStereo", ) diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py new file mode 100644 index 00000000000..f4edb98b7a6 --- /dev/null +++ b/torchvision/datasets/_stereo_matching.py @@ -0,0 +1,328 @@ +import functools +import os +from abc import ABC, abstractmethod +from glob import glob +from pathlib import Path +from typing import Callable, List, Optional, Tuple + +import numpy as np +from PIL import Image + +from .utils import _read_pfm, verify_str_arg +from .vision import VisionDataset + +__all__ = () + +_read_pfm_file = functools.partial(_read_pfm, slice_channels=1) + + +class StereoMatchingDataset(ABC, VisionDataset): + """Base interface for Stereo matching datasets""" + + _has_built_in_disparity_mask = False + + def __init__(self, root: str, transforms: Optional[Callable] = None): + """ + Args: + root(str): Root directory of the dataset. + transforms(callable, optional): A function/transform that takes in Tuples of + (images, disparities, valid_masks) and returns a transformed version of each of them. + images is a Tuple of (``PIL.Image``, ``PIL.Image``) + disparities is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (1, H, W) + valid_masks is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (H, W) + In some cases, when a dataset does not provide disparities, the ``disparities`` and + ``valid_masks`` can be Tuples containing None values. + For training splits generally the datasets provide a minimal guarantee of + images: (``PIL.Image``, ``PIL.Image``) + disparities: (``np.ndarray``, ``None``) with shape (1, H, W) + Optionally, based on the dataset, it can return a ``mask`` as well: + valid_masks: (``np.ndarray | None``, ``None``) with shape (H, W) + For some test splits, the datasets provides outputs that look like: + imgaes: (``PIL.Image``, ``PIL.Image``) + disparities: (``None``, ``None``) + Optionally, based on the dataset, it can return a ``mask`` as well: + valid_masks: (``None``, ``None``) + """ + super().__init__(root=root) + self.transforms = transforms + + self._images: List[Tuple[str, str]] = [] + self._disparities: List[Tuple[str, str]] = [] + + def _read_img(self, file_path: str) -> Image.Image: + img = Image.open(file_path) + if img.mode != "RGB": + img = img.convert("RGB") + return img + + def _scan_pairs( + self, paths_left_pattern: str, paths_right_pattern: str, fill_empty: bool = False + ) -> List[Tuple[str, str]]: + left_paths: List[str] = sorted(glob(paths_left_pattern)) + right_paths: List[str] = sorted(glob(paths_right_pattern)) + + # used when dealing with inexistent disparity for the right image + if fill_empty: + right_paths = list("" for _ in left_paths) + + if not left_paths: + raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_left_pattern}") + + if not right_paths: + raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_right_pattern}") + + if len(left_paths) != len(right_paths): + raise ValueError( + f"Found {len(left_paths)} left files but {len(right_paths)} right files using:\n " + f"left pattern: {paths_left_pattern}\n" + f"right pattern: {paths_right_pattern}\n" + ) + + images = list((left, right) for left, right in zip(left_paths, right_paths)) + return images + + @abstractmethod + def _read_disparity(self, file_path: str) -> Tuple: + # function that returns a disparity map and an occlusion map + pass + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + Args: + index(int): The index of the example to retrieve + Returns: + tuple: A 3 or 4-tuple with ``(img_left, img_right, disparity, Optional[valid_mask])`` where ``valid_mask`` + can be a numpy boolean mask of shape (H, W) if the dataset provides a file + indicating which disparity pixels are valid. The disparity is a numpy array of + shape (1, H, W) and the images are PIL images. ``disparity`` is None for + datasets on which for ``split="test"`` the authors did not provide annotations. + """ + img_left = self._read_img(self._images[index][0]) + img_right = self._read_img(self._images[index][1]) + + dsp_map_left, valid_mask_left = self._read_disparity(self._disparities[index][0]) + dsp_map_right, valid_mask_right = self._read_disparity(self._disparities[index][1]) + + imgs = (img_left, img_right) + dsp_maps = (dsp_map_left, dsp_map_right) + valid_masks = (valid_mask_left, valid_mask_right) + + if self.transforms is not None: + ( + imgs, + dsp_maps, + valid_masks, + ) = self.transforms(imgs, dsp_maps, valid_masks) + + if self._has_built_in_disparity_mask or valid_masks[0] is not None: + return imgs[0], imgs[1], dsp_maps[0], valid_masks[0] + else: + return imgs[0], imgs[1], dsp_maps[0] + + def __len__(self) -> int: + return len(self._images) + + +class CarlaStereo(StereoMatchingDataset): + """ + Carla simulator data linked in the `CREStereo github repo `_. + + The dataset is expected to have the following structure: :: + root + carla-highres + trainingF + scene1 + img0.png + img1.png + disp0GT.pfm + disp1GT.pfm + calib.txt + scene2 + img0.png + img1.png + disp0GT.pfm + disp1GT.pfm + calib.txt + ... + Args: + root (string): Root directory where `carla-highres` is located. + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + def __init__(self, root: str, transforms: Optional[Callable] = None): + super().__init__(root, transforms) + self._has_built_in_disparity_mask = False + + root = Path(root) / "carla-highres" + + left_image_pattern = str(root / "trainingF" / "*" / "im0.png") + right_image_pattern = str(root / "trainingF" / "*" / "im1.png") + imgs = self._scan_pairs(left_image_pattern, right_image_pattern) + self._images += imgs + + left_disparity_pattern = str(root / "trainingF" / "*" / "disp0GT.pfm") + right_disparity_pattern = str(root / "trainingF" / "*" / "disp1GT.pfm") + disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + self._disparities += disparities + + def _read_disparity(self, file_path: str) -> Tuple: + disparity_map = _read_pfm_file(file_path) + disparity_map = np.abs(disparity_map) # ensure that the disparity is positive + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + Args: + index(int): The index of the example to retrieve + Returns: + tuple: A 3-tuple with ``(img_left, img_right, disparity)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + If a ``valid_mask`` is generated within the ``transforms`` parameter, + a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. + """ + return super().__getitem__(index) + + +class Kitti2012Stereo(StereoMatchingDataset): + """Kitti dataset from the `2012 `_ stereo evaluation benchmark. + Uses the RGB images for consistency with Kitti 2015. + The dataset is expected to have the following structure: :: + root + Kitti2012 + testing + colored_0 + colored_1 + training + colored_0 + colored_1 + disp_noc + calib + Args: + root (string): Root directory where Kitti2012 is located. + split (string, optional): The dataset split of scenes, either "train" (default), test, or "additional" + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): + super().__init__(root, transforms) + self._has_built_in_disparity_mask = True + + verify_str_arg(split, "split", valid_values=("train", "test")) + + root = Path(root) / "Kitti2012" / (split + "ing") + + left_img_pattern = str(root / "colored_0" / "*_10.png") + right_img_pattern = str(root / "colored_1" / "*_10.png") + self._images += self._scan_pairs(left_img_pattern, right_img_pattern) + + if split == "train": + disparity_pattern = str(root / "disp_noc" / "*.png") + self._disparities += self._scan_pairs(disparity_pattern, "", fill_empty=True) + else: + self._disparities = list(("", "") for _ in self._images) + + def _read_disparity(self, file_path: str) -> Tuple: + # test split has no disparity maps + if not os.path.exists(file_path): + return None, None + + disparity_map = np.asarray(Image.open(file_path)) / 256.0 + # unsqueeze the disparity map into (C, H, W) format + disparity_map = disparity_map[None, :, :] + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + Args: + index(int): The index of the example to retrieve + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not + generate a valid mask. + Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. + """ + return super().__getitem__(index) + + +class Kitti2015Stereo(StereoMatchingDataset): + """Kitti dataset from the `2015 `_ stereo evaluation benchmark. + The dataset is expected to have the following structure: :: + root + Kitti2015 + testing + image_2 + img1.png + img2.png + ... + image_3 + img1.png + img2.png + ... + training + image_2 + img1.png + img2.png + ... + image_3 + img1.png + img2.png + ... + disp_occ_0 + img1.png + img2.png + ... + disp_occ_1 + img1.png + img2.png + ... + calib + Args: + root (string): Root directory where Kitti2015 is located. + split (string, optional): The dataset split of scenes, either "train" (default) or test. + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): + super().__init__(root, transforms) + self._has_built_in_disparity_mask = True + + verify_str_arg(split, "split", valid_values=("train", "test")) + + root = Path(root) / "Kitti2015" / (split + "ing") + left_img_pattern = str(root / "image_2" / "*.png") + right_img_pattern = str(root / "image_3" / "*.png") + self._images += self._scan_pairs(left_img_pattern, right_img_pattern) + + if split == "train": + left_disparity_pattern = str(root / "disp_occ_0" / "*.png") + right_disparity_pattern = str(root / "disp_occ_1" / "*.png") + self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + else: + self._disparities = list(("", "") for _ in self._images) + + def _read_disparity(self, file_path: str) -> Tuple: + # test split has no disparity maps + if not os.path.exists(file_path): + return None, None + + disparity_map = np.asarray(Image.open(file_path)) / 256.0 + # unsqueeze the disparity map into (C, H, W) format + disparity_map = disparity_map[None, :, :] + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + Args: + index(int): The index of the example to retrieve + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not + generate a valid mask. + Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. + """ + return super().__getitem__(index) From 3f8d3304ed766a7cc71b66f0c5bdfa79d2521029 Mon Sep 17 00:00:00 2001 From: Ponku Date: Thu, 28 Jul 2022 13:28:13 +0100 Subject: [PATCH 2/5] Removed some types. Store None instead of "". Merged test util functions. --- test/datasets_utils.py | 50 ++++++++-------------- test/test_datasets.py | 12 +++--- torchvision/datasets/_stereo_matching.py | 53 ++++++++++++------------ 3 files changed, 50 insertions(+), 65 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index eafff3a4371..c232e7132b4 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -751,47 +751,31 @@ def size(idx: int) -> Tuple[int, int, int]: ] -def shape_test_for_stereo_gt_w_mask( - left: PIL.Image.Image, right: PIL.Image.Image, disparity: np.ndarray, valid_mask: np.ndarray +def shape_test_for_stereo( + left: PIL.Image.Image, + right: PIL.Image.Image, + disparity: Optional[np.ndarray] = None, + valid_mask: Optional[np.ndarray] = None, ): left_dims = get_dimensions(left) right_dims = get_dimensions(right) c, h, w = left_dims # check that left and right are the same size assert left_dims == right_dims - # check general shapes assert c == 3 - assert disparity.ndim == 3 - assert disparity.shape == (1, h, w) - # check that valid mask is the same size as the disparity - _, dh, dw = disparity.shape - mh, mw = valid_mask.shape - assert dh == mh - assert dw == mw - - -def shape_test_for_stereo_gt_no_mask(left: PIL.Image.Image, right: PIL.Image.Image, disparity: np.ndarray): - left_dims = get_dimensions(left) - right_dims = get_dimensions(right) - c, h, w = left_dims - # check that left and right are the same size - assert left_dims == right_dims - # check general shapes - assert c == 3 - assert disparity.ndim == 3 - assert disparity.shape == (1, h, w) - - -def shape_test_for_stereo_no_gt(left: PIL.Image.Image, right: PIL.Image.Image, disparity: None): - left_dims = get_dimensions(left) - right_dims = get_dimensions(right) - c, _, _ = left_dims - # check that left and right are the same size - assert left_dims == right_dims - # check general shapes - assert c == 3 - assert disparity is None + # check that the disparity has the same spatial dimensions + # as the input + if disparity is not None: + assert disparity.ndim == 3 + assert disparity.shape == (1, h, w) + + if valid_mask is not None: + # check that valid mask is the same size as the disparity + _, dh, dw = disparity.shape + mh, mw = valid_mask.shape + assert dh == mh + assert dw == mw @requires_lazy_imports("av") diff --git a/test/test_datasets.py b/test/test_datasets.py index e68634dddc1..54696b0d6a8 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2718,14 +2718,15 @@ def test_train_splits(self): with self.create_dataset(split=split) as (dataset, _): for left, right, disparity, mask in dataset: assert mask is None - datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity) + datasets_utils.shape_test_for_stereo(left, right, disparity) def test_test_split(self): for split in ["test"]: with self.create_dataset(split=split) as (dataset, _): for left, right, disparity, mask in dataset: assert mask is None - datasets_utils.shape_test_for_stereo_no_gt(left, right, disparity) + assert disparity is None + datasets_utils.shape_test_for_stereo(left, right) def test_bad_input(self): with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): @@ -2788,14 +2789,15 @@ def test_train_splits(self): with self.create_dataset(split=split) as (dataset, _): for left, right, disparity, mask in dataset: assert mask is None - datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity) + datasets_utils.shape_test_for_stereo(left, right, disparity) def test_test_split(self): for split in ["test"]: with self.create_dataset(split=split) as (dataset, _): for left, right, disparity, mask in dataset: assert mask is None - datasets_utils.shape_test_for_stereo_no_gt(left, right, disparity) + assert disparity is None + datasets_utils.shape_test_for_stereo(left, right) def test_bad_input(self): with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): @@ -2836,7 +2838,7 @@ def inject_fake_data(self, tmpdir, config): def test_train_splits(self): with self.create_dataset() as (dataset, _): for left, right, disparity in dataset: - datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity) + datasets_utils.shape_test_for_stereo(left, right, disparity) if __name__ == "__main__": diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py index f4edb98b7a6..25d919ea7e5 100644 --- a/torchvision/datasets/_stereo_matching.py +++ b/torchvision/datasets/_stereo_matching.py @@ -1,9 +1,8 @@ import functools -import os from abc import ABC, abstractmethod from glob import glob from pathlib import Path -from typing import Callable, List, Optional, Tuple +from typing import Callable, Optional, Tuple import numpy as np from PIL import Image @@ -46,8 +45,8 @@ def __init__(self, root: str, transforms: Optional[Callable] = None): super().__init__(root=root) self.transforms = transforms - self._images: List[Tuple[str, str]] = [] - self._disparities: List[Tuple[str, str]] = [] + self._images = [] # type: ignore + self._disparities = [] # type: ignore def _read_img(self, file_path: str) -> Image.Image: img = Image.open(file_path) @@ -55,15 +54,14 @@ def _read_img(self, file_path: str) -> Image.Image: img = img.convert("RGB") return img - def _scan_pairs( - self, paths_left_pattern: str, paths_right_pattern: str, fill_empty: bool = False - ) -> List[Tuple[str, str]]: - left_paths: List[str] = sorted(glob(paths_left_pattern)) - right_paths: List[str] = sorted(glob(paths_right_pattern)) + def _scan_pairs(self, paths_left_pattern: str, paths_right_pattern: Optional[str] = None): - # used when dealing with inexistent disparity for the right image - if fill_empty: - right_paths = list("" for _ in left_paths) + left_paths = list(sorted(glob(paths_left_pattern))) # type: ignore + + if paths_right_pattern: + right_paths = list(sorted(glob(paths_right_pattern))) # type: ignore + else: + right_paths = list(None for _ in left_paths) # type: ignore if not left_paths: raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_left_pattern}") @@ -78,8 +76,8 @@ def _scan_pairs( f"right pattern: {paths_right_pattern}\n" ) - images = list((left, right) for left, right in zip(left_paths, right_paths)) - return images + paths = list((left, right) for left, right in zip(left_paths, right_paths)) + return paths @abstractmethod def _read_disparity(self, file_path: str) -> Tuple: @@ -151,19 +149,18 @@ class CarlaStereo(StereoMatchingDataset): def __init__(self, root: str, transforms: Optional[Callable] = None): super().__init__(root, transforms) - self._has_built_in_disparity_mask = False root = Path(root) / "carla-highres" left_image_pattern = str(root / "trainingF" / "*" / "im0.png") right_image_pattern = str(root / "trainingF" / "*" / "im1.png") imgs = self._scan_pairs(left_image_pattern, right_image_pattern) - self._images += imgs + self._images = imgs left_disparity_pattern = str(root / "trainingF" / "*" / "disp0GT.pfm") right_disparity_pattern = str(root / "trainingF" / "*" / "disp1GT.pfm") disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) - self._disparities += disparities + self._disparities = disparities def _read_disparity(self, file_path: str) -> Tuple: disparity_map = _read_pfm_file(file_path) @@ -204,9 +201,10 @@ class Kitti2012Stereo(StereoMatchingDataset): transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. """ + _has_built_in_disparity_mask = True + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): super().__init__(root, transforms) - self._has_built_in_disparity_mask = True verify_str_arg(split, "split", valid_values=("train", "test")) @@ -214,17 +212,17 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl left_img_pattern = str(root / "colored_0" / "*_10.png") right_img_pattern = str(root / "colored_1" / "*_10.png") - self._images += self._scan_pairs(left_img_pattern, right_img_pattern) + self._images = self._scan_pairs(left_img_pattern, right_img_pattern) if split == "train": disparity_pattern = str(root / "disp_noc" / "*.png") - self._disparities += self._scan_pairs(disparity_pattern, "", fill_empty=True) + self._disparities = self._scan_pairs(disparity_pattern, None) else: - self._disparities = list(("", "") for _ in self._images) + self._disparities = list((None, None) for _ in self._images) def _read_disparity(self, file_path: str) -> Tuple: # test split has no disparity maps - if not os.path.exists(file_path): + if file_path is None: return None, None disparity_map = np.asarray(Image.open(file_path)) / 256.0 @@ -285,27 +283,28 @@ class Kitti2015Stereo(StereoMatchingDataset): transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. """ + _has_built_in_disparity_mask = True + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): super().__init__(root, transforms) - self._has_built_in_disparity_mask = True verify_str_arg(split, "split", valid_values=("train", "test")) root = Path(root) / "Kitti2015" / (split + "ing") left_img_pattern = str(root / "image_2" / "*.png") right_img_pattern = str(root / "image_3" / "*.png") - self._images += self._scan_pairs(left_img_pattern, right_img_pattern) + self._images = self._scan_pairs(left_img_pattern, right_img_pattern) if split == "train": left_disparity_pattern = str(root / "disp_occ_0" / "*.png") right_disparity_pattern = str(root / "disp_occ_1" / "*.png") - self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) else: - self._disparities = list(("", "") for _ in self._images) + self._disparities = list((None, None) for _ in self._images) def _read_disparity(self, file_path: str) -> Tuple: # test split has no disparity maps - if not os.path.exists(file_path): + if file_path is None: return None, None disparity_map = np.asarray(Image.open(file_path)) / 256.0 From 16ee59439e91fc6f7948c8b4c122cae610d1ba44 Mon Sep 17 00:00:00 2001 From: Ponku Date: Fri, 29 Jul 2022 10:42:25 +0100 Subject: [PATCH 3/5] minor mypy fixes. minor doc fixes --- docs/source/datasets.rst | 11 ++++++ torchvision/datasets/_stereo_matching.py | 50 ++++++++++++++++++------ 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index bc0864083e0..af7ac072e31 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -101,6 +101,17 @@ Optical Flow KittiFlow Sintel +Stereo Matching +~~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: generated/ + :template: class_dataset.rst + + CarlaStereo + Kitti2012Stereo + Kitti2015Stereo + Image pairs ~~~~~~~~~~~ diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py index 25d919ea7e5..8a064826adc 100644 --- a/torchvision/datasets/_stereo_matching.py +++ b/torchvision/datasets/_stereo_matching.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from glob import glob from pathlib import Path -from typing import Callable, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union import numpy as np from PIL import Image @@ -56,12 +56,13 @@ def _read_img(self, file_path: str) -> Image.Image: def _scan_pairs(self, paths_left_pattern: str, paths_right_pattern: Optional[str] = None): - left_paths = list(sorted(glob(paths_left_pattern))) # type: ignore + left_paths = list(sorted(glob(paths_left_pattern))) + right_paths: List[Union[None, str]] if paths_right_pattern: - right_paths = list(sorted(glob(paths_right_pattern))) # type: ignore + right_paths = list(sorted(glob(paths_right_pattern))) else: - right_paths = list(None for _ in left_paths) # type: ignore + right_paths = list(None for _ in left_paths) if not left_paths: raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_left_pattern}") @@ -126,6 +127,7 @@ class CarlaStereo(StereoMatchingDataset): Carla simulator data linked in the `CREStereo github repo `_. The dataset is expected to have the following structure: :: + root carla-highres trainingF @@ -141,7 +143,8 @@ class CarlaStereo(StereoMatchingDataset): disp0GT.pfm disp1GT.pfm calib.txt - ... + ... + Args: root (string): Root directory where `carla-highres` is located. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. @@ -182,22 +185,41 @@ def __getitem__(self, index: int) -> Tuple: class Kitti2012Stereo(StereoMatchingDataset): - """Kitti dataset from the `2012 `_ stereo evaluation benchmark. - Uses the RGB images for consistency with Kitti 2015. + """ + KITTI dataset from the `2012 stereo evaluation benchmark `_. + Uses the RGB images for consistency with KITTI 2015. + The dataset is expected to have the following structure: :: + root Kitti2012 testing colored_0 + 1_10.png + 2_10.png + ... colored_1 + 1_10.png + 2_10.png + ... training colored_0 + 1_10.png + 2_10.png + ... colored_1 + 1_10.png + 2_10.png + ... disp_noc + 1.png + 2.png + ... calib + Args: - root (string): Root directory where Kitti2012 is located. - split (string, optional): The dataset split of scenes, either "train" (default), test, or "additional" + root (string): Root directory where `Kitti2012` is located. + split (string, optional): The dataset split of scenes, either "train" (default) or "test". transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. """ @@ -246,8 +268,11 @@ def __getitem__(self, index: int) -> Tuple: class Kitti2015Stereo(StereoMatchingDataset): - """Kitti dataset from the `2015 `_ stereo evaluation benchmark. + """ + KITTI dataset from the `2015 stereo evaluation benchmark `_. + The dataset is expected to have the following structure: :: + root Kitti2015 testing @@ -277,9 +302,10 @@ class Kitti2015Stereo(StereoMatchingDataset): img2.png ... calib + Args: - root (string): Root directory where Kitti2015 is located. - split (string, optional): The dataset split of scenes, either "train" (default) or test. + root (string): Root directory where `Kitti2015` is located. + split (string, optional): The dataset split of scenes, either "train" (default) or "test". transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. """ From ba90c701a6db4187690734a51ac1752bdfb6b880 Mon Sep 17 00:00:00 2001 From: Ponku Date: Mon, 1 Aug 2022 15:48:02 +0100 Subject: [PATCH 4/5] reformated docstring --- torchvision/datasets/_stereo_matching.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py index 8a064826adc..ab18fc080ee 100644 --- a/torchvision/datasets/_stereo_matching.py +++ b/torchvision/datasets/_stereo_matching.py @@ -89,6 +89,7 @@ def __getitem__(self, index: int) -> Tuple: """Return example at given index. Args: index(int): The index of the example to retrieve + Returns: tuple: A 3 or 4-tuple with ``(img_left, img_right, disparity, Optional[valid_mask])`` where ``valid_mask`` can be a numpy boolean mask of shape (H, W) if the dataset provides a file @@ -175,6 +176,7 @@ def __getitem__(self, index: int) -> Tuple: """Return example at given index. Args: index(int): The index of the example to retrieve + Returns: tuple: A 3-tuple with ``(img_left, img_right, disparity)``. The disparity is a numpy array of shape (1, H, W) and the images are PIL images. @@ -257,6 +259,7 @@ def __getitem__(self, index: int) -> Tuple: """Return example at given index. Args: index(int): The index of the example to retrieve + Returns: tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. The disparity is a numpy array of shape (1, H, W) and the images are PIL images. @@ -343,6 +346,7 @@ def __getitem__(self, index: int) -> Tuple: """Return example at given index. Args: index(int): The index of the example to retrieve + Returns: tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. The disparity is a numpy array of shape (1, H, W) and the images are PIL images. From e7ff0e994be42b43f829762920b5780a72d14a77 Mon Sep 17 00:00:00 2001 From: Ponku Date: Mon, 1 Aug 2022 18:23:39 +0100 Subject: [PATCH 5/5] Added additional line-skips --- torchvision/datasets/_stereo_matching.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py index ab18fc080ee..de213fc0368 100644 --- a/torchvision/datasets/_stereo_matching.py +++ b/torchvision/datasets/_stereo_matching.py @@ -87,6 +87,7 @@ def _read_disparity(self, file_path: str) -> Tuple: def __getitem__(self, index: int) -> Tuple: """Return example at given index. + Args: index(int): The index of the example to retrieve @@ -174,6 +175,7 @@ def _read_disparity(self, file_path: str) -> Tuple: def __getitem__(self, index: int) -> Tuple: """Return example at given index. + Args: index(int): The index of the example to retrieve @@ -257,6 +259,7 @@ def _read_disparity(self, file_path: str) -> Tuple: def __getitem__(self, index: int) -> Tuple: """Return example at given index. + Args: index(int): The index of the example to retrieve @@ -344,6 +347,7 @@ def _read_disparity(self, file_path: str) -> Tuple: def __getitem__(self, index: int) -> Tuple: """Return example at given index. + Args: index(int): The index of the example to retrieve