Skip to content

Add typing annotations to detection/backbone_utils #4603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 19, 2021
4 changes: 0 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ ignore_errors=True

ignore_errors = True

[mypy-torchvision.models.detection.backbone_utils]

ignore_errors = True

[mypy-torchvision.models.detection.transform]

ignore_errors = True
Expand Down
62 changes: 38 additions & 24 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from typing import List, Optional
from typing import Callable, Dict, Optional, List

from torch import nn
from torch import nn, Tensor
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock

Expand Down Expand Up @@ -29,7 +29,14 @@ class BackboneWithFPN(nn.Module):
out_channels (int): the number of channels in the FPN
"""

def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
def __init__(
self,
backbone: nn.Module,
return_layers: Dict[str, str],
in_channels_list: List[int],
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> None:
super(BackboneWithFPN, self).__init__()

if extra_blocks is None:
Expand All @@ -43,20 +50,20 @@ def __init__(self, backbone, return_layers, in_channels_list, out_channels, extr
)
self.out_channels = out_channels

def forward(self, x):
def forward(self, x: Tensor) -> Dict[str, Tensor]:
x = self.body(x)
x = self.fpn(x)
return x


def resnet_fpn_backbone(
backbone_name,
pretrained,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None,
):
backbone_name: str,
pretrained: bool,
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers: int = 3,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> BackboneWithFPN:
"""
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.

Expand All @@ -80,7 +87,7 @@ def resnet_fpn_backbone(
backbone_name (string): resnet architecture. Possible values are 'ResNet', 'resnet18', 'resnet34', 'resnet50',
'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet
norm_layer (torchvision.ops): it is recommended to use the default value. For details visit:
norm_layer (callable): it is recommended to use the default value. For details visit:
(https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
Expand All @@ -101,7 +108,8 @@ def _resnet_backbone_config(
trainable_layers: int,
returned_layers: Optional[List[int]],
extra_blocks: Optional[ExtraFPNBlock],
):
) -> BackboneWithFPN:

# select layers that wont be frozen
assert 0 <= trainable_layers <= 5
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
Expand All @@ -125,8 +133,13 @@ def _resnet_backbone_config(
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)


def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value):
# dont freeze any layers if pretrained model or backbone is not used
def _validate_trainable_layers(
pretrained: bool,
trainable_backbone_layers: Optional[int],
max_value: int,
default_value: int,
) -> int:
# don't freeze any layers if pretrained model or backbone is not used
if not pretrained:
if trainable_backbone_layers is not None:
warnings.warn(
Expand All @@ -144,14 +157,15 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value,


def mobilenet_backbone(
backbone_name,
pretrained,
fpn,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=2,
returned_layers=None,
extra_blocks=None,
):
backbone_name: str,
pretrained: bool,
fpn: bool,
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers: int = 2,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> nn.Module:

backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features

# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
Expand Down Expand Up @@ -185,5 +199,5 @@ def mobilenet_backbone(
# depthwise linear combination of channels to reduce their size
nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
)
m.out_channels = out_channels
m.out_channels = out_channels # type: ignore[assignment]
return m
22 changes: 14 additions & 8 deletions torchvision/prototype/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config
from typing import Callable, Optional, List

from torch import nn

from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config, BackboneWithFPN, ExtraFPNBlock
from .. import resnet
from .._api import Weights


def resnet_fpn_backbone(
backbone_name,
weights,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None,
):
backbone_name: str,
weights: Optional[Weights],
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers: int = 3,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> BackboneWithFPN:

backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)