Skip to content

Commit edee25a

Browse files
Darijan Gudeljfacebook-github-bot
Darijan Gudelj
authored andcommitted
voxel grids with interpolation
Summary: Added voxel grid classes from TensoRF, both in their factorized (CP and VM) and full form. Reviewed By: bottler Differential Revision: D38465419 fbshipit-source-id: 8b306338af58dc50ef47a682616022a0512c0047
1 parent af799fa commit edee25a

File tree

3 files changed

+1112
-0
lines changed

3 files changed

+1112
-0
lines changed

pytorch3d/implicitron/models/implicit_function/utils.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import Callable, Optional
88

99
import torch
10+
11+
import torch.nn.functional as F
1012
from pytorch3d.common.compat import prod
1113
from pytorch3d.renderer.cameras import CamerasBase
1214

@@ -88,3 +90,98 @@ def create_embeddings_for_implicit_function(
8890
embeds = broadcast_global_code(embeds, global_code)
8991

9092
return embeds
93+
94+
95+
def interpolate_line(
96+
points: torch.Tensor,
97+
source: torch.Tensor,
98+
**kwargs,
99+
) -> torch.Tensor:
100+
"""
101+
Linearly interpolates values of source grids. The first dimension of points represents
102+
number of points and the second coordinate, for example ([[x0], [x1], ...]). The first
103+
dimension of argument source represents feature and ones after that the spatial
104+
dimension.
105+
106+
Arguments:
107+
points: shape (n_grids, n_points, 1),
108+
source: tensor of shape (n_grids, features, width),
109+
Returns:
110+
interpolated tensor of shape (n_grids, n_points, features)
111+
"""
112+
# To enable sampling of the source using the torch.functional.grid_sample
113+
# points need to have 2 coordinates.
114+
expansion = points.new_zeros(points.shape)
115+
points = torch.cat((points, expansion), dim=-1)
116+
117+
source = source[:, :, None, :]
118+
points = points[:, :, None, :]
119+
120+
out = F.grid_sample(
121+
grid=points,
122+
input=source,
123+
**kwargs,
124+
)
125+
return out[:, :, :, 0].permute(0, 2, 1)
126+
127+
128+
def interpolate_plane(
129+
points: torch.Tensor,
130+
source: torch.Tensor,
131+
**kwargs,
132+
) -> torch.Tensor:
133+
"""
134+
Bilinearly interpolates values of source grids. The first dimension of points represents
135+
number of points and the second coordinates, for example ([[x0, y0], [x1, y1], ...]).
136+
The first dimension of argument source represents feature and ones after that the
137+
spatial dimension.
138+
139+
Arguments:
140+
points: shape (n_grids, n_points, 2),
141+
source: tensor of shape (n_grids, features, width, height),
142+
Returns:
143+
interpolated tensor of shape (n_grids, n_points, features)
144+
"""
145+
# permuting because torch.nn.functional.grid_sample works with
146+
# (features, height, width) and not
147+
# (features, width, height)
148+
source = source.permute(0, 1, 3, 2)
149+
points = points[:, :, None, :]
150+
151+
out = F.grid_sample(
152+
grid=points,
153+
input=source,
154+
**kwargs,
155+
)
156+
return out[:, :, :, 0].permute(0, 2, 1)
157+
158+
159+
def interpolate_volume(
160+
points: torch.Tensor, source: torch.Tensor, **kwargs
161+
) -> torch.Tensor:
162+
"""
163+
Interpolates values of source grids. The first dimension of points represents
164+
number of points and the second coordinates, for example
165+
[[x0, y0, z0], [x1, y1, z1], ...]. The first dimension of a source represents features
166+
and ones after that the spatial dimension.
167+
168+
Arguments:
169+
points: shape (n_grids, n_points, 3),
170+
source: tensor of shape (n_grids, features, width, height, depth),
171+
Returns:
172+
interpolated tensor of shape (n_grids, n_points, features)
173+
"""
174+
if "mode" in kwargs and kwargs["mode"] == "trilinear":
175+
kwargs = kwargs.copy()
176+
kwargs["mode"] = "bilinear"
177+
# permuting because torch.nn.functional.grid_sample works with
178+
# (features, depth, height, width) and not (features, width, height, depth)
179+
source = source.permute(0, 1, 4, 3, 2)
180+
grid = points[:, :, None, None, :]
181+
182+
out = F.grid_sample(
183+
grid=grid,
184+
input=source,
185+
**kwargs,
186+
)
187+
return out[:, :, :, 0, 0].permute(0, 2, 1)

0 commit comments

Comments
 (0)