Skip to content

Commit a12612a

Browse files
bottlerfacebook-github-bot
authored andcommitted
doc rgbd point cloud
Summary: docstring and shape fix Reviewed By: shapovalov Differential Revision: D42609661 fbshipit-source-id: fd50234872ad61b5452821eeb89d51344f70c957
1 parent d561f19 commit a12612a

File tree

2 files changed

+88
-3
lines changed

2 files changed

+88
-3
lines changed

pytorch3d/implicitron/tools/point_cloud_utils.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,33 @@ def get_rgbd_point_cloud(
2727
mask: Optional[torch.Tensor] = None,
2828
mask_thr: float = 0.5,
2929
mask_points: bool = True,
30+
euclidean: bool = False,
3031
) -> Pointclouds:
3132
"""
32-
Given a batch of images, depths, masks and cameras, generate a colored
33-
point cloud by unprojecting depth maps to the and coloring with the source
33+
Given a batch of images, depths, masks and cameras, generate a single colored
34+
point cloud by unprojecting depth maps and coloring with the source
3435
pixel colors.
36+
37+
Arguments:
38+
camera: Batch of N cameras
39+
image_rgb: Batch of N images of shape (N, C, H, W).
40+
For RGB images C=3.
41+
depth_map: Batch of N depth maps of shape (N, 1, H', W').
42+
Only positive values here are used to generate points.
43+
If euclidean=False (default) this contains perpendicular distances
44+
from each point to the camera plane (z-values).
45+
If euclidean=True, this contains distances from each point to
46+
the camera center.
47+
mask: If provided, batch of N masks of the same shape as depth_map.
48+
If provided, values in depth_map are ignored if the corresponding
49+
element of mask is smaller than mask_thr.
50+
mask_thr: used in interpreting mask
51+
euclidean: used in interpreting depth_map.
52+
53+
Returns:
54+
Pointclouds object containing one point cloud.
3555
"""
36-
imh, imw = image_rgb.shape[2:]
56+
imh, imw = depth_map.shape[2:]
3757

3858
# convert the depth maps to point clouds using the grid ray sampler
3959
pts_3d = ray_bundle_to_ray_points(
@@ -43,6 +63,7 @@ def get_rgbd_point_cloud(
4363
n_pts_per_ray=1,
4464
min_depth=1.0,
4565
max_depth=1.0,
66+
unit_directions=euclidean,
4667
)(camera)._replace(lengths=depth_map[:, 0, ..., None])
4768
)
4869

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
11+
12+
from pytorch3d.renderer.cameras import PerspectiveCameras
13+
from tests.common_testing import TestCaseMixin
14+
15+
16+
class TestPointCloudUtils(TestCaseMixin, unittest.TestCase):
17+
def setUp(self):
18+
torch.manual_seed(42)
19+
20+
def test_unproject(self):
21+
H, W = 50, 100
22+
23+
# Random RGBD image with depth 3
24+
# (depth 0 = at the camera)
25+
# and purple in the upper right corner
26+
27+
image = torch.rand(4, H, W)
28+
depth = 3
29+
image[3] = depth
30+
image[1, H // 2 :, W // 2 :] *= 0.4
31+
32+
# two ways to define the same camera:
33+
# at the origin facing the positive z axis
34+
ndc_camera = PerspectiveCameras(focal_length=1.0)
35+
screen_camera = PerspectiveCameras(
36+
focal_length=H // 2,
37+
in_ndc=False,
38+
image_size=((H, W),),
39+
principal_point=((W / 2, H / 2),),
40+
)
41+
42+
for camera in (ndc_camera, screen_camera):
43+
# 1. z-depth
44+
cloud = get_rgbd_point_cloud(
45+
camera,
46+
image_rgb=image[:3][None],
47+
depth_map=image[3:][None],
48+
euclidean=False,
49+
)
50+
[points] = cloud.points_list()
51+
self.assertConstant(points[:, 2], depth) # constant depth
52+
extremes = depth * torch.tensor([W / H - 1 / H, 1 - 1 / H])
53+
self.assertClose(points[:, :2].min(0).values, -extremes)
54+
self.assertClose(points[:, :2].max(0).values, extremes)
55+
56+
# 2. euclidean
57+
cloud = get_rgbd_point_cloud(
58+
camera,
59+
image_rgb=image[:3][None],
60+
depth_map=image[3:][None],
61+
euclidean=True,
62+
)
63+
[points] = cloud.points_list()
64+
self.assertConstant(torch.norm(points, dim=1), depth, atol=1e-5)

0 commit comments

Comments
 (0)