Skip to content

Commit 12f20d7

Browse files
davnov134facebook-github-bot
authored andcommitted
Convert from Pytorch3D NDC coordinates to grid_sample coordinates.
Summary: Implements a utility function to convert from 2D coordinates in Pytorch3D NDC space to the coordinates in grid_sample. Reviewed By: shapovalov Differential Revision: D33741394 fbshipit-source-id: 88981653356588fe646e6dea48fe7f7298738437
1 parent 47c0997 commit 12f20d7

File tree

3 files changed

+260
-3
lines changed

3 files changed

+260
-3
lines changed

pytorch3d/renderer/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,12 @@
7070
PulsarPointsRenderer,
7171
rasterize_points,
7272
)
73-
from .utils import TensorProperties, convert_to_tensors_and_broadcast
73+
from .utils import (
74+
TensorProperties,
75+
convert_to_tensors_and_broadcast,
76+
ndc_to_grid_sample_coords,
77+
ndc_grid_sample,
78+
)
7479

7580

7681
__all__ = [k for k in globals().keys() if not k.startswith("_")]

pytorch3d/renderer/utils.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import copy
99
import inspect
1010
import warnings
11-
from typing import Any, Optional, Union
11+
from typing import Any, Optional, Union, Tuple
1212

1313
import numpy as np
1414
import torch
@@ -350,3 +350,80 @@ def convert_to_tensors_and_broadcast(
350350
args_Nd.append(c.expand(*expand_sizes))
351351

352352
return args_Nd
353+
354+
355+
def ndc_grid_sample(
356+
input: torch.Tensor,
357+
grid_ndc: torch.Tensor,
358+
**grid_sample_kwargs,
359+
) -> torch.Tensor:
360+
"""
361+
Samples a tensor `input` of shape `(B, dim, H, W)` at 2D locations
362+
specified by a tensor `grid_ndc` of shape `(B, ..., 2)` using
363+
the `torch.nn.functional.grid_sample` function.
364+
`grid_ndc` is specified in PyTorch3D NDC coordinate frame.
365+
366+
Args:
367+
input: The tensor of shape `(B, dim, H, W)` to be sampled.
368+
grid_ndc: A tensor of shape `(B, ..., 2)` denoting the set of
369+
2D locations at which `input` is sampled.
370+
See [1] for a detailed description of the NDC coordinates.
371+
grid_sample_kwargs: Additional arguments forwarded to the
372+
`torch.nn.functional.grid_sample` call. See the corresponding
373+
docstring for a listing of the corresponding arguments.
374+
375+
Returns:
376+
sampled_input: A tensor of shape `(B, dim, ...)` containing the samples
377+
of `input` at 2D locations `grid_ndc`.
378+
379+
References:
380+
[1] https://pytorch3d.org/docs/cameras
381+
"""
382+
383+
batch, *spatial_size, pt_dim = grid_ndc.shape
384+
if batch != input.shape[0]:
385+
raise ValueError("'input' and 'grid_ndc' have to have the same batch size.")
386+
if input.ndim != 4:
387+
raise ValueError("'input' has to be a 4-dimensional Tensor.")
388+
if pt_dim != 2:
389+
raise ValueError("The last dimension of 'grid_ndc' has to be == 2.")
390+
391+
grid_ndc_flat = grid_ndc.reshape(batch, -1, 1, 2)
392+
393+
grid_flat = ndc_to_grid_sample_coords(grid_ndc_flat, input.shape[2:])
394+
395+
sampled_input_flat = torch.nn.functional.grid_sample(
396+
input, grid_flat, **grid_sample_kwargs
397+
)
398+
399+
sampled_input = sampled_input_flat.reshape([batch, input.shape[1], *spatial_size])
400+
401+
return sampled_input
402+
403+
404+
def ndc_to_grid_sample_coords(
405+
xy_ndc: torch.Tensor,
406+
image_size_hw: Tuple[int, int],
407+
) -> torch.Tensor:
408+
"""
409+
Convert from the PyTorch3D's NDC coordinates to
410+
`torch.nn.functional.grid_sampler`'s coordinates.
411+
412+
Args:
413+
xy_ndc: Tensor of shape `(..., 2)` containing 2D points in the
414+
PyTorch3D's NDC coordinates.
415+
image_size_hw: A tuple `(image_height, image_width)` denoting the
416+
height and width of the image tensor to sample.
417+
Returns:
418+
xy_grid_sample: Tensor of shape `(..., 2)` containing 2D points in the
419+
`torch.nn.functional.grid_sample` coordinates.
420+
"""
421+
if len(image_size_hw) != 2 or any(s <= 0 for s in image_size_hw):
422+
raise ValueError("'image_size_hw' has to be a 2-tuple of positive integers")
423+
aspect = min(image_size_hw) / max(image_size_hw)
424+
xy_grid_sample = -xy_ndc # first negate the coords
425+
if image_size_hw[0] >= image_size_hw[1]:
426+
xy_grid_sample[..., 1] *= aspect
427+
else:
428+
xy_grid_sample[..., 0] *= aspect
429+
return xy_grid_sample

tests/test_rendering_utils.py

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,20 @@
1010
import numpy as np
1111
import torch
1212
from common_testing import TestCaseMixin
13-
from pytorch3d.renderer.utils import TensorProperties
13+
from pytorch3d.ops import eyes
14+
from pytorch3d.renderer import (
15+
PerspectiveCameras,
16+
AlphaCompositor,
17+
PointsRenderer,
18+
PointsRasterizationSettings,
19+
PointsRasterizer,
20+
)
21+
from pytorch3d.renderer.utils import (
22+
TensorProperties,
23+
ndc_to_grid_sample_coords,
24+
ndc_grid_sample,
25+
)
26+
from pytorch3d.structures import Pointclouds
1427

1528

1629
# Example class for testing
@@ -96,3 +109,165 @@ def test_gather_props(self):
96109
# the input.
97110
self.assertClose(test_class_gathered.x[inds].mean(dim=0), x[i, ...])
98111
self.assertClose(test_class_gathered.y[inds].mean(dim=0), y[i, ...])
112+
113+
def test_ndc_grid_sample_rendering(self):
114+
"""
115+
Use PyTorch3D point renderer to render a colored point cloud, then
116+
sample the image at the locations of the point projections with
117+
`ndc_grid_sample`. Finally, assert that the sampled colors are equal to the
118+
original point cloud colors.
119+
120+
Note that, in order to ensure correctness, we use a nearest-neighbor
121+
assignment point renderer (i.e. no soft splatting).
122+
"""
123+
124+
# generate a bunch of 3D points on a regular grid lying in the z-plane
125+
n_grid_pts = 10
126+
grid_scale = 0.9
127+
z_plane = 2.0
128+
image_size = [128, 128]
129+
point_radius = 0.015
130+
n_pts = n_grid_pts * n_grid_pts
131+
pts = torch.stack(
132+
torch.meshgrid(
133+
[torch.linspace(-grid_scale, grid_scale, n_grid_pts)] * 2, indexing="ij"
134+
),
135+
dim=-1,
136+
)
137+
pts = torch.cat([pts, z_plane * torch.ones_like(pts[..., :1])], dim=-1)
138+
pts = pts.reshape(1, n_pts, 3)
139+
140+
# color the points randomly
141+
pts_colors = torch.rand(1, n_pts, 3)
142+
143+
# make trivial rendering cameras
144+
cameras = PerspectiveCameras(
145+
R=eyes(dim=3, N=1),
146+
device=pts.device,
147+
T=torch.zeros(1, 3, dtype=torch.float32, device=pts.device),
148+
)
149+
150+
# render the point cloud
151+
pcl = Pointclouds(points=pts, features=pts_colors)
152+
renderer = NearestNeighborPointsRenderer(
153+
rasterizer=PointsRasterizer(
154+
cameras=cameras,
155+
raster_settings=PointsRasterizationSettings(
156+
image_size=image_size,
157+
radius=point_radius,
158+
points_per_pixel=1,
159+
),
160+
),
161+
compositor=AlphaCompositor(),
162+
)
163+
im_render = renderer(pcl)
164+
165+
# sample the render at projected pts
166+
pts_proj = cameras.transform_points(pcl.points_padded())[..., :2]
167+
pts_colors_sampled = ndc_grid_sample(
168+
im_render,
169+
pts_proj,
170+
mode="nearest",
171+
align_corners=False,
172+
).permute(0, 2, 1)
173+
174+
# assert that the samples are the same as original points
175+
self.assertClose(pts_colors, pts_colors_sampled, atol=1e-4)
176+
177+
def test_ndc_to_grid_sample_coords(self):
178+
"""
179+
Test the conversion from ndc to grid_sample coords by comparing
180+
to known conversion results.
181+
"""
182+
183+
# square image tests
184+
image_size_square = [100, 100]
185+
xy_ndc_gs_square = torch.FloatTensor(
186+
[
187+
# 4 corners
188+
[[-1.0, -1.0], [1.0, 1.0]],
189+
[[1.0, 1.0], [-1.0, -1.0]],
190+
[[1.0, -1.0], [-1.0, 1.0]],
191+
[[1.0, 1.0], [-1.0, -1.0]],
192+
# center
193+
[[0.0, 0.0], [0.0, 0.0]],
194+
]
195+
)
196+
197+
# non-batched version
198+
for xy_ndc, xy_gs in xy_ndc_gs_square:
199+
xy_gs_predicted = ndc_to_grid_sample_coords(
200+
xy_ndc,
201+
image_size_square,
202+
)
203+
self.assertClose(xy_gs_predicted, xy_gs)
204+
205+
# batched version
206+
xy_ndc, xy_gs = xy_ndc_gs_square[:, 0], xy_ndc_gs_square[:, 1]
207+
xy_gs_predicted = ndc_to_grid_sample_coords(
208+
xy_ndc,
209+
image_size_square,
210+
)
211+
self.assertClose(xy_gs_predicted, xy_gs)
212+
213+
# non-square image tests
214+
image_size = [100, 200]
215+
xy_ndc_gs = torch.FloatTensor(
216+
[
217+
# 4 corners
218+
[[-2.0, -1.0], [1.0, 1.0]],
219+
[[2.0, -1.0], [-1.0, 1.0]],
220+
[[-2.0, 1.0], [1.0, -1.0]],
221+
[[2.0, 1.0], [-1.0, -1.0]],
222+
# center
223+
[[0.0, 0.0], [0.0, 0.0]],
224+
# non-corner points
225+
[[4.0, 0.5], [-2.0, -0.5]],
226+
[[1.0, -0.5], [-0.5, 0.5]],
227+
]
228+
)
229+
230+
# check both H > W and W > H
231+
for flip_axes in [False, True]:
232+
233+
# non-batched version
234+
for xy_ndc, xy_gs in xy_ndc_gs:
235+
xy_gs_predicted = ndc_to_grid_sample_coords(
236+
xy_ndc.flip(dims=(-1,)) if flip_axes else xy_ndc,
237+
list(reversed(image_size)) if flip_axes else image_size,
238+
)
239+
self.assertClose(
240+
xy_gs_predicted,
241+
xy_gs.flip(dims=(-1,)) if flip_axes else xy_gs,
242+
)
243+
244+
# batched version
245+
xy_ndc, xy_gs = xy_ndc_gs[:, 0], xy_ndc_gs[:, 1]
246+
xy_gs_predicted = ndc_to_grid_sample_coords(
247+
xy_ndc.flip(dims=(-1,)) if flip_axes else xy_ndc,
248+
list(reversed(image_size)) if flip_axes else image_size,
249+
)
250+
self.assertClose(
251+
xy_gs_predicted,
252+
xy_gs.flip(dims=(-1,)) if flip_axes else xy_gs,
253+
)
254+
255+
256+
class NearestNeighborPointsRenderer(PointsRenderer):
257+
"""
258+
A class for rendering a batch of points by a trivial nearest
259+
neighbor assignment.
260+
"""
261+
262+
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
263+
fragments = self.rasterizer(point_clouds, **kwargs)
264+
# set all weights trivially to one
265+
dists2 = fragments.dists.permute(0, 3, 1, 2)
266+
weights = torch.ones_like(dists2)
267+
images = self.compositor(
268+
fragments.idx.long().permute(0, 3, 1, 2),
269+
weights,
270+
point_clouds.features_packed().permute(1, 0),
271+
**kwargs,
272+
)
273+
return images

0 commit comments

Comments
 (0)