44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ """
8+ This file contains
9+ - modules which get used by ImplicitFunction objects for decoding an embedding defined in
10+ space, e.g. to color or opacity.
11+ - DecoderFunctionBase and its subclasses, which wrap some of those modules, providing
12+ some such modules as an extension point which an ImplicitFunction object could use.
13+ """
14+
715import logging
816
917from typing import Optional , Tuple
1018
1119import torch
1220
21+ from pytorch3d .implicitron .tools .config import (
22+ Configurable ,
23+ registry ,
24+ ReplaceableBase ,
25+ run_auto_creation ,
26+ )
27+
1328logger = logging .getLogger (__name__ )
1429
1530
16- class MLPWithInputSkips (torch .nn .Module ):
31+ class DecoderFunctionBase (ReplaceableBase , torch .nn .Module ):
32+ """
33+ Decoding function is a torch.nn.Module which takes the embedding of a location in
34+ space and transforms it into the required quantity (for example density and color).
35+ """
36+
37+ def __post_init__ (self ):
38+ super ().__init__ ()
39+
40+ def forward (
41+ self , features : torch .Tensor , z : Optional [torch .Tensor ] = None
42+ ) -> torch .Tensor :
43+ """
44+ Args:
45+ features (torch.Tensor): tensor of shape (batch, ..., num_in_features)
46+ z: optional tensor to append to parts of the decoding function
47+ Returns:
48+ decoded_features (torch.Tensor) : tensor of
49+ shape (batch, ..., num_out_features)
50+ """
51+ raise NotImplementedError ()
52+
53+
54+ @registry .register
55+ class IdentityDecoder (DecoderFunctionBase ):
56+ """
57+ Decoding function which returns its input.
58+ """
59+
60+ def forward (
61+ self , features : torch .Tensor , z : Optional [torch .Tensor ] = None
62+ ) -> torch .Tensor :
63+ return features
64+
65+
66+ class MLPWithInputSkips (Configurable , torch .nn .Module ):
1767 """
1868 Implements the multi-layer perceptron architecture of the Neural Radiance Field.
1969
@@ -31,70 +81,68 @@ class MLPWithInputSkips(torch.nn.Module):
3181 and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng:
3282 NeRF: Representing Scenes as Neural Radiance Fields for View
3383 Synthesis, ECCV2020
84+
85+ Members:
86+ n_layers: The number of linear layers of the MLP.
87+ input_dim: The number of channels of the input tensor.
88+ output_dim: The number of channels of the output.
89+ skip_dim: The number of channels of the tensor `z` appended when
90+ evaluating the skip layers.
91+ hidden_dim: The number of hidden units of the MLP.
92+ input_skips: The list of layer indices at which we append the skip
93+ tensor `z`.
3494 """
3595
36- def _make_affine_layer (self , input_dim , hidden_dim ):
37- l1 = torch .nn .Linear (input_dim , hidden_dim * 2 )
38- l2 = torch .nn .Linear (hidden_dim * 2 , hidden_dim * 2 )
39- _xavier_init (l1 )
40- _xavier_init (l2 )
41- return torch .nn .Sequential (l1 , torch .nn .ReLU (True ), l2 )
96+ n_layers : int = 8
97+ input_dim : int = 39
98+ output_dim : int = 256
99+ skip_dim : int = 39
100+ hidden_dim : int = 256
101+ input_skips : Tuple [int , ...] = (5 ,)
102+ skip_affine_trans : bool = False
103+ no_last_relu = False
42104
43- def _apply_affine_layer (self , layer , x , z ):
44- mu_log_std = layer (z )
45- mu , log_std = mu_log_std .split (mu_log_std .shape [- 1 ] // 2 , dim = - 1 )
46- std = torch .nn .functional .softplus (log_std )
47- return (x - mu ) * std
48-
49- def __init__ (
50- self ,
51- n_layers : int = 8 ,
52- input_dim : int = 39 ,
53- output_dim : int = 256 ,
54- skip_dim : int = 39 ,
55- hidden_dim : int = 256 ,
56- input_skips : Tuple [int , ...] = (5 ,),
57- skip_affine_trans : bool = False ,
58- no_last_relu = False ,
59- ):
60- """
61- Args:
62- n_layers: The number of linear layers of the MLP.
63- input_dim: The number of channels of the input tensor.
64- output_dim: The number of channels of the output.
65- skip_dim: The number of channels of the tensor `z` appended when
66- evaluating the skip layers.
67- hidden_dim: The number of hidden units of the MLP.
68- input_skips: The list of layer indices at which we append the skip
69- tensor `z`.
70- """
105+ def __post_init__ (self ):
71106 super ().__init__ ()
72107 layers = []
73108 skip_affine_layers = []
74- for layeri in range (n_layers ):
75- dimin = hidden_dim if layeri > 0 else input_dim
76- dimout = hidden_dim if layeri + 1 < n_layers else output_dim
109+ for layeri in range (self . n_layers ):
110+ dimin = self . hidden_dim if layeri > 0 else self . input_dim
111+ dimout = self . hidden_dim if layeri + 1 < self . n_layers else self . output_dim
77112
78- if layeri > 0 and layeri in input_skips :
79- if skip_affine_trans :
113+ if layeri > 0 and layeri in self . input_skips :
114+ if self . skip_affine_trans :
80115 skip_affine_layers .append (
81- self ._make_affine_layer (skip_dim , hidden_dim )
116+ self ._make_affine_layer (self . skip_dim , self . hidden_dim )
82117 )
83118 else :
84- dimin = hidden_dim + skip_dim
119+ dimin = self . hidden_dim + self . skip_dim
85120
86121 linear = torch .nn .Linear (dimin , dimout )
87122 _xavier_init (linear )
88123 layers .append (
89124 torch .nn .Sequential (linear , torch .nn .ReLU (True ))
90- if not no_last_relu or layeri + 1 < n_layers
125+ if not self . no_last_relu or layeri + 1 < self . n_layers
91126 else linear
92127 )
93128 self .mlp = torch .nn .ModuleList (layers )
94- if skip_affine_trans :
129+ if self . skip_affine_trans :
95130 self .skip_affines = torch .nn .ModuleList (skip_affine_layers )
96- self ._input_skips = set (input_skips )
97- self ._skip_affine_trans = skip_affine_trans
131+ self ._input_skips = set (self .input_skips )
132+ self ._skip_affine_trans = self .skip_affine_trans
133+
134+ def _make_affine_layer (self , input_dim , hidden_dim ):
135+ l1 = torch .nn .Linear (input_dim , hidden_dim * 2 )
136+ l2 = torch .nn .Linear (hidden_dim * 2 , hidden_dim * 2 )
137+ _xavier_init (l1 )
138+ _xavier_init (l2 )
139+ return torch .nn .Sequential (l1 , torch .nn .ReLU (True ), l2 )
140+
141+ def _apply_affine_layer (self , layer , x , z ):
142+ mu_log_std = layer (z )
143+ mu , log_std = mu_log_std .split (mu_log_std .shape [- 1 ] // 2 , dim = - 1 )
144+ std = torch .nn .functional .softplus (log_std )
145+ return (x - mu ) * std
98146
99147 def forward (self , x : torch .Tensor , z : Optional [torch .Tensor ] = None ):
100148 """
@@ -121,6 +169,24 @@ def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
121169 return y
122170
123171
172+ @registry .register
173+ class MLPDecoder (DecoderFunctionBase ):
174+ """
175+ Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
176+ """
177+
178+ network : MLPWithInputSkips
179+
180+ def __post_init__ (self ):
181+ super ().__post_init__ ()
182+ run_auto_creation (self )
183+
184+ def forward (
185+ self , features : torch .Tensor , z : Optional [torch .Tensor ] = None
186+ ) -> torch .Tensor :
187+ return self .network (features , z )
188+
189+
124190class TransformerWithInputSkips (torch .nn .Module ):
125191 def __init__ (
126192 self ,
0 commit comments