Skip to content

Commit 96931c1

Browse files
committed
Use custom LayerNorm layer.
1 parent 73f3e1a commit 96931c1

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

torchvision/models/vision_transformer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,10 @@ def __init__(
146146
norm_layer,
147147
)
148148
self.layers = nn.Sequential(layers)
149-
self.ln = norm_layer(hidden_dim)
150149

151150
def forward(self, input: torch.Tensor):
152151
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
153-
return self.ln(self.layers(self.dropout(input)))
152+
return self.layers(self.dropout(input))
154153

155154

156155
class VisionTransformer(nn.Module):

torchvision/ops/feature_pyramid_network.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,29 @@ def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
203203

204204
return out
205205

206+
# TODO: Remove this? The pytorch version MUST have channels as last dimension..
207+
class LayerNorm(torch.nn.Module):
208+
"""
209+
A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
210+
variance normalization over the channel dimension for inputs that have shape
211+
(batch_size, channels, height, width).
212+
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
213+
"""
214+
215+
def __init__(self, normalized_shape, eps=1e-6):
216+
super().__init__()
217+
self.weight = torch.nn.Parameter(torch.ones(normalized_shape))
218+
self.bias = torch.nn.Parameter(torch.zeros(normalized_shape))
219+
self.eps = eps
220+
self.normalized_shape = (normalized_shape,)
221+
222+
def forward(self, x):
223+
u = x.mean(1, keepdim=True)
224+
s = (x - u).pow(2).mean(1, keepdim=True)
225+
x = (x - u) / torch.sqrt(s + self.eps)
226+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
227+
return x
228+
206229

207230
class SimpleFeaturePyramidNetwork(nn.Module):
208231
"""
@@ -220,7 +243,7 @@ class SimpleFeaturePyramidNetwork(nn.Module):
220243
be performed. It is expected to take the fpn features, the original
221244
features and the names of the original features as input, and returns
222245
a new list of feature maps and their corresponding names
223-
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
246+
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: LayerNorm
224247
225248
Examples::
226249
@@ -244,7 +267,7 @@ def __init__(
244267
in_channels: int,
245268
out_channels: int,
246269
extra_blocks: Optional[ExtraFPNBlock] = None,
247-
norm_layer: Optional[Callable[..., nn.Module]] = None,
270+
norm_layer: Optional[Callable[..., nn.Module]] = LayerNorm,
248271
):
249272
super().__init__()
250273
_log_api_usage_once(self)
@@ -257,26 +280,14 @@ def __init__(
257280

258281
current_in_channels = in_channels
259282
if block_index == 0:
260-
# This class and its uses is required because of:
261-
# https://github.com/pytorch/pytorch/issues/71465
262-
class Permute(nn.Module):
263-
def __init__(self, *dims):
264-
super().__init__()
265-
self.dims = dims
266-
267-
def forward(self, x: Tensor) -> Tensor:
268-
return x.permute(self.dims)
269-
270283
layers.extend([
271284
nn.ConvTranspose2d(
272285
in_channels,
273286
in_channels // 2,
274287
kernel_size=2,
275288
stride=2,
276289
),
277-
Permute(0, 2, 3, 1),
278-
nn.LayerNorm(in_channels // 2),
279-
Permute(0, 3, 1, 2),
290+
LayerNorm(in_channels // 2),
280291
nn.GELU(),
281292
nn.ConvTranspose2d(
282293
in_channels // 2,

0 commit comments

Comments
 (0)