Skip to content

Commit a1f2ded

Browse files
bottlerfacebook-github-bot
authored andcommitted
voxel_grid_implicit_function scaffold fixes
Summary: Fix indexing of directions after filtering of points by scaffold. Reviewed By: shapovalov Differential Revision: D40853482 fbshipit-source-id: 9cfdb981e97cb82edcd27632c5848537ed2c6837
1 parent e4a3298 commit a1f2ded

File tree

2 files changed

+44
-30
lines changed

2 files changed

+44
-30
lines changed

pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import math
88
import warnings
99
from dataclasses import fields
10-
from typing import Callable, Dict, Optional, Tuple, Union
10+
from typing import Callable, Dict, Optional, Tuple
1111

1212
import torch
1313

@@ -118,11 +118,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
118118
the calculation.)
119119
scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying
120120
voxel grid which stores scaffold
121-
scaffold_empty_space_threshold (float): if `self.get_density` evaluates to less than
121+
scaffold_empty_space_threshold (float): if `self._get_density` evaluates to less than
122122
this it will be considered as empty space and the scaffold at that point would
123123
evaluate as empty space.
124124
scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate
125-
at the same time. To calculate the scaffold we need to query `get_density()` at
125+
at the same time. To calculate the scaffold we need to query `_get_density()` at
126126
every voxel, this calculation can be split into scaffold depth number of xy plane
127127
calculations if you want the lowest memory usage, one calculation to calculate the
128128
whole scaffold, but with higher memory footprint or any other number of planes.
@@ -242,14 +242,16 @@ def forward(
242242
points = ray_bundle_to_ray_points(ray_bundle)
243243
directions = ray_bundle.directions.reshape(-1, 3)
244244
input_shape = points.shape
245+
num_points_per_ray = input_shape[-2]
245246
points = points.view(-1, 3)
247+
non_empty_points = None
246248

247249
# ########## filter the points using the scaffold ########## #
248250
if self._scaffold_ready and self.scaffold_filter_points:
249-
# pyre-ignore[29]
250-
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
251+
with torch.no_grad():
252+
# pyre-ignore[29]
253+
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
251254
points = points[non_empty_points]
252-
directions = directions[non_empty_points]
253255
if len(points) == 0:
254256
warnings.warn(
255257
"The scaffold has filtered all the points."
@@ -262,8 +264,8 @@ def forward(
262264
)
263265

264266
# ########## calculate color and density ########## #
265-
rays_densities, rays_colors = self.calculate_density_and_color(
266-
points, directions, camera
267+
rays_densities, rays_colors = self._calculate_density_and_color(
268+
points, directions, camera, non_empty_points, num_points_per_ray
267269
)
268270

269271
if not (self._scaffold_ready and self.scaffold_filter_points):
@@ -283,9 +285,8 @@ def forward(
283285
rays_colors_combined = rays_colors.new_zeros(
284286
(math.prod(input_shape[:-1]), rays_colors.shape[-1])
285287
)
286-
# pyre-ignore[61]
288+
assert non_empty_points is not None
287289
rays_densities_combined[non_empty_points] = rays_densities
288-
# pyre-ignore[61]
289290
rays_colors_combined[non_empty_points] = rays_colors
290291

291292
return (
@@ -294,23 +295,28 @@ def forward(
294295
{},
295296
)
296297

297-
def calculate_density_and_color(
298+
def _calculate_density_and_color(
298299
self,
299300
points: torch.Tensor,
300301
directions: torch.Tensor,
301-
camera: Optional[CamerasBase] = None,
302+
camera: Optional[CamerasBase],
303+
non_empty_points: Optional[torch.Tensor],
304+
num_points_per_ray: int,
302305
) -> Tuple[torch.Tensor, torch.Tensor]:
303306
"""
304307
Calculates density and color at `points`.
305308
If enabled use cuda streams.
306309
307310
Args:
308311
points: points at which to calculate density and color.
309-
Tensor of shape [..., 3].
310-
directions: from which directions are the points viewed
311-
Tensor of shape [..., 3].
312+
Tensor of shape [n_points, 3].
313+
directions: from which directions are the points viewed.
314+
One per ray. Tensor of shape [n_rays, 3].
312315
camera: A camera model which will be used to transform the viewing
313316
directions
317+
non_empty_points: indices of points which weren't filtered out;
318+
used for expanding directions
319+
num_points_per_ray: number of points per ray, needed to expand directions.
314320
Returns:
315321
Tuple of color (tensor of shape [..., 3]) and density
316322
(tensor of shape [..., 1])
@@ -323,20 +329,24 @@ def calculate_density_and_color(
323329
with torch.cuda.stream(other_stream):
324330
# rays_densities.shape =
325331
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim]
326-
rays_densities = self.get_density(points)
332+
rays_densities = self._get_density(points)
327333

328334
# rays_colors.shape =
329335
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim]
330-
rays_colors = self.get_color(points, camera, directions)
336+
rays_colors = self._get_color(
337+
points, camera, directions, non_empty_points, num_points_per_ray
338+
)
331339

332340
current_stream.wait_stream(other_stream)
333341
else:
334342
# Same calculation as above, just serial.
335-
rays_densities = self.get_density(points)
336-
rays_colors = self.get_color(points, camera, directions)
343+
rays_densities = self._get_density(points)
344+
rays_colors = self._get_color(
345+
points, camera, directions, non_empty_points, num_points_per_ray
346+
)
337347
return rays_densities, rays_colors
338348

339-
def get_density(self, points: torch.Tensor) -> torch.Tensor:
349+
def _get_density(self, points: torch.Tensor) -> torch.Tensor:
340350
"""
341351
Calculates density at points:
342352
1) Evaluates the voxel grid on points
@@ -356,11 +366,13 @@ def get_density(self, points: torch.Tensor) -> torch.Tensor:
356366
# shape = [..., density_dim]
357367
return self.decoder_density(harmonic_embedding_density)
358368

359-
def get_color(
369+
def _get_color(
360370
self,
361371
points: torch.Tensor,
362372
camera: Optional[CamerasBase],
363373
directions: torch.Tensor,
374+
non_empty_points: Optional[torch.Tensor],
375+
num_points_per_ray: int,
364376
) -> torch.Tensor:
365377
"""
366378
Calculates color at points using the viewing direction:
@@ -376,6 +388,9 @@ def get_color(
376388
directions
377389
directions: A tensor of shape `(..., 3)`
378390
containing the direction vectors of sampling rays in world coords.
391+
non_empty_points: indices of points which weren't filtered out;
392+
used for expanding directions
393+
num_points_per_ray: number of points per ray, needed to expand directions.
379394
"""
380395
# ########## transform direction ########## #
381396
if self.xyz_ray_dir_in_camera_coords:
@@ -400,12 +415,11 @@ def get_color(
400415
rays_directions_normed
401416
)
402417

403-
n_rays = directions.shape[0]
404-
points_per_ray: int = points.shape[0] // n_rays
405-
406418
harmonic_embedding_dir = torch.repeat_interleave(
407-
harmonic_embedding_dir, points_per_ray, dim=0
419+
harmonic_embedding_dir, num_points_per_ray, dim=0
408420
)
421+
if non_empty_points is not None:
422+
harmonic_embedding_dir = harmonic_embedding_dir[non_empty_points]
409423

410424
# total color embedding is concatenation of the harmonic embedding of voxel grid
411425
# output and harmonic embedding of the normalized direction
@@ -505,7 +519,7 @@ def _get_scaffold(self, epoch: int) -> bool:
505519
)
506520
for k in range(0, points.shape[-1], chunk_size):
507521
points_in_planes = points[..., k : k + chunk_size]
508-
planes.append(self.get_density(points_in_planes)[..., 0])
522+
planes.append(self._get_density(points_in_planes)[..., 0])
509523

510524
density_cube = torch.cat(planes, dim=-1)
511525
density_cube = torch.nn.functional.max_pool3d(

tests/implicitron/test_voxel_grid_implicit_function.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def new_density(points):
8989
out.append(torch.tensor([[0.0]]))
9090
return torch.cat(out).view(*inshape[:-1], 1).to(device)
9191

92-
func.get_density = new_density
92+
func._get_density = new_density
9393
func._get_scaffold(0)
9494

9595
points = torch.tensor(
@@ -136,15 +136,15 @@ def new_density(points):
136136
assert torch.all(scaffold(points)), (scaffold(points), points.shape)
137137
return points.sum(dim=-1, keepdim=True)
138138

139-
def new_color(points, camera, directions):
139+
def new_color(points, camera, directions, non_empty_points, num_points_per_ray):
140140
# check if all passed points should be passed here
141141
assert torch.all(scaffold(points)) # , (scaffold(points), points)
142142
return points * 2
143143

144144
# check both computation paths that they contain only points
145145
# which are not in empty space
146-
func.get_density = new_density
147-
func.get_color = new_color
146+
func._get_density = new_density
147+
func._get_color = new_color
148148
func.voxel_grid_scaffold.forward = scaffold
149149
func._scaffold_ready = True
150150

0 commit comments

Comments
 (0)