File tree Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Original file line number Diff line number Diff line change 3
3
import torch
4
4
from torch import nn , Tensor
5
5
6
+ from .misc import FrozenBatchNorm2d
7
+
6
8
7
9
def _cat (tensors : List [Tensor ], dim : int = 0 ) -> Tensor :
8
10
"""
@@ -43,7 +45,13 @@ def split_normalization_params(
43
45
) -> Tuple [List [Tensor ], List [Tensor ]]:
44
46
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
45
47
if not norm_classes :
46
- norm_classes = [nn .modules .batchnorm ._BatchNorm , nn .LayerNorm , nn .GroupNorm ]
48
+ norm_classes = [
49
+ nn .modules .batchnorm ._BatchNorm ,
50
+ nn .LayerNorm ,
51
+ nn .GroupNorm ,
52
+ nn .modules .instancenorm ._InstanceNorm ,
53
+ nn .LocalResponseNorm ,
54
+ ]
47
55
48
56
for t in norm_classes :
49
57
if not issubclass (t , nn .Module ):
You can’t perform that action at this time.
0 commit comments