Skip to content

Commit cafa02d

Browse files
committed
Simplify LayerNorm2d implementation.
1 parent 1ab9030 commit cafa02d

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

torchvision/prototype/models/convnext.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,11 @@
2929

3030

3131
class LayerNorm2d(nn.LayerNorm):
32-
def __init__(self, *args: Any, **kwargs: Any) -> None:
33-
self.channels_last = kwargs.pop("channels_last", False)
34-
super().__init__(*args, **kwargs)
35-
3632
def forward(self, x: Tensor) -> Tensor:
3733
# TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298
38-
if not self.channels_last:
39-
x = x.permute(0, 2, 3, 1)
34+
x = x.permute(0, 2, 3, 1)
4035
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
41-
if not self.channels_last:
42-
x = x.permute(0, 3, 1, 2)
36+
x = x.permute(0, 3, 1, 2)
4337
return x
4438

4539

0 commit comments

Comments
 (0)