Skip to content

Commit 2231c4a

Browse files
authored
Merge branch 'main' into multi-crop
2 parents da1fa8f + 350a3e8 commit 2231c4a

File tree

10 files changed

+53
-50
lines changed

10 files changed

+53
-50
lines changed

torchvision/models/detection/faster_rcnn.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -383,15 +383,15 @@ def fasterrcnn_resnet50_fpn(
383383
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
384384
passed (the default) this value is set to 3.
385385
"""
386-
trainable_backbone_layers = _validate_trainable_layers(
387-
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
388-
)
386+
is_trained = pretrained or pretrained_backbone
387+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
388+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
389389

390390
if pretrained:
391391
# no need to download the backbone if pretrained is set
392392
pretrained_backbone = False
393393

394-
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
394+
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
395395
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
396396
model = FasterRCNN(backbone, num_classes, **kwargs)
397397
if pretrained:
@@ -410,16 +410,14 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
410410
trainable_backbone_layers=None,
411411
**kwargs,
412412
):
413-
trainable_backbone_layers = _validate_trainable_layers(
414-
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3
415-
)
413+
is_trained = pretrained or pretrained_backbone
414+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
415+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
416416

417417
if pretrained:
418418
pretrained_backbone = False
419419

420-
backbone = mobilenet_v3_large(
421-
pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d
422-
)
420+
backbone = mobilenet_v3_large(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
423421
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
424422

425423
anchor_sizes = (

torchvision/models/detection/fcos.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -686,15 +686,15 @@ def fcos_resnet50_fpn(
686686
from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
687687
trainable. If ``None`` is passed (the default) this value is set to 3. Default: None
688688
"""
689-
trainable_backbone_layers = _validate_trainable_layers(
690-
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
691-
)
689+
is_trained = pretrained or pretrained_backbone
690+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
691+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
692692

693693
if pretrained:
694694
# no need to download the backbone if pretrained is set
695695
pretrained_backbone = False
696696

697-
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
697+
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
698698
backbone = _resnet_fpn_extractor(
699699
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
700700
)

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,15 +365,15 @@ def keypointrcnn_resnet50_fpn(
365365
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
366366
passed (the default) this value is set to 3.
367367
"""
368-
trainable_backbone_layers = _validate_trainable_layers(
369-
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
370-
)
368+
is_trained = pretrained or pretrained_backbone
369+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
370+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
371371

372372
if pretrained:
373373
# no need to download the backbone if pretrained is set
374374
pretrained_backbone = False
375375

376-
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
376+
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
377377
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
378378
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
379379
if pretrained:

torchvision/models/detection/mask_rcnn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,15 +360,15 @@ def maskrcnn_resnet50_fpn(
360360
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
361361
passed (the default) this value is set to 3.
362362
"""
363-
trainable_backbone_layers = _validate_trainable_layers(
364-
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
365-
)
363+
is_trained = pretrained or pretrained_backbone
364+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
365+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
366366

367367
if pretrained:
368368
# no need to download the backbone if pretrained is set
369369
pretrained_backbone = False
370370

371-
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
371+
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
372372
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
373373
model = MaskRCNN(backbone, num_classes, **kwargs)
374374
if pretrained:

torchvision/models/detection/retinanet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -626,15 +626,15 @@ def retinanet_resnet50_fpn(
626626
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
627627
passed (the default) this value is set to 3.
628628
"""
629-
trainable_backbone_layers = _validate_trainable_layers(
630-
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
631-
)
629+
is_trained = pretrained or pretrained_backbone
630+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
631+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
632632

633633
if pretrained:
634634
# no need to download the backbone if pretrained is set
635635
pretrained_backbone = False
636636

637-
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
637+
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
638638
# skip P2 because it generates too many anchors (according to their paper)
639639
backbone = _resnet_fpn_extractor(
640640
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)

torchvision/prototype/models/detection/faster_rcnn.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Optional, Union
22

3+
from torch import nn
34
from torchvision.prototype.transforms import CocoEval
45
from torchvision.transforms.functional import InterpolationMode
56

@@ -103,11 +104,11 @@ def fasterrcnn_resnet50_fpn(
103104
elif num_classes is None:
104105
num_classes = 91
105106

106-
trainable_backbone_layers = _validate_trainable_layers(
107-
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
108-
)
107+
is_trained = weights is not None or weights_backbone is not None
108+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
109+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
109110

110-
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
111+
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
111112
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
112113
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
113114

@@ -134,11 +135,11 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
134135
elif num_classes is None:
135136
num_classes = 91
136137

137-
trainable_backbone_layers = _validate_trainable_layers(
138-
weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 3
139-
)
138+
is_trained = weights is not None or weights_backbone is not None
139+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
140+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
140141

141-
backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
142+
backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
142143
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
143144
anchor_sizes = (
144145
(

torchvision/prototype/models/detection/fcos.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Optional
22

3+
from torch import nn
34
from torchvision.prototype.transforms import CocoEval
45
from torchvision.transforms.functional import InterpolationMode
56

@@ -63,11 +64,11 @@ def fcos_resnet50_fpn(
6364
elif num_classes is None:
6465
num_classes = 91
6566

66-
trainable_backbone_layers = _validate_trainable_layers(
67-
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
68-
)
67+
is_trained = weights is not None or weights_backbone is not None
68+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
69+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
6970

70-
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
71+
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
7172
backbone = _resnet_fpn_extractor(
7273
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
7374
)

torchvision/prototype/models/detection/keypoint_rcnn.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Optional
22

3+
from torch import nn
34
from torchvision.prototype.transforms import CocoEval
45
from torchvision.transforms.functional import InterpolationMode
56

@@ -91,11 +92,11 @@ def keypointrcnn_resnet50_fpn(
9192
if num_keypoints is None:
9293
num_keypoints = 17
9394

94-
trainable_backbone_layers = _validate_trainable_layers(
95-
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
96-
)
95+
is_trained = weights is not None or weights_backbone is not None
96+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
97+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
9798

98-
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
99+
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
99100
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
100101
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
101102

torchvision/prototype/models/detection/mask_rcnn.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Optional
22

3+
from torch import nn
34
from torchvision.prototype.transforms import CocoEval
45
from torchvision.transforms.functional import InterpolationMode
56

@@ -64,11 +65,11 @@ def maskrcnn_resnet50_fpn(
6465
elif num_classes is None:
6566
num_classes = 91
6667

67-
trainable_backbone_layers = _validate_trainable_layers(
68-
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
69-
)
68+
is_trained = weights is not None or weights_backbone is not None
69+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
70+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
7071

71-
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
72+
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
7273
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
7374
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
7475

torchvision/prototype/models/detection/retinanet.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Optional
22

3+
from torch import nn
34
from torchvision.prototype.transforms import CocoEval
45
from torchvision.transforms.functional import InterpolationMode
56

@@ -64,11 +65,11 @@ def retinanet_resnet50_fpn(
6465
elif num_classes is None:
6566
num_classes = 91
6667

67-
trainable_backbone_layers = _validate_trainable_layers(
68-
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
69-
)
68+
is_trained = weights is not None or weights_backbone is not None
69+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
70+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
7071

71-
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
72+
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
7273
# skip P2 because it generates too many anchors (according to their paper)
7374
backbone = _resnet_fpn_extractor(
7475
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)

0 commit comments

Comments
 (0)