Skip to content

Commit e7c609f

Browse files
Darijan Gudeljfacebook-github-bot
authored andcommitted
Decoding functions
Summary: Added replacable decoding functions which will be applied after the voxel grid to get color and density Reviewed By: bottler Differential Revision: D38829763 fbshipit-source-id: f21ce206c1c19548206ea2ce97d7ebea3de30a23
1 parent 24f5f4a commit e7c609f

File tree

3 files changed

+148
-47
lines changed

3 files changed

+148
-47
lines changed

pytorch3d/implicitron/models/implicit_function/decoding_functions.py

Lines changed: 112 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,66 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
"""
8+
This file contains
9+
- modules which get used by ImplicitFunction objects for decoding an embedding defined in
10+
space, e.g. to color or opacity.
11+
- DecoderFunctionBase and its subclasses, which wrap some of those modules, providing
12+
some such modules as an extension point which an ImplicitFunction object could use.
13+
"""
14+
715
import logging
816

917
from typing import Optional, Tuple
1018

1119
import torch
1220

21+
from pytorch3d.implicitron.tools.config import (
22+
Configurable,
23+
registry,
24+
ReplaceableBase,
25+
run_auto_creation,
26+
)
27+
1328
logger = logging.getLogger(__name__)
1429

1530

16-
class MLPWithInputSkips(torch.nn.Module):
31+
class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
32+
"""
33+
Decoding function is a torch.nn.Module which takes the embedding of a location in
34+
space and transforms it into the required quantity (for example density and color).
35+
"""
36+
37+
def __post_init__(self):
38+
super().__init__()
39+
40+
def forward(
41+
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
42+
) -> torch.Tensor:
43+
"""
44+
Args:
45+
features (torch.Tensor): tensor of shape (batch, ..., num_in_features)
46+
z: optional tensor to append to parts of the decoding function
47+
Returns:
48+
decoded_features (torch.Tensor) : tensor of
49+
shape (batch, ..., num_out_features)
50+
"""
51+
raise NotImplementedError()
52+
53+
54+
@registry.register
55+
class IdentityDecoder(DecoderFunctionBase):
56+
"""
57+
Decoding function which returns its input.
58+
"""
59+
60+
def forward(
61+
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
62+
) -> torch.Tensor:
63+
return features
64+
65+
66+
class MLPWithInputSkips(Configurable, torch.nn.Module):
1767
"""
1868
Implements the multi-layer perceptron architecture of the Neural Radiance Field.
1969
@@ -31,70 +81,68 @@ class MLPWithInputSkips(torch.nn.Module):
3181
and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng:
3282
NeRF: Representing Scenes as Neural Radiance Fields for View
3383
Synthesis, ECCV2020
84+
85+
Members:
86+
n_layers: The number of linear layers of the MLP.
87+
input_dim: The number of channels of the input tensor.
88+
output_dim: The number of channels of the output.
89+
skip_dim: The number of channels of the tensor `z` appended when
90+
evaluating the skip layers.
91+
hidden_dim: The number of hidden units of the MLP.
92+
input_skips: The list of layer indices at which we append the skip
93+
tensor `z`.
3494
"""
3595

36-
def _make_affine_layer(self, input_dim, hidden_dim):
37-
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
38-
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
39-
_xavier_init(l1)
40-
_xavier_init(l2)
41-
return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2)
96+
n_layers: int = 8
97+
input_dim: int = 39
98+
output_dim: int = 256
99+
skip_dim: int = 39
100+
hidden_dim: int = 256
101+
input_skips: Tuple[int, ...] = (5,)
102+
skip_affine_trans: bool = False
103+
no_last_relu = False
42104

43-
def _apply_affine_layer(self, layer, x, z):
44-
mu_log_std = layer(z)
45-
mu, log_std = mu_log_std.split(mu_log_std.shape[-1] // 2, dim=-1)
46-
std = torch.nn.functional.softplus(log_std)
47-
return (x - mu) * std
48-
49-
def __init__(
50-
self,
51-
n_layers: int = 8,
52-
input_dim: int = 39,
53-
output_dim: int = 256,
54-
skip_dim: int = 39,
55-
hidden_dim: int = 256,
56-
input_skips: Tuple[int, ...] = (5,),
57-
skip_affine_trans: bool = False,
58-
no_last_relu=False,
59-
):
60-
"""
61-
Args:
62-
n_layers: The number of linear layers of the MLP.
63-
input_dim: The number of channels of the input tensor.
64-
output_dim: The number of channels of the output.
65-
skip_dim: The number of channels of the tensor `z` appended when
66-
evaluating the skip layers.
67-
hidden_dim: The number of hidden units of the MLP.
68-
input_skips: The list of layer indices at which we append the skip
69-
tensor `z`.
70-
"""
105+
def __post_init__(self):
71106
super().__init__()
72107
layers = []
73108
skip_affine_layers = []
74-
for layeri in range(n_layers):
75-
dimin = hidden_dim if layeri > 0 else input_dim
76-
dimout = hidden_dim if layeri + 1 < n_layers else output_dim
109+
for layeri in range(self.n_layers):
110+
dimin = self.hidden_dim if layeri > 0 else self.input_dim
111+
dimout = self.hidden_dim if layeri + 1 < self.n_layers else self.output_dim
77112

78-
if layeri > 0 and layeri in input_skips:
79-
if skip_affine_trans:
113+
if layeri > 0 and layeri in self.input_skips:
114+
if self.skip_affine_trans:
80115
skip_affine_layers.append(
81-
self._make_affine_layer(skip_dim, hidden_dim)
116+
self._make_affine_layer(self.skip_dim, self.hidden_dim)
82117
)
83118
else:
84-
dimin = hidden_dim + skip_dim
119+
dimin = self.hidden_dim + self.skip_dim
85120

86121
linear = torch.nn.Linear(dimin, dimout)
87122
_xavier_init(linear)
88123
layers.append(
89124
torch.nn.Sequential(linear, torch.nn.ReLU(True))
90-
if not no_last_relu or layeri + 1 < n_layers
125+
if not self.no_last_relu or layeri + 1 < self.n_layers
91126
else linear
92127
)
93128
self.mlp = torch.nn.ModuleList(layers)
94-
if skip_affine_trans:
129+
if self.skip_affine_trans:
95130
self.skip_affines = torch.nn.ModuleList(skip_affine_layers)
96-
self._input_skips = set(input_skips)
97-
self._skip_affine_trans = skip_affine_trans
131+
self._input_skips = set(self.input_skips)
132+
self._skip_affine_trans = self.skip_affine_trans
133+
134+
def _make_affine_layer(self, input_dim, hidden_dim):
135+
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
136+
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
137+
_xavier_init(l1)
138+
_xavier_init(l2)
139+
return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2)
140+
141+
def _apply_affine_layer(self, layer, x, z):
142+
mu_log_std = layer(z)
143+
mu, log_std = mu_log_std.split(mu_log_std.shape[-1] // 2, dim=-1)
144+
std = torch.nn.functional.softplus(log_std)
145+
return (x - mu) * std
98146

99147
def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
100148
"""
@@ -121,6 +169,24 @@ def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
121169
return y
122170

123171

172+
@registry.register
173+
class MLPDecoder(DecoderFunctionBase):
174+
"""
175+
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
176+
"""
177+
178+
network: MLPWithInputSkips
179+
180+
def __post_init__(self):
181+
super().__post_init__()
182+
run_auto_creation(self)
183+
184+
def forward(
185+
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
186+
) -> torch.Tensor:
187+
return self.network(features, z)
188+
189+
124190
class TransformerWithInputSkips(torch.nn.Module):
125191
def __init__(
126192
self,

pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
12-
from pytorch3d.implicitron.tools.config import registry
12+
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
1313
from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle
1414
from pytorch3d.renderer.cameras import CamerasBase
1515
from pytorch3d.renderer.implicit import HarmonicEmbedding
@@ -214,6 +214,7 @@ class NeuralRadianceFieldImplicitFunction(NeuralRadianceFieldBase):
214214
append_xyz: Tuple[int, ...] = (5,)
215215

216216
def _construct_xyz_encoder(self, input_dim: int):
217+
expand_args_fields(MLPWithInputSkips)
217218
return MLPWithInputSkips(
218219
self.n_layers_xyz,
219220
input_dim,
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import unittest
9+
10+
import torch
11+
12+
from pytorch3d.implicitron.models.implicit_function.decoding_functions import (
13+
IdentityDecoder,
14+
MLPDecoder,
15+
)
16+
from pytorch3d.implicitron.tools.config import expand_args_fields
17+
18+
from tests.common_testing import TestCaseMixin
19+
20+
21+
class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
22+
def setUp(self):
23+
torch.manual_seed(42)
24+
expand_args_fields(IdentityDecoder)
25+
expand_args_fields(MLPDecoder)
26+
27+
def test_identity_function(self, in_shape=(33, 4, 1), n_tests=2):
28+
"""
29+
Test that identity function returns its input
30+
"""
31+
func = IdentityDecoder()
32+
for _ in range(n_tests):
33+
_in = torch.randn(in_shape)
34+
assert torch.allclose(func(_in), _in)

0 commit comments

Comments
 (0)