diff --git a/mypy.ini b/mypy.ini index 4820cdf1661..d33ec7fcd21 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,7 +17,23 @@ ignore_errors = True ignore_errors=True -[mypy-torchvision.models.detection.*] +[mypy-torchvision.models.detection._utils] + +ignore_errors = True + +[mypy-torchvision.models.detection.anchor_utils] + +ignore_errors = True + +[mypy-torchvision.models.detection.backbone_utils] + +ignore_errors = True + +[mypy-torchvision.models.detection.roi_heads] + +ignore_errors = True + +[mypy-torchvision.models.detection.ssdlite] ignore_errors = True diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 02da39e8c73..b4ec18df283 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -1,3 +1,5 @@ +from typing import Optional, Tuple, Any, cast, List + import torch.nn.functional as F from torch import nn from torchvision.ops import MultiScaleRoIAlign @@ -55,10 +57,10 @@ class FasterRCNN(GeneralizedRCNN): If box_predictor is specified, num_classes should be None. min_size (int): minimum size of the image to be rescaled before feeding it to the backbone max_size (int): maximum size of the image to be rescaled before feeding it to the backbone - image_mean (Tuple[float, float, float]): mean values used for input normalization. + image_mean (List[float]): mean values used for input normalization. They are generally the mean values of the dataset on which the backbone has been trained on - image_std (Tuple[float, float, float]): std values used for input normalization. + image_std (List[float]): std values used for input normalization. They are generally the std values of the dataset on which the backbone has been trained on rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature maps. @@ -143,39 +145,39 @@ class FasterRCNN(GeneralizedRCNN): def __init__( self, - backbone, - num_classes=None, + backbone: nn.Module, + num_classes: Optional[int] = None, # transform parameters - min_size=800, - max_size=1333, - image_mean=None, - image_std=None, + min_size: int = 800, + max_size: int = 1333, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, # RPN parameters - rpn_anchor_generator=None, - rpn_head=None, - rpn_pre_nms_top_n_train=2000, - rpn_pre_nms_top_n_test=1000, - rpn_post_nms_top_n_train=2000, - rpn_post_nms_top_n_test=1000, - rpn_nms_thresh=0.7, - rpn_fg_iou_thresh=0.7, - rpn_bg_iou_thresh=0.3, - rpn_batch_size_per_image=256, - rpn_positive_fraction=0.5, - rpn_score_thresh=0.0, + rpn_anchor_generator: Optional[AnchorGenerator] = None, + rpn_head: Optional[nn.Module] = None, + rpn_pre_nms_top_n_train: int = 2000, + rpn_pre_nms_top_n_test: int = 1000, + rpn_post_nms_top_n_train: int = 2000, + rpn_post_nms_top_n_test: int = 1000, + rpn_nms_thresh: float = 0.7, + rpn_fg_iou_thresh: float = 0.7, + rpn_bg_iou_thresh: float = 0.3, + rpn_batch_size_per_image: int = 256, + rpn_positive_fraction: float = 0.5, + rpn_score_thresh: float = 0.0, # Box parameters - box_roi_pool=None, - box_head=None, - box_predictor=None, - box_score_thresh=0.05, - box_nms_thresh=0.5, - box_detections_per_img=100, - box_fg_iou_thresh=0.5, - box_bg_iou_thresh=0.5, - box_batch_size_per_image=512, - box_positive_fraction=0.25, - bbox_reg_weights=None, - ): + box_roi_pool: Optional[MultiScaleRoIAlign] = None, + box_head: Optional[nn.Module] = None, + box_predictor: Optional[nn.Module] = None, + box_score_thresh: float = 0.05, + box_nms_thresh: float = 0.5, + box_detections_per_img: int = 100, + box_fg_iou_thresh: float = 0.5, + box_bg_iou_thresh: float = 0.5, + box_batch_size_per_image: int = 512, + box_positive_fraction: float = 0.25, + bbox_reg_weights: Optional[Tuple[float, ...]] = None, + ) -> None: if not hasattr(backbone, "out_channels"): raise ValueError( @@ -194,7 +196,7 @@ def __init__( if box_predictor is None: raise ValueError("num_classes should not be None when box_predictor " "is not specified") - out_channels = backbone.out_channels + out_channels = cast(int, backbone.out_channels) if rpn_anchor_generator is None: anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) @@ -229,7 +231,7 @@ def __init__( if box_predictor is None: representation_size = 1024 - box_predictor = FastRCNNPredictor(representation_size, num_classes) + box_predictor = FastRCNNPredictor(representation_size, num_classes) # type: ignore[arg-type] roi_heads = RoIHeads( # Box @@ -264,7 +266,7 @@ class TwoMLPHead(nn.Module): representation_size (int): size of the intermediate representation """ - def __init__(self, in_channels, representation_size): + def __init__(self, in_channels: int, representation_size: int) -> None: super(TwoMLPHead, self).__init__() self.fc6 = nn.Linear(in_channels, representation_size) @@ -289,7 +291,7 @@ class FastRCNNPredictor(nn.Module): num_classes (int): number of output classes (including background) """ - def __init__(self, in_channels, num_classes): + def __init__(self, in_channels: int, num_classes: int) -> None: super(FastRCNNPredictor, self).__init__() self.cls_score = nn.Linear(in_channels, num_classes) self.bbox_pred = nn.Linear(in_channels, num_classes * 4) @@ -312,8 +314,13 @@ def forward(self, x): def fasterrcnn_resnet50_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + pretrained: bool = False, + progress: bool = True, + num_classes: int = 91, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: """ Constructs a Faster R-CNN model with a ResNet-50-FPN backbone. @@ -395,14 +402,15 @@ def fasterrcnn_resnet50_fpn( def _fasterrcnn_mobilenet_v3_large_fpn( - weights_name, - pretrained=False, - progress=True, - num_classes=91, - pretrained_backbone=True, - trainable_backbone_layers=None, - **kwargs, -): + weights_name: str, + pretrained: bool = False, + progress: bool = True, + num_classes: int = 91, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: + trainable_backbone_layers = _validate_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3 ) @@ -436,8 +444,14 @@ def _fasterrcnn_mobilenet_v3_large_fpn( def fasterrcnn_mobilenet_v3_large_320_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + pretrained: bool = False, + progress: bool = True, + num_classes: int = 91, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: + """ Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See @@ -481,8 +495,14 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( def fasterrcnn_mobilenet_v3_large_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + pretrained: bool = False, + progress: bool = True, + num_classes: int = 91, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: + """ Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index c77c892e63e..6826d8ec4b2 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -23,7 +23,13 @@ class GeneralizedRCNN(nn.Module): the model """ - def __init__(self, backbone, rpn, roi_heads, transform): + def __init__( + self, + backbone: nn.Module, + rpn: nn.Module, + roi_heads: nn.Module, + transform: nn.Module, + ) -> None: super(GeneralizedRCNN, self).__init__() self.transform = transform self.backbone = backbone @@ -32,16 +38,23 @@ def __init__(self, backbone, rpn, roi_heads, transform): # used only on torchscript mode self._has_warned = False - @torch.jit.unused - def eager_outputs(self, losses, detections): - # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]] + def eager_outputs( + self, + losses: Dict[str, Tensor], + detections: List[Dict[str, Tensor]], + ) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]: + if self.training: return losses return detections - def forward(self, images, targets=None): - # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + def forward( + self, + images: List[Tensor], + targets: Optional[List[Dict[str, Tensor]]] = None, + ) -> Union[Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]], Dict[str, Tensor], List[Dict[str, Tensor]]]: + """ Args: images (list[Tensor]): images to be processed @@ -68,11 +81,11 @@ def forward(self, images, targets=None): else: raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes))) - original_image_sizes: List[Tuple[int, int]] = [] + original_image_sizes: List[List[int]] = [] for img in images: val = img.shape[-2:] assert len(val) == 2 - original_image_sizes.append((val[0], val[1])) + original_image_sizes.append([val[0], val[1]]) images, targets = self.transform(images, targets) diff --git a/torchvision/models/detection/image_list.py b/torchvision/models/detection/image_list.py index 333d3b569f2..3c46d9705c9 100644 --- a/torchvision/models/detection/image_list.py +++ b/torchvision/models/detection/image_list.py @@ -12,7 +12,7 @@ class ImageList(object): and storing in a field the original sizes of each image """ - def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]): + def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None: """ Args: tensors (tensor) diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 7cd975ea6a0..52c4763ce70 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -1,9 +1,12 @@ +from typing import Any, Optional, Tuple, List, cast + import torch from torch import nn from torchvision.ops import MultiScaleRoIAlign from ..._internally_replaced_utils import load_state_dict_from_url from ._utils import overwrite_eps +from .anchor_utils import AnchorGenerator from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers from .faster_rcnn import FasterRCNN @@ -51,10 +54,10 @@ class KeypointRCNN(FasterRCNN): If box_predictor is specified, num_classes should be None. min_size (int): minimum size of the image to be rescaled before feeding it to the backbone max_size (int): maximum size of the image to be rescaled before feeding it to the backbone - image_mean (Tuple[float, float, float]): mean values used for input normalization. + image_mean (List[float]): mean values used for input normalization. They are generally the mean values of the dataset on which the backbone has been trained on - image_std (Tuple[float, float, float]): std values used for input normalization. + image_std (List[float]): std values used for input normalization. They are generally the std values of the dataset on which the backbone has been trained on rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature maps. @@ -150,44 +153,44 @@ class KeypointRCNN(FasterRCNN): def __init__( self, - backbone, - num_classes=None, + backbone: nn.Module, + num_classes: Optional[int] = None, # transform parameters - min_size=None, - max_size=1333, - image_mean=None, - image_std=None, + min_size: Optional[Tuple[int, ...]] = None, + max_size: int = 1333, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, # RPN parameters - rpn_anchor_generator=None, - rpn_head=None, - rpn_pre_nms_top_n_train=2000, - rpn_pre_nms_top_n_test=1000, - rpn_post_nms_top_n_train=2000, - rpn_post_nms_top_n_test=1000, - rpn_nms_thresh=0.7, - rpn_fg_iou_thresh=0.7, - rpn_bg_iou_thresh=0.3, - rpn_batch_size_per_image=256, - rpn_positive_fraction=0.5, - rpn_score_thresh=0.0, + rpn_anchor_generator: Optional[AnchorGenerator] = None, + rpn_head: Optional[nn.Module] = None, + rpn_pre_nms_top_n_train: int = 2000, + rpn_pre_nms_top_n_test: int = 1000, + rpn_post_nms_top_n_train: int = 2000, + rpn_post_nms_top_n_test: int = 1000, + rpn_nms_thresh: float = 0.7, + rpn_fg_iou_thresh: float = 0.7, + rpn_bg_iou_thresh: float = 0.3, + rpn_batch_size_per_image: int = 256, + rpn_positive_fraction: float = 0.5, + rpn_score_thresh: float = 0.0, # Box parameters - box_roi_pool=None, - box_head=None, - box_predictor=None, - box_score_thresh=0.05, - box_nms_thresh=0.5, - box_detections_per_img=100, - box_fg_iou_thresh=0.5, - box_bg_iou_thresh=0.5, - box_batch_size_per_image=512, - box_positive_fraction=0.25, - bbox_reg_weights=None, + box_roi_pool: Optional[MultiScaleRoIAlign] = None, + box_head: Optional[nn.Module] = None, + box_predictor: Optional[nn.Module] = None, + box_score_thresh: float = 0.05, + box_nms_thresh: float = 0.5, + box_detections_per_img: int = 100, + box_fg_iou_thresh: float = 0.5, + box_bg_iou_thresh: float = 0.5, + box_batch_size_per_image: int = 512, + box_positive_fraction: float = 0.25, + bbox_reg_weights: Optional[Tuple[float, ...]] = None, # keypoint parameters - keypoint_roi_pool=None, - keypoint_head=None, - keypoint_predictor=None, - num_keypoints=17, - ): + keypoint_roi_pool: Optional[MultiScaleRoIAlign] = None, + keypoint_head: Optional[nn.Module] = None, + keypoint_predictor: Optional[nn.Module] = None, + num_keypoints: int = 17, + ) -> None: assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))) if min_size is None: @@ -197,7 +200,7 @@ def __init__( if keypoint_predictor is not None: raise ValueError("num_classes should be None when keypoint_predictor is specified") - out_channels = backbone.out_channels + out_channels = cast(int, backbone.out_channels) if keypoint_roi_pool is None: keypoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2) @@ -251,8 +254,13 @@ def __init__( class KeypointRCNNHeads(nn.Sequential): - def __init__(self, in_channels, layers): - d = [] + def __init__( + self, + in_channels: int, + layers: Tuple[int, ...], + ) -> None: + + d = nn.ModuleList() next_feature = in_channels for out_channels in layers: d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1)) @@ -266,7 +274,11 @@ def __init__(self, in_channels, layers): class KeypointRCNNPredictor(nn.Module): - def __init__(self, in_channels, num_keypoints): + def __init__( + self, + in_channels: int, + num_keypoints: int, + ) -> None: super(KeypointRCNNPredictor, self).__init__() input_features = in_channels deconv_kernel = 4 @@ -297,13 +309,13 @@ def forward(self, x): def keypointrcnn_resnet50_fpn( - pretrained=False, - progress=True, - num_classes=2, - num_keypoints=17, - pretrained_backbone=True, - trainable_backbone_layers=None, - **kwargs, + pretrained: bool = False, + progress: bool = True, + num_classes: int = 2, + num_keypoints: int = 17, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, ): """ Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone. diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 6b8208b19d8..efd4d49ab8c 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -1,10 +1,12 @@ from collections import OrderedDict +from typing import Optional, Tuple, Any, List, cast from torch import nn from torchvision.ops import MultiScaleRoIAlign from ..._internally_replaced_utils import load_state_dict_from_url from ._utils import overwrite_eps +from .anchor_utils import AnchorGenerator from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers from .faster_rcnn import FasterRCNN @@ -53,10 +55,10 @@ class MaskRCNN(FasterRCNN): If box_predictor is specified, num_classes should be None. min_size (int): minimum size of the image to be rescaled before feeding it to the backbone max_size (int): maximum size of the image to be rescaled before feeding it to the backbone - image_mean (Tuple[float, float, float]): mean values used for input normalization. + image_mean (List[float]: mean values used for input normalization. They are generally the mean values of the dataset on which the backbone has been trained on - image_std (Tuple[float, float, float]): std values used for input normalization. + image_std (List[float]): std values used for input normalization. They are generally the std values of the dataset on which the backbone has been trained on rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature maps. @@ -151,43 +153,43 @@ class MaskRCNN(FasterRCNN): def __init__( self, - backbone, - num_classes=None, + backbone: nn.Module, + num_classes: Optional[int] = None, # transform parameters - min_size=800, - max_size=1333, - image_mean=None, - image_std=None, + min_size: int = 800, + max_size: int = 1333, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, # RPN parameters - rpn_anchor_generator=None, - rpn_head=None, - rpn_pre_nms_top_n_train=2000, - rpn_pre_nms_top_n_test=1000, - rpn_post_nms_top_n_train=2000, - rpn_post_nms_top_n_test=1000, - rpn_nms_thresh=0.7, - rpn_fg_iou_thresh=0.7, - rpn_bg_iou_thresh=0.3, - rpn_batch_size_per_image=256, - rpn_positive_fraction=0.5, - rpn_score_thresh=0.0, + rpn_anchor_generator: Optional[AnchorGenerator] = None, + rpn_head: Optional[nn.Module] = None, + rpn_pre_nms_top_n_train: int = 2000, + rpn_pre_nms_top_n_test: int = 1000, + rpn_post_nms_top_n_train: int = 2000, + rpn_post_nms_top_n_test: int = 1000, + rpn_nms_thresh: float = 0.7, + rpn_fg_iou_thresh: float = 0.7, + rpn_bg_iou_thresh: float = 0.3, + rpn_batch_size_per_image: int = 256, + rpn_positive_fraction: float = 0.5, + rpn_score_thresh: float = 0.0, # Box parameters - box_roi_pool=None, - box_head=None, - box_predictor=None, - box_score_thresh=0.05, - box_nms_thresh=0.5, - box_detections_per_img=100, - box_fg_iou_thresh=0.5, - box_bg_iou_thresh=0.5, - box_batch_size_per_image=512, - box_positive_fraction=0.25, - bbox_reg_weights=None, + box_roi_pool: Optional[MultiScaleRoIAlign] = None, + box_head: Optional[nn.Module] = None, + box_predictor: Optional[nn.Module] = None, + box_score_thresh: float = 0.05, + box_nms_thresh: float = 0.5, + box_detections_per_img: int = 100, + box_fg_iou_thresh: float = 0.5, + box_bg_iou_thresh: float = 0.5, + box_batch_size_per_image: int = 512, + box_positive_fraction: float = 0.25, + bbox_reg_weights: Optional[Tuple[float, ...]] = None, # Mask parameters - mask_roi_pool=None, - mask_head=None, - mask_predictor=None, - ): + mask_roi_pool: Optional[MultiScaleRoIAlign] = None, + mask_head: Optional[nn.Module] = None, + mask_predictor: Optional[nn.Module] = None, + ) -> None: assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))) @@ -195,20 +197,22 @@ def __init__( if mask_predictor is not None: raise ValueError("num_classes should be None when mask_predictor is specified") - out_channels = backbone.out_channels + out_channels = cast(int, backbone.out_channels) if mask_roi_pool is None: mask_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2) if mask_head is None: - mask_layers = (256, 256, 256, 256) + mask_layers = [256, 256, 256, 256] mask_dilation = 1 mask_head = MaskRCNNHeads(out_channels, mask_layers, mask_dilation) if mask_predictor is None: mask_predictor_in_channels = 256 # == mask_layers[-1] mask_dim_reduced = 256 - mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, mask_dim_reduced, num_classes) + mask_predictor = MaskRCNNPredictor( + mask_predictor_in_channels, mask_dim_reduced, num_classes + ) # type: ignore[arg-type] super(MaskRCNN, self).__init__( backbone, @@ -251,20 +255,20 @@ def __init__( class MaskRCNNHeads(nn.Sequential): - def __init__(self, in_channels, layers, dilation): + def __init__(self, in_channels: int, layers: List[int], dilation: int) -> None: """ Args: in_channels (int): number of input channels layers (list): feature dimensions of each FCN layer dilation (int): dilation rate of kernel """ - d = OrderedDict() + d = nn.ModuleDict() next_feature = in_channels for layer_idx, layer_features in enumerate(layers, 1): - d["mask_fcn{}".format(layer_idx)] = nn.Conv2d( + d[f"mask_fcn{layer_idx}"] = nn.Conv2d( next_feature, layer_features, kernel_size=3, stride=1, padding=dilation, dilation=dilation ) - d["relu{}".format(layer_idx)] = nn.ReLU(inplace=True) + d[f"relu{layer_idx}"] = nn.ReLU(inplace=True) next_feature = layer_features super(MaskRCNNHeads, self).__init__(d) @@ -276,7 +280,7 @@ def __init__(self, in_channels, layers, dilation): class MaskRCNNPredictor(nn.Sequential): - def __init__(self, in_channels, dim_reduced, num_classes): + def __init__(self, in_channels: int, dim_reduced: int, num_classes: int) -> None: super(MaskRCNNPredictor, self).__init__( OrderedDict( [ @@ -300,8 +304,13 @@ def __init__(self, in_channels, dim_reduced, num_classes): def maskrcnn_resnet50_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + pretrained: bool = False, + progress: bool = True, + num_classes: int = 91, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> MaskRCNN: """ Constructs a Mask R-CNN model with a ResNet-50-FPN backbone. diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index eb05144cb0c..c85c3edd8b7 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -1,14 +1,14 @@ import math import warnings from collections import OrderedDict -from typing import Dict, List, Tuple, Optional +from typing import Dict, List, Tuple, Optional, Any, Union, cast import torch from torch import nn, Tensor from ..._internally_replaced_utils import load_state_dict_from_url -from ...ops import sigmoid_focal_loss from ...ops import boxes as box_ops +from ...ops import sigmoid_focal_loss from ...ops.feature_pyramid_network import LastLevelP6P7 from . import _utils as det_utils from ._utils import overwrite_eps @@ -16,7 +16,6 @@ from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers from .transform import GeneralizedRCNNTransform - __all__ = ["RetinaNet", "retinanet_resnet50_fpn"] @@ -37,20 +36,32 @@ class RetinaNetHead(nn.Module): num_classes (int): number of classes to be predicted """ - def __init__(self, in_channels, num_anchors, num_classes): + def __init__( + self, + in_channels: int, + num_anchors: int, + num_classes: int, + ) -> None: super().__init__() self.classification_head = RetinaNetClassificationHead(in_channels, num_anchors, num_classes) self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors) - def compute_loss(self, targets, head_outputs, anchors, matched_idxs): - # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor] + def compute_loss( + self, + targets: List[Dict[str, Tensor]], + head_outputs: Dict[str, Tensor], + anchors: List[Tensor], + matched_idxs: List[Tensor], + ) -> Dict[str, Tensor]: return { "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs), "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), } - def forward(self, x): - # type: (List[Tensor]) -> Dict[str, Tensor] + def forward( + self, + x: List[Tensor], + ) -> Dict[str, Tensor]: return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)} @@ -62,12 +73,13 @@ class RetinaNetClassificationHead(nn.Module): in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted num_classes (int): number of classes to be predicted + prior_probability (float): probability of priors to initialise the head """ - def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01): + def __init__(self, in_channels: int, num_anchors: int, num_classes: int, prior_probability: float = 0.01) -> None: super().__init__() - conv = [] + conv = nn.ModuleList() for _ in range(4): conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) conv.append(nn.ReLU()) @@ -90,8 +102,12 @@ def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01 # https://github.com/pytorch/vision/pull/1697#issuecomment-630255584 self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS - def compute_loss(self, targets, head_outputs, matched_idxs): - # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor + def compute_loss( + self, + targets: List[Dict[str, Tensor]], + head_outputs: Dict[str, Tensor], + matched_idxs: List[Tensor], + ) -> Tensor: losses = [] cls_logits = head_outputs["cls_logits"] @@ -123,8 +139,10 @@ def compute_loss(self, targets, head_outputs, matched_idxs): return _sum(losses) / len(targets) - def forward(self, x): - # type: (List[Tensor]) -> Tensor + def forward( + self, + x: List[Tensor], + ) -> Tensor: all_cls_logits = [] for features in x: @@ -155,10 +173,14 @@ class RetinaNetRegressionHead(nn.Module): "box_coder": det_utils.BoxCoder, } - def __init__(self, in_channels, num_anchors): + def __init__( + self, + in_channels: int, + num_anchors: int, + ) -> None: super().__init__() - conv = [] + conv = nn.ModuleList() for _ in range(4): conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) conv.append(nn.ReLU()) @@ -175,8 +197,13 @@ def __init__(self, in_channels, num_anchors): self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) - def compute_loss(self, targets, head_outputs, anchors, matched_idxs): - # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor + def compute_loss( + self, + targets: List[Dict[str, Tensor]], + head_outputs: Dict[str, Tensor], + anchors: List[Tensor], + matched_idxs: List[Tensor], + ) -> Tensor: losses = [] bbox_regression = head_outputs["bbox_regression"] @@ -204,8 +231,10 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): return _sum(losses) / max(1, len(targets)) - def forward(self, x): - # type: (List[Tensor]) -> Tensor + def forward( + self, + x: List[Tensor], + ) -> Tensor: all_bbox_regression = [] for features in x: @@ -257,15 +286,17 @@ class RetinaNet(nn.Module): num_classes (int): number of output classes of the model (including the background). min_size (int): minimum size of the image to be rescaled before feeding it to the backbone max_size (int): maximum size of the image to be rescaled before feeding it to the backbone - image_mean (Tuple[float, float, float]): mean values used for input normalization. + image_mean (List[float]): mean values used for input normalization. They are generally the mean values of the dataset on which the backbone has been trained on - image_std (Tuple[float, float, float]): std values used for input normalization. + image_std (List[float]): std values used for input normalization. They are generally the std values of the dataset on which the backbone has been trained on anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature maps. head (nn.Module): Module run on top of the feature pyramid. Defaults to a module containing a classification and regression module. + proposal_matcher (object): Module that matches the proposals generated by groudn truth and predictions. + Defaults to a Matcher module which matches the predictions. score_thresh (float): Score threshold used for postprocessing the detections. nms_thresh (float): NMS threshold used for postprocessing the detections. detections_per_img (int): Number of best detections to keep after NMS. @@ -315,24 +346,24 @@ class RetinaNet(nn.Module): def __init__( self, - backbone, - num_classes, + backbone: nn.Module, + num_classes: int, # transform parameters - min_size=800, - max_size=1333, - image_mean=None, - image_std=None, + min_size: int = 800, + max_size: int = 1333, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, # Anchor parameters - anchor_generator=None, - head=None, - proposal_matcher=None, - score_thresh=0.05, - nms_thresh=0.5, - detections_per_img=300, - fg_iou_thresh=0.5, - bg_iou_thresh=0.4, - topk_candidates=1000, - ): + anchor_generator: Optional[AnchorGenerator] = None, + head: Optional[nn.Module] = None, + proposal_matcher: Optional[object] = None, + score_thresh: float = 0.05, + nms_thresh: float = 0.5, + detections_per_img: int = 300, + fg_iou_thresh: float = 0.5, + bg_iou_thresh: float = 0.4, + topk_candidates: int = 1000, + ) -> None: super().__init__() if not hasattr(backbone, "out_channels"): @@ -352,7 +383,8 @@ def __init__( self.anchor_generator = anchor_generator if head is None: - head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes) + out_channels = cast(int, backbone.out_channels) + head = RetinaNetHead(out_channels, anchor_generator.num_anchors_per_location()[0], num_classes) self.head = head if proposal_matcher is None: @@ -379,16 +411,23 @@ def __init__( # used only on torchscript mode self._has_warned = False - @torch.jit.unused - def eager_outputs(self, losses, detections): - # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + def eager_outputs( + self, + losses: Dict[str, Tensor], + detections: List[Dict[str, Tensor]], + ) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]: + if self.training: return losses return detections - def compute_loss(self, targets, head_outputs, anchors): - # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor] + def compute_loss( + self, + targets: List[Dict[str, Tensor]], + head_outputs: Dict[str, Tensor], + anchors: List[Tensor], + ) -> Dict[str, Tensor]: matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): if targets_per_image["boxes"].numel() == 0: @@ -398,12 +437,16 @@ def compute_loss(self, targets, head_outputs, anchors): continue match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image) - matched_idxs.append(self.proposal_matcher(match_quality_matrix)) + matched_idxs.append(self.proposal_matcher(match_quality_matrix)) # type: ignore[operator] - return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) + return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) # type: ignore[operator] - def postprocess_detections(self, head_outputs, anchors, image_shapes): - # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]] + def postprocess_detections( + self, + head_outputs: Dict[str, List[Tensor]], + anchors: List[List[Tensor]], + image_shapes: List[Tuple[int, int]], + ) -> List[Dict[str, Tensor]]: class_logits = head_outputs["cls_logits"] box_regression = head_outputs["bbox_regression"] @@ -448,26 +491,31 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): image_scores.append(scores_per_level) image_labels.append(labels_per_level) - image_boxes = torch.cat(image_boxes, dim=0) - image_scores = torch.cat(image_scores, dim=0) - image_labels = torch.cat(image_labels, dim=0) + image_boxes_concated = torch.cat(image_boxes, dim=0) + image_scores_concated = torch.cat(image_scores, dim=0) + image_labels_concated = torch.cat(image_labels, dim=0) # non-maximum suppression - keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh) + keep = box_ops.batched_nms( + image_boxes_concated, image_scores_concated, image_labels_concated, self.nms_thresh + ) keep = keep[: self.detections_per_img] detections.append( { - "boxes": image_boxes[keep], - "scores": image_scores[keep], - "labels": image_labels[keep], + "boxes": image_boxes_concated[keep], + "scores": image_scores_concated[keep], + "labels": image_labels_concated[keep], } ) return detections - def forward(self, images, targets=None): - # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + def forward( + self, + images: List[Tensor], + targets: Optional[List[Dict[str, Tensor]]] = None, + ) -> Union[Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]], Dict[str, Tensor], List[Dict[str, Tensor]]]: """ Args: images (list[Tensor]): images to be processed @@ -496,11 +544,11 @@ def forward(self, images, targets=None): raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes))) # get the original image sizes - original_image_sizes: List[Tuple[int, int]] = [] + original_image_sizes: List[List[int]] = [] for img in images: val = img.shape[-2:] assert len(val) == 2 - original_image_sizes.append((val[0], val[1])) + original_image_sizes.append([val[0], val[1]]) # transform the input images, targets = self.transform(images, targets) @@ -575,8 +623,13 @@ def forward(self, images, targets=None): def retinanet_resnet50_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + pretrained: bool = False, + progress: bool = True, + num_classes: int = 91, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> RetinaNet: """ Constructs a RetinaNet model with a ResNet-50-FPN backbone. diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 35aee4b7d54..1ef1ff9d865 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -469,11 +469,10 @@ def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w): return res_append -def paste_masks_in_image(masks, boxes, img_shape, padding=1): - # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor +def paste_masks_in_image(masks: Tensor, boxes: Tensor, img_shape: List[int], padding: int = 1) -> Tensor: masks, scale = expand_masks(masks, padding=padding) boxes = expand_boxes(boxes, scale).to(dtype=torch.int64) - im_h, im_w = img_shape + im_h, im_w = img_shape[0], img_shape[1] if torchvision._is_tracing(): return _onnx_paste_masks_in_image_loop( diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index c58e1a37af0..be675b48bb1 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Dict, Tuple +from typing import List, Optional, Dict, Tuple, cast import torch import torchvision @@ -14,14 +14,15 @@ @torch.jit.unused -def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): - # type: (Tensor, int) -> Tuple[int, int] +def _onnx_get_num_anchors_and_pre_nms_top_n(ob: Tensor, orig_pre_nms_top_n: int) -> Tuple[int, int]: from torch.onnx import operators num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0) pre_nms_top_n = torch.min(torch.cat((torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0)) - return num_anchors, pre_nms_top_n + # For mypy we cast at runtime + + return cast(int, num_anchors), cast(int, pre_nms_top_n) class RPNHead(nn.Module): @@ -33,7 +34,11 @@ class RPNHead(nn.Module): num_anchors (int): number of anchors to be predicted """ - def __init__(self, in_channels, num_anchors): + def __init__( + self, + in_channels: int, + num_anchors: int, + ) -> None: super(RPNHead, self).__init__() self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) @@ -43,8 +48,10 @@ def __init__(self, in_channels, num_anchors): torch.nn.init.normal_(layer.weight, std=0.01) torch.nn.init.constant_(layer.bias, 0) - def forward(self, x): - # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] + def forward( + self, + x: List[Tensor], + ) -> Tuple[List[Tensor], List[Tensor]]: logits = [] bbox_reg = [] for feature in x: @@ -54,16 +61,26 @@ def forward(self, x): return logits, bbox_reg -def permute_and_flatten(layer, N, A, C, H, W): - # type: (Tensor, int, int, int, int, int) -> Tensor +def permute_and_flatten( + layer: Tensor, + N: int, + A: int, + C: int, + H: int, + W: int, +) -> Tensor: + layer = layer.view(N, -1, C, H, W) layer = layer.permute(0, 3, 4, 1, 2) layer = layer.reshape(N, -1, C) return layer -def concat_box_prediction_layers(box_cls, box_regression): - # type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] +def concat_box_prediction_layers( + box_cls: List[Tensor], + box_regression: List[Tensor], +) -> Tuple[Tensor, Tensor]: + box_cls_flattened = [] box_regression_flattened = [] # for each feature level, permute the outputs to make them be in the @@ -83,9 +100,9 @@ def concat_box_prediction_layers(box_cls, box_regression): # concatenate on the first dimension (representing the feature levels), to # take into account the way the labels were generated (with all feature maps # being concatenated as well) - box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2) - box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4) - return box_cls, box_regression + box_cls_concated = torch.cat(box_cls_flattened, dim=1).flatten(0, -2) + box_regression_concated = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4) + return box_cls_concated, box_regression_concated class RegionProposalNetwork(torch.nn.Module): @@ -104,13 +121,14 @@ class RegionProposalNetwork(torch.nn.Module): for computing the loss positive_fraction (float): proportion of positive anchors in a mini-batch during training of the RPN - pre_nms_top_n (Dict[int]): number of proposals to keep before applying NMS. It should + pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should contain two fields: training and testing, to allow for different values depending on training or evaluation - post_nms_top_n (Dict[int]): number of proposals to keep after applying NMS. It should + post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should contain two fields: training and testing, to allow for different values depending on training or evaluation - nms_thresh (float): NMS threshold used for postprocessing the RPN proposals + nms_thresh (float): NMS threshold used for postprocessing the RPN proposals. + score_thresh (float): NMS Score threshold for postprocessing boxes. """ @@ -124,19 +142,17 @@ class RegionProposalNetwork(torch.nn.Module): def __init__( self, - anchor_generator, - head, - # - fg_iou_thresh, - bg_iou_thresh, - batch_size_per_image, - positive_fraction, - # - pre_nms_top_n, - post_nms_top_n, - nms_thresh, - score_thresh=0.0, - ): + anchor_generator: AnchorGenerator, + head: nn.Module, + fg_iou_thresh: float, + bg_iou_thresh: float, + batch_size_per_image: int, + positive_fraction: float, + pre_nms_top_n: Dict[str, int], + post_nms_top_n: Dict[str, int], + nms_thresh: float, + score_thresh: float = 0.0, + ) -> None: super(RegionProposalNetwork, self).__init__() self.anchor_generator = anchor_generator self.head = head @@ -159,18 +175,22 @@ def __init__( self.score_thresh = score_thresh self.min_size = 1e-3 - def pre_nms_top_n(self): + def pre_nms_top_n(self) -> int: if self.training: return self._pre_nms_top_n["training"] return self._pre_nms_top_n["testing"] - def post_nms_top_n(self): + def post_nms_top_n(self) -> int: if self.training: return self._post_nms_top_n["training"] return self._post_nms_top_n["testing"] - def assign_targets_to_anchors(self, anchors, targets): - # type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]] + def assign_targets_to_anchors( + self, + anchors: List[Tensor], + targets: List[Dict[str, Tensor]], + ) -> Tuple[List[Tensor], List[Tensor]]: + labels = [] matched_gt_boxes = [] for anchors_per_image, targets_per_image in zip(anchors, targets): @@ -205,8 +225,12 @@ def assign_targets_to_anchors(self, anchors, targets): matched_gt_boxes.append(matched_gt_boxes_per_image) return labels, matched_gt_boxes - def _get_top_n_idx(self, objectness, num_anchors_per_level): - # type: (Tensor, List[int]) -> Tensor + def _get_top_n_idx( + self, + objectness: Tensor, + num_anchors_per_level: List[int], + ) -> Tensor: + r = [] offset = 0 for ob in objectness.split(num_anchors_per_level, 1): @@ -220,8 +244,14 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level): offset += num_anchors return torch.cat(r, dim=1) - def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level): - # type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> Tuple[List[Tensor], List[Tensor]] + def filter_proposals( + self, + proposals: Tensor, + objectness: Tensor, + image_shapes: List[Tuple[int, int]], + num_anchors_per_level: List[int], + ) -> Tuple[List[Tensor], List[Tensor]]: + num_images = proposals.shape[0] device = proposals.device # do not backprop through objectness @@ -271,8 +301,14 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_ final_scores.append(scores) return final_boxes, final_scores - def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets): - # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] + def compute_loss( + self, + objectness: Tensor, + pred_bbox_deltas: Tensor, + labels: List[Tensor], + regression_targets: List[Tensor], + ) -> Tuple[Tensor, Tensor]: + """ Args: objectness (Tensor) @@ -312,11 +348,10 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets) def forward( self, - images, # type: ImageList - features, # type: Dict[str, Tensor] - targets=None, # type: Optional[List[Dict[str, Tensor]]] - ): - # type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]] + images: ImageList, + features: Dict[str, Tensor], + targets: Optional[List[Dict[str, Tensor]]] = None, + ) -> Tuple[List[Tensor], Dict[str, Tensor]]: """ Args: images (ImageList): images for which we want to compute the predictions @@ -330,7 +365,7 @@ def forward( Returns: boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per image. - losses (Dict[Tensor]): the losses for the model during training. During + losses (Dict[str, Tensor]): the losses for the model during training. During testing, it is an empty dict. """ # RPN uses all feature maps that are available diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index ef68c0279be..7803d911041 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -1,6 +1,6 @@ import warnings from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, cast, Union import torch import torch.nn.functional as F @@ -27,7 +27,7 @@ } -def _xavier_init(conv: nn.Module): +def _xavier_init(conv: nn.Module) -> None: for layer in conv.modules(): if isinstance(layer, nn.Conv2d): torch.nn.init.xavier_uniform_(layer.weight) @@ -36,7 +36,7 @@ def _xavier_init(conv: nn.Module): class SSDHead(nn.Module): - def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int): + def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int) -> None: super().__init__() self.classification_head = SSDClassificationHead(in_channels, num_anchors, num_classes) self.regression_head = SSDRegressionHead(in_channels, num_anchors) @@ -49,7 +49,7 @@ def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: class SSDScoringHead(nn.Module): - def __init__(self, module_list: nn.ModuleList, num_columns: int): + def __init__(self, module_list: nn.ModuleList, num_columns: int) -> None: super().__init__() self.module_list = module_list self.num_columns = num_columns @@ -86,7 +86,7 @@ def forward(self, x: List[Tensor]) -> Tensor: class SSDClassificationHead(SSDScoringHead): - def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int): + def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int) -> None: cls_logits = nn.ModuleList() for channels, anchors in zip(in_channels, num_anchors): cls_logits.append(nn.Conv2d(channels, num_classes * anchors, kernel_size=3, padding=1)) @@ -95,7 +95,7 @@ def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: class SSDRegressionHead(SSDScoringHead): - def __init__(self, in_channels: List[int], num_anchors: List[int]): + def __init__(self, in_channels: List[int], num_anchors: List[int]) -> None: bbox_reg = nn.ModuleList() for channels, anchors in zip(in_channels, num_anchors): bbox_reg.append(nn.Conv2d(channels, 4 * anchors, kernel_size=3, padding=1)) @@ -140,10 +140,10 @@ class SSD(nn.Module): size (Tuple[int, int]): the width and height to which images will be rescaled before feeding them to the backbone. num_classes (int): number of output classes of the model (including the background). - image_mean (Tuple[float, float, float]): mean values used for input normalization. + image_mean (List[float]): mean values used for input normalization. They are generally the mean values of the dataset on which the backbone has been trained on - image_std (Tuple[float, float, float]): std values used for input normalization. + image_std (List[float]): std values used for input normalization. They are generally the std values of the dataset on which the backbone has been trained on head (nn.Module, optional): Module run on top of the backbone features. Defaults to a module containing a classification and regression module. @@ -178,18 +178,24 @@ def __init__( iou_thresh: float = 0.5, topk_candidates: int = 400, positive_fraction: float = 0.25, - ): - super().__init__() + ) -> None: + super().__init__() self.backbone = backbone - self.anchor_generator = anchor_generator - self.box_coder = det_utils.BoxCoder(weights=(10.0, 10.0, 5.0, 5.0)) if head is None: if hasattr(backbone, "out_channels"): - out_channels = backbone.out_channels + # Other models use int as out_channel, we need compatibility. + # Also out_channels needs a runtime cast to int for mypy. + out_channels_calc = cast(int, backbone.out_channels) + if isinstance(out_channels_calc, int): + out_channels = [out_channels_calc] + elif isinstance(out_channels_calc, list): + out_channels = out_channels_calc + else: + raise ValueError("Backbone should have out_channels either int or List[int] ") else: out_channels = det_utils.retrieve_out_channels(backbone, size) @@ -218,10 +224,10 @@ def __init__( # used only on torchscript mode self._has_warned = False - @torch.jit.unused def eager_outputs( self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]] - ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: + ) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]: + if self.training: return losses @@ -300,7 +306,8 @@ def compute_loss( def forward( self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None - ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: + ) -> Union[Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]], Dict[str, Tensor], List[Dict[str, Tensor]]]: + if self.training and targets is None: raise ValueError("In training mode, targets should be passed") @@ -317,11 +324,11 @@ def forward( raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes))) # get the original image sizes - original_image_sizes: List[Tuple[int, int]] = [] + original_image_sizes: List[List[int]] = [] for img in images: val = img.shape[-2:] assert len(val) == 2 - original_image_sizes.append((val[0], val[1])) + original_image_sizes.append([val[0], val[1]]) # transform the input images, targets = self.transform(images, targets) @@ -433,7 +440,7 @@ def postprocess_detections( class SSDFeatureExtractorVGG(nn.Module): - def __init__(self, backbone: nn.Module, highres: bool): + def __init__(self, backbone: nn.Module, highres: bool) -> None: super().__init__() _, _, maxpool3_pos, maxpool4_pos, _ = (i for i, layer in enumerate(backbone) if isinstance(layer, nn.MaxPool2d)) @@ -519,7 +526,9 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return OrderedDict([(str(i), v) for i, v in enumerate(output)]) -def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int): +def _vgg_extractor( + backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int +) -> SSDFeatureExtractorVGG: if backbone_name in backbone_urls: # Use custom backbones more appropriate for SSD arch = backbone_name.split("_")[0] @@ -553,7 +562,7 @@ def ssd300_vgg16( pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, -): +) -> SSD: """Constructs an SSD model with input size 300x300 and a VGG16 backbone. Reference: `"SSD: Single Shot MultiBox Detector" `_. diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index e4a1134b85c..f64b2c6d70e 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -1,5 +1,5 @@ import math -from typing import List, Tuple, Dict, Optional +from typing import List, Tuple, Dict, Optional, cast import torch import torchvision @@ -19,7 +19,8 @@ def _get_shape_onnx(image: Tensor) -> Tensor: @torch.jit.unused def _fake_cast_onnx(v: Tensor) -> float: # ONNX requires a tensor but here we fake its type for JIT. - return v + # cast is no-op at runtime and it's there only to help mypy. + return cast(float, v) def _resize_image_and_masks( @@ -29,6 +30,7 @@ def _resize_image_and_masks( target: Optional[Dict[str, Tensor]] = None, fixed_size: Optional[Tuple[int, int]] = None, ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + if torchvision._is_tracing(): im_shape = _get_shape_onnx(image) else: @@ -91,7 +93,8 @@ def __init__( image_std: List[float], size_divisible: int = 32, fixed_size: Optional[Tuple[int, int]] = None, - ): + ) -> None: + super(GeneralizedRCNNTransform, self).__init__() if not isinstance(min_size, (list, tuple)): min_size = (min_size,) @@ -133,13 +136,13 @@ def forward( targets[i] = target_index image_sizes = [img.shape[-2:] for img in images] - images = self.batch_images(images, size_divisible=self.size_divisible) + batched_images = self.batch_images(images, size_divisible=self.size_divisible) image_sizes_list: List[Tuple[int, int]] = [] for image_size in image_sizes: assert len(image_size) == 2 image_sizes_list.append((image_size[0], image_size[1])) - image_list = ImageList(images, image_sizes_list) + image_list = ImageList(batched_images, image_sizes_list) return image_list, targets def normalize(self, image: Tensor) -> Tensor: @@ -179,12 +182,12 @@ def resize( return image, target bbox = target["boxes"] - bbox = resize_boxes(bbox, (h, w), image.shape[-2:]) + bbox = resize_boxes(bbox, [h, w], list(image.shape[-2:])) target["boxes"] = bbox if "keypoints" in target: keypoints = target["keypoints"] - keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:]) + keypoints = resize_keypoints(keypoints, [h, w], list(image.shape[-2:])) target["keypoints"] = keypoints return image, target @@ -242,8 +245,8 @@ def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor def postprocess( self, result: List[Dict[str, Tensor]], - image_shapes: List[Tuple[int, int]], - original_image_sizes: List[Tuple[int, int]], + image_shapes: List[List[int]], + original_image_sizes: List[List[int]], ) -> List[Dict[str, Tensor]]: if self.training: return result