|
1 | 1 | import functools |
| 2 | +import os |
2 | 3 | from abc import ABC, abstractmethod |
3 | 4 | from glob import glob |
4 | 5 | from pathlib import Path |
@@ -465,3 +466,125 @@ def __getitem__(self, index: int) -> Tuple: |
465 | 466 | a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. |
466 | 467 | """ |
467 | 468 | return super().__getitem__(index) |
| 469 | + |
| 470 | + |
| 471 | +class SintelStereo(StereoMatchingDataset): |
| 472 | + """Sintel `Stereo Dataset <http://sintel.is.tue.mpg.de/stereo>`_. |
| 473 | +
|
| 474 | + The dataset is expected to have the following structure: :: |
| 475 | +
|
| 476 | + root |
| 477 | + Sintel |
| 478 | + training |
| 479 | + final_left |
| 480 | + scene1 |
| 481 | + img1.png |
| 482 | + img2.png |
| 483 | + ... |
| 484 | + ... |
| 485 | + final_right |
| 486 | + scene2 |
| 487 | + img1.png |
| 488 | + img2.png |
| 489 | + ... |
| 490 | + ... |
| 491 | + disparities |
| 492 | + scene1 |
| 493 | + img1.png |
| 494 | + img2.png |
| 495 | + ... |
| 496 | + ... |
| 497 | + occlusions |
| 498 | + scene1 |
| 499 | + img1.png |
| 500 | + img2.png |
| 501 | + ... |
| 502 | + ... |
| 503 | + outofframe |
| 504 | + scene1 |
| 505 | + img1.png |
| 506 | + img2.png |
| 507 | + ... |
| 508 | + ... |
| 509 | +
|
| 510 | + Args: |
| 511 | + root (string): Root directory where Sintel Stereo is located. |
| 512 | + pass_name (string): The name of the pass to use, either "final", "clean" or "both". |
| 513 | + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. |
| 514 | + """ |
| 515 | + |
| 516 | + _has_built_in_disparity_mask = True |
| 517 | + |
| 518 | + def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None): |
| 519 | + super().__init__(root, transforms) |
| 520 | + |
| 521 | + verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both")) |
| 522 | + |
| 523 | + root = Path(root) / "Sintel" |
| 524 | + pass_names = { |
| 525 | + "final": ["final"], |
| 526 | + "clean": ["clean"], |
| 527 | + "both": ["final", "clean"], |
| 528 | + }[pass_name] |
| 529 | + |
| 530 | + for p in pass_names: |
| 531 | + left_img_pattern = str(root / "training" / f"{p}_left" / "*" / "*.png") |
| 532 | + right_img_pattern = str(root / "training" / f"{p}_right" / "*" / "*.png") |
| 533 | + self._images += self._scan_pairs(left_img_pattern, right_img_pattern) |
| 534 | + |
| 535 | + disparity_pattern = str(root / "training" / "disparities" / "*" / "*.png") |
| 536 | + self._disparities += self._scan_pairs(disparity_pattern, None) |
| 537 | + |
| 538 | + def _get_occlussion_mask_paths(self, file_path: str) -> Tuple[str, str]: |
| 539 | + # helper function to get the occlusion mask paths |
| 540 | + # a path will look like .../.../.../training/disparities/scene1/img1.png |
| 541 | + # we want to get something like .../.../.../training/occlusions/scene1/img1.png |
| 542 | + fpath = Path(file_path) |
| 543 | + basename = fpath.name |
| 544 | + scenedir = fpath.parent |
| 545 | + # the parent of the scenedir is actually the disparity dir |
| 546 | + sampledir = scenedir.parent.parent |
| 547 | + |
| 548 | + occlusion_path = str(sampledir / "occlusions" / scenedir.name / basename) |
| 549 | + outofframe_path = str(sampledir / "outofframe" / scenedir.name / basename) |
| 550 | + |
| 551 | + if not os.path.exists(occlusion_path): |
| 552 | + raise FileNotFoundError(f"Occlusion mask {occlusion_path} does not exist") |
| 553 | + |
| 554 | + if not os.path.exists(outofframe_path): |
| 555 | + raise FileNotFoundError(f"Out of frame mask {outofframe_path} does not exist") |
| 556 | + |
| 557 | + return occlusion_path, outofframe_path |
| 558 | + |
| 559 | + def _read_disparity(self, file_path: str) -> Tuple: |
| 560 | + if file_path is None: |
| 561 | + return None, None |
| 562 | + |
| 563 | + # disparity decoding as per Sintel instructions in the README provided with the dataset |
| 564 | + disparity_map = np.asarray(Image.open(file_path), dtype=np.float32) |
| 565 | + r, g, b = np.split(disparity_map, 3, axis=-1) |
| 566 | + disparity_map = r * 4 + g / (2**6) + b / (2**14) |
| 567 | + # reshape into (C, H, W) format |
| 568 | + disparity_map = np.transpose(disparity_map, (2, 0, 1)) |
| 569 | + # find the appropiate file paths |
| 570 | + occlued_mask_path, out_of_frame_mask_path = self._get_occlussion_mask_paths(file_path) |
| 571 | + # occlusion masks |
| 572 | + valid_mask = np.asarray(Image.open(occlued_mask_path)) == 0 |
| 573 | + # out of frame masks |
| 574 | + off_mask = np.asarray(Image.open(out_of_frame_mask_path)) == 0 |
| 575 | + # combine the masks together |
| 576 | + valid_mask = np.logical_and(off_mask, valid_mask) |
| 577 | + return disparity_map, valid_mask |
| 578 | + |
| 579 | + def __getitem__(self, index: int) -> Tuple: |
| 580 | + """Return example at given index. |
| 581 | +
|
| 582 | + Args: |
| 583 | + index(int): The index of the example to retrieve |
| 584 | +
|
| 585 | + Returns: |
| 586 | + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. |
| 587 | + The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst |
| 588 | + the valid_mask is a numpy array of shape (H, W). |
| 589 | + """ |
| 590 | + return super().__getitem__(index) |
0 commit comments