7
7
import math
8
8
import warnings
9
9
from dataclasses import fields
10
- from typing import Callable , Dict , Optional , Tuple , Union
10
+ from typing import Callable , Dict , Optional , Tuple
11
11
12
12
import torch
13
13
@@ -118,11 +118,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
118
118
the calculation.)
119
119
scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying
120
120
voxel grid which stores scaffold
121
- scaffold_empty_space_threshold (float): if `self.get_density ` evaluates to less than
121
+ scaffold_empty_space_threshold (float): if `self._get_density ` evaluates to less than
122
122
this it will be considered as empty space and the scaffold at that point would
123
123
evaluate as empty space.
124
124
scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate
125
- at the same time. To calculate the scaffold we need to query `get_density ()` at
125
+ at the same time. To calculate the scaffold we need to query `_get_density ()` at
126
126
every voxel, this calculation can be split into scaffold depth number of xy plane
127
127
calculations if you want the lowest memory usage, one calculation to calculate the
128
128
whole scaffold, but with higher memory footprint or any other number of planes.
@@ -242,14 +242,16 @@ def forward(
242
242
points = ray_bundle_to_ray_points (ray_bundle )
243
243
directions = ray_bundle .directions .reshape (- 1 , 3 )
244
244
input_shape = points .shape
245
+ num_points_per_ray = input_shape [- 2 ]
245
246
points = points .view (- 1 , 3 )
247
+ non_empty_points = None
246
248
247
249
# ########## filter the points using the scaffold ########## #
248
250
if self ._scaffold_ready and self .scaffold_filter_points :
249
- # pyre-ignore[29]
250
- non_empty_points = self .voxel_grid_scaffold (points )[..., 0 ] > 0
251
+ with torch .no_grad ():
252
+ # pyre-ignore[29]
253
+ non_empty_points = self .voxel_grid_scaffold (points )[..., 0 ] > 0
251
254
points = points [non_empty_points ]
252
- directions = directions [non_empty_points ]
253
255
if len (points ) == 0 :
254
256
warnings .warn (
255
257
"The scaffold has filtered all the points."
@@ -262,8 +264,8 @@ def forward(
262
264
)
263
265
264
266
# ########## calculate color and density ########## #
265
- rays_densities , rays_colors = self .calculate_density_and_color (
266
- points , directions , camera
267
+ rays_densities , rays_colors = self ._calculate_density_and_color (
268
+ points , directions , camera , non_empty_points , num_points_per_ray
267
269
)
268
270
269
271
if not (self ._scaffold_ready and self .scaffold_filter_points ):
@@ -283,9 +285,8 @@ def forward(
283
285
rays_colors_combined = rays_colors .new_zeros (
284
286
(math .prod (input_shape [:- 1 ]), rays_colors .shape [- 1 ])
285
287
)
286
- # pyre-ignore[61]
288
+ assert non_empty_points is not None
287
289
rays_densities_combined [non_empty_points ] = rays_densities
288
- # pyre-ignore[61]
289
290
rays_colors_combined [non_empty_points ] = rays_colors
290
291
291
292
return (
@@ -294,23 +295,28 @@ def forward(
294
295
{},
295
296
)
296
297
297
- def calculate_density_and_color (
298
+ def _calculate_density_and_color (
298
299
self ,
299
300
points : torch .Tensor ,
300
301
directions : torch .Tensor ,
301
- camera : Optional [CamerasBase ] = None ,
302
+ camera : Optional [CamerasBase ],
303
+ non_empty_points : Optional [torch .Tensor ],
304
+ num_points_per_ray : int ,
302
305
) -> Tuple [torch .Tensor , torch .Tensor ]:
303
306
"""
304
307
Calculates density and color at `points`.
305
308
If enabled use cuda streams.
306
309
307
310
Args:
308
311
points: points at which to calculate density and color.
309
- Tensor of shape [... , 3].
310
- directions: from which directions are the points viewed
311
- Tensor of shape [... , 3].
312
+ Tensor of shape [n_points , 3].
313
+ directions: from which directions are the points viewed.
314
+ One per ray. Tensor of shape [n_rays , 3].
312
315
camera: A camera model which will be used to transform the viewing
313
316
directions
317
+ non_empty_points: indices of points which weren't filtered out;
318
+ used for expanding directions
319
+ num_points_per_ray: number of points per ray, needed to expand directions.
314
320
Returns:
315
321
Tuple of color (tensor of shape [..., 3]) and density
316
322
(tensor of shape [..., 1])
@@ -323,20 +329,24 @@ def calculate_density_and_color(
323
329
with torch .cuda .stream (other_stream ):
324
330
# rays_densities.shape =
325
331
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim]
326
- rays_densities = self .get_density (points )
332
+ rays_densities = self ._get_density (points )
327
333
328
334
# rays_colors.shape =
329
335
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim]
330
- rays_colors = self .get_color (points , camera , directions )
336
+ rays_colors = self ._get_color (
337
+ points , camera , directions , non_empty_points , num_points_per_ray
338
+ )
331
339
332
340
current_stream .wait_stream (other_stream )
333
341
else :
334
342
# Same calculation as above, just serial.
335
- rays_densities = self .get_density (points )
336
- rays_colors = self .get_color (points , camera , directions )
343
+ rays_densities = self ._get_density (points )
344
+ rays_colors = self ._get_color (
345
+ points , camera , directions , non_empty_points , num_points_per_ray
346
+ )
337
347
return rays_densities , rays_colors
338
348
339
- def get_density (self , points : torch .Tensor ) -> torch .Tensor :
349
+ def _get_density (self , points : torch .Tensor ) -> torch .Tensor :
340
350
"""
341
351
Calculates density at points:
342
352
1) Evaluates the voxel grid on points
@@ -356,11 +366,13 @@ def get_density(self, points: torch.Tensor) -> torch.Tensor:
356
366
# shape = [..., density_dim]
357
367
return self .decoder_density (harmonic_embedding_density )
358
368
359
- def get_color (
369
+ def _get_color (
360
370
self ,
361
371
points : torch .Tensor ,
362
372
camera : Optional [CamerasBase ],
363
373
directions : torch .Tensor ,
374
+ non_empty_points : Optional [torch .Tensor ],
375
+ num_points_per_ray : int ,
364
376
) -> torch .Tensor :
365
377
"""
366
378
Calculates color at points using the viewing direction:
@@ -376,6 +388,9 @@ def get_color(
376
388
directions
377
389
directions: A tensor of shape `(..., 3)`
378
390
containing the direction vectors of sampling rays in world coords.
391
+ non_empty_points: indices of points which weren't filtered out;
392
+ used for expanding directions
393
+ num_points_per_ray: number of points per ray, needed to expand directions.
379
394
"""
380
395
# ########## transform direction ########## #
381
396
if self .xyz_ray_dir_in_camera_coords :
@@ -400,12 +415,11 @@ def get_color(
400
415
rays_directions_normed
401
416
)
402
417
403
- n_rays = directions .shape [0 ]
404
- points_per_ray : int = points .shape [0 ] // n_rays
405
-
406
418
harmonic_embedding_dir = torch .repeat_interleave (
407
- harmonic_embedding_dir , points_per_ray , dim = 0
419
+ harmonic_embedding_dir , num_points_per_ray , dim = 0
408
420
)
421
+ if non_empty_points is not None :
422
+ harmonic_embedding_dir = harmonic_embedding_dir [non_empty_points ]
409
423
410
424
# total color embedding is concatenation of the harmonic embedding of voxel grid
411
425
# output and harmonic embedding of the normalized direction
@@ -505,7 +519,7 @@ def _get_scaffold(self, epoch: int) -> bool:
505
519
)
506
520
for k in range (0 , points .shape [- 1 ], chunk_size ):
507
521
points_in_planes = points [..., k : k + chunk_size ]
508
- planes .append (self .get_density (points_in_planes )[..., 0 ])
522
+ planes .append (self ._get_density (points_in_planes )[..., 0 ])
509
523
510
524
density_cube = torch .cat (planes , dim = - 1 )
511
525
density_cube = torch .nn .functional .max_pool3d (
0 commit comments