@@ -203,6 +203,29 @@ def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
203
203
204
204
return out
205
205
206
+ # TODO: Remove this? The pytorch version MUST have channels as last dimension..
207
+ class LayerNorm (torch .nn .Module ):
208
+ """
209
+ A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
210
+ variance normalization over the channel dimension for inputs that have shape
211
+ (batch_size, channels, height, width).
212
+ https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
213
+ """
214
+
215
+ def __init__ (self , normalized_shape , eps = 1e-6 ):
216
+ super ().__init__ ()
217
+ self .weight = torch .nn .Parameter (torch .ones (normalized_shape ))
218
+ self .bias = torch .nn .Parameter (torch .zeros (normalized_shape ))
219
+ self .eps = eps
220
+ self .normalized_shape = (normalized_shape ,)
221
+
222
+ def forward (self , x ):
223
+ u = x .mean (1 , keepdim = True )
224
+ s = (x - u ).pow (2 ).mean (1 , keepdim = True )
225
+ x = (x - u ) / torch .sqrt (s + self .eps )
226
+ x = self .weight [:, None , None ] * x + self .bias [:, None , None ]
227
+ return x
228
+
206
229
207
230
class SimpleFeaturePyramidNetwork (nn .Module ):
208
231
"""
@@ -220,7 +243,7 @@ class SimpleFeaturePyramidNetwork(nn.Module):
220
243
be performed. It is expected to take the fpn features, the original
221
244
features and the names of the original features as input, and returns
222
245
a new list of feature maps and their corresponding names
223
- norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
246
+ norm_layer (callable, optional): Module specifying the normalization layer to use. Default: LayerNorm
224
247
225
248
Examples::
226
249
@@ -244,7 +267,7 @@ def __init__(
244
267
in_channels : int ,
245
268
out_channels : int ,
246
269
extra_blocks : Optional [ExtraFPNBlock ] = None ,
247
- norm_layer : Optional [Callable [..., nn .Module ]] = None ,
270
+ norm_layer : Optional [Callable [..., nn .Module ]] = LayerNorm ,
248
271
):
249
272
super ().__init__ ()
250
273
_log_api_usage_once (self )
@@ -257,26 +280,14 @@ def __init__(
257
280
258
281
current_in_channels = in_channels
259
282
if block_index == 0 :
260
- # This class and its uses is required because of:
261
- # https://github.com/pytorch/pytorch/issues/71465
262
- class Permute (nn .Module ):
263
- def __init__ (self , * dims ):
264
- super ().__init__ ()
265
- self .dims = dims
266
-
267
- def forward (self , x : Tensor ) -> Tensor :
268
- return x .permute (self .dims )
269
-
270
283
layers .extend ([
271
284
nn .ConvTranspose2d (
272
285
in_channels ,
273
286
in_channels // 2 ,
274
287
kernel_size = 2 ,
275
288
stride = 2 ,
276
289
),
277
- Permute (0 , 2 , 3 , 1 ),
278
- nn .LayerNorm (in_channels // 2 ),
279
- Permute (0 , 3 , 1 , 2 ),
290
+ LayerNorm (in_channels // 2 ),
280
291
nn .GELU (),
281
292
nn .ConvTranspose2d (
282
293
in_channels // 2 ,
0 commit comments