diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index 1d3979caa3f..27a7934a925 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -22,7 +22,13 @@ class GeneralizedRCNN(nn.Module): the model """ - def __init__(self, backbone, rpn, roi_heads, transform): + def __init__( + self, + backbone: nn.Module, + rpn: nn.Module, + roi_heads: nn.Module, + transform: nn.Module, + ) -> None: super(GeneralizedRCNN, self).__init__() self.transform = transform self.backbone = backbone @@ -32,15 +38,22 @@ def __init__(self, backbone, rpn, roi_heads, transform): self._has_warned = False @torch.jit.unused - def eager_outputs(self, losses, detections): - # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]] + def eager_outputs( + self, + losses: Dict[str, Tensor], + detections: List[Dict[str, Tensor]], + ) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]: + if self.training: return losses return detections - def forward(self, images, targets=None): - # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + def forward( + self, + images: List[Tensor], + targets: Optional[List[Dict[str, Tensor]]] = None, + ) -> Union[Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]]: """ Args: images (list[Tensor]): images to be processed diff --git a/torchvision/models/detection/image_list.py b/torchvision/models/detection/image_list.py index a389b3c3ce1..19a126f65bf 100644 --- a/torchvision/models/detection/image_list.py +++ b/torchvision/models/detection/image_list.py @@ -11,7 +11,7 @@ class ImageList(object): and storing in a field the original sizes of each image """ - def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]): + def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None: """ Args: tensors (tensor) diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index fd9a980b97d..4b397f501ac 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -1,9 +1,11 @@ import torch from torch import nn +from typing import Any, List, Optional, Tuple from torchvision.ops import MultiScaleRoIAlign from ._utils import overwrite_eps +from .anchor_utils import AnchorGenerator from ..._internally_replaced_utils import load_state_dict_from_url from .faster_rcnn import FasterRCNN @@ -151,27 +153,46 @@ class KeypointRCNN(FasterRCNN): >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) """ - def __init__(self, backbone, num_classes=None, - # transform parameters - min_size=None, max_size=1333, - image_mean=None, image_std=None, - # RPN parameters - rpn_anchor_generator=None, rpn_head=None, - rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, - rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, - rpn_nms_thresh=0.7, - rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, - rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, - rpn_score_thresh=0.0, - # Box parameters - box_roi_pool=None, box_head=None, box_predictor=None, - box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, - box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, - box_batch_size_per_image=512, box_positive_fraction=0.25, - bbox_reg_weights=None, - # keypoint parameters - keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None, - num_keypoints=17): + def __init__( + self, + backbone: nn.Module, + num_classes: Optional[int] = None, + # transform parameters + min_size: Optional[Tuple[int]] = None, + max_size: int = 1333, + image_mean: Optional[Tuple[float]] = None, + image_std: Optional[Tuple[float]] = None, + # RPN parameters + rpn_anchor_generator: Optional[AnchorGenerator] = None, + rpn_head: Optional[nn.Module] = None, + rpn_pre_nms_top_n_train: int = 2000, + rpn_pre_nms_top_n_test: int = 1000, + rpn_post_nms_top_n_train: int = 2000, + rpn_post_nms_top_n_test: int = 1000, + rpn_nms_thresh: float = 0.7, + rpn_fg_iou_thresh: float = 0.7, + rpn_bg_iou_thresh: float = 0.3, + rpn_batch_size_per_image: int = 256, + rpn_positive_fraction: float = 0.5, + rpn_score_thresh: float = 0.0, + # Box parameters + box_roi_pool: Optional[MultiScaleRoIAlign] = None, + box_head: Optional[nn.Module] = None, + box_predictor: Optional[nn.Module] = None, + box_score_thresh: float = 0.05, + box_nms_thresh: float = 0.5, + box_detections_per_img: int = 100, + box_fg_iou_thresh: float = 0.5, + box_bg_iou_thresh: float = 0.5, + box_batch_size_per_image: int = 512, + box_positive_fraction: float = 0.25, + bbox_reg_weights: Optional[Tuple[float]] = None, + # keypoint parameters + keypoint_roi_pool: Optional[MultiScaleRoIAlign] = None, + keypoint_head: Optional[nn.Module] = None, + keypoint_predictor: Optional[nn.Module] = None, + num_keypoints: int = 17, + ) -> None: assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))) if min_size is None: @@ -223,7 +244,11 @@ def __init__(self, backbone, num_classes=None, class KeypointRCNNHeads(nn.Sequential): - def __init__(self, in_channels, layers): + def __init__( + self, + in_channels: int, + layers: List[int], + ) -> None: d = [] next_feature = in_channels for out_channels in layers: @@ -238,7 +263,11 @@ def __init__(self, in_channels, layers): class KeypointRCNNPredictor(nn.Module): - def __init__(self, in_channels, num_keypoints): + def __init__( + self, + in_channels: int, + num_keypoints: int, + ) -> None: super(KeypointRCNNPredictor, self).__init__() input_features = in_channels deconv_kernel = 4 @@ -256,7 +285,10 @@ def __init__(self, in_channels, num_keypoints): self.up_scale = 2 self.out_channels = num_keypoints - def forward(self, x): + def forward( + self, + x, + ): x = self.kps_score_lowres(x) return torch.nn.functional.interpolate( x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False @@ -272,9 +304,15 @@ def forward(self, x): } -def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, - num_classes=2, num_keypoints=17, - pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): +def keypointrcnn_resnet50_fpn( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 2, + num_keypoints: int = 17, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +): """ Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone. diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index a98eac24dd3..34d9364c9c0 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -35,7 +35,11 @@ class RPNHead(nn.Module): num_anchors (int): number of anchors to be predicted """ - def __init__(self, in_channels, num_anchors): + def __init__( + self, + in_channels: int, + num_anchors: int, + ) -> None: super(RPNHead, self).__init__() self.conv = nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=1, padding=1 @@ -49,8 +53,10 @@ def __init__(self, in_channels, num_anchors): torch.nn.init.normal_(layer.weight, std=0.01) torch.nn.init.constant_(layer.bias, 0) - def forward(self, x): - # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] + def forward( + self, + x: List[Tensor], + ) -> Tuple[List[Tensor], List[Tensor]]: logits = [] bbox_reg = [] for feature in x: @@ -60,16 +66,26 @@ def forward(self, x): return logits, bbox_reg -def permute_and_flatten(layer, N, A, C, H, W): - # type: (Tensor, int, int, int, int, int) -> Tensor +def permute_and_flatten( + layer: Tensor, + N: int, + A: int, + C: int, + H: int, + W: int, +) -> Tensor: + layer = layer.view(N, -1, C, H, W) layer = layer.permute(0, 3, 4, 1, 2) layer = layer.reshape(N, -1, C) return layer -def concat_box_prediction_layers(box_cls, box_regression): - # type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] +def concat_box_prediction_layers( + box_cls: List[Tensor], + box_regression: List[Tensor], +) -> Tuple[Tensor, Tensor]: + box_cls_flattened = [] box_regression_flattened = [] # for each feature level, permute the outputs to make them be in the @@ -122,25 +138,31 @@ class RegionProposalNetwork(torch.nn.Module): post_nms_top_n (Dict[int]): number of proposals to keep after applying NMS. It should contain two fields: training and testing, to allow for different values depending on training or evaluation - nms_thresh (float): NMS threshold used for postprocessing the RPN proposals + nms_thresh (float): NMS threshold used for postprocessing the RPN proposals. + score_thresh (float): NMS Score threshold for postprocessing boxes. """ __annotations__ = { 'box_coder': det_utils.BoxCoder, 'proposal_matcher': det_utils.Matcher, 'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler, - 'pre_nms_top_n': Dict[str, int], - 'post_nms_top_n': Dict[str, int], } - def __init__(self, - anchor_generator, - head, - # - fg_iou_thresh, bg_iou_thresh, - batch_size_per_image, positive_fraction, - # - pre_nms_top_n, post_nms_top_n, nms_thresh, score_thresh=0.0): + def __init__( + self, + anchor_generator: AnchorGenerator, + head: nn.Module, + # + fg_iou_thresh: float, + bg_iou_thresh: float, + batch_size_per_image: int, + positive_fraction: float, + # + pre_nms_top_n: Dict[str, int], + post_nms_top_n: Dict[str, int], + nms_thresh: float, + score_thresh: float = 0.0, + ) -> None: super(RegionProposalNetwork, self).__init__() self.anchor_generator = anchor_generator self.head = head @@ -175,8 +197,12 @@ def post_nms_top_n(self): return self._post_nms_top_n['training'] return self._post_nms_top_n['testing'] - def assign_targets_to_anchors(self, anchors, targets): - # type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]] + def assign_targets_to_anchors( + self, + anchors: List[Tensor], + targets: List[Dict[str, Tensor]], + ) -> Tuple[List[Tensor], List[Tensor]]: + labels = [] matched_gt_boxes = [] for anchors_per_image, targets_per_image in zip(anchors, targets): @@ -211,8 +237,12 @@ def assign_targets_to_anchors(self, anchors, targets): matched_gt_boxes.append(matched_gt_boxes_per_image) return labels, matched_gt_boxes - def _get_top_n_idx(self, objectness, num_anchors_per_level): - # type: (Tensor, List[int]) -> Tensor + def _get_top_n_idx( + self, + objectness: Tensor, + num_anchors_per_level: List[int], + ) -> Tensor: + r = [] offset = 0 for ob in objectness.split(num_anchors_per_level, 1): @@ -226,8 +256,14 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level): offset += num_anchors return torch.cat(r, dim=1) - def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level): - # type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> Tuple[List[Tensor], List[Tensor]] + def filter_proposals( + self, + proposals: Tensor, + objectness: Tensor, + image_shapes: List[Tuple[int, int]], + num_anchors_per_level: List[int], + ) -> Tuple[List[Tensor], List[Tensor]]: + num_images = proposals.shape[0] device = proposals.device # do not backprop throught objectness @@ -278,8 +314,14 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_ final_scores.append(scores) return final_boxes, final_scores - def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets): - # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] + def compute_loss( + self, + objectness: Tensor, + pred_bbox_deltas: Tensor, + labels: List[Tensor], + regression_targets: List[Tensor], + ) -> Tuple[Tensor, Tensor]: + """ Args: objectness (Tensor) @@ -316,12 +358,12 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets) return objectness_loss, box_loss - def forward(self, - images, # type: ImageList - features, # type: Dict[str, Tensor] - targets=None # type: Optional[List[Dict[str, Tensor]]] - ): - # type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]] + def forward( + self, + images: ImageList, + features: Dict[str, Tensor], + targets: Optional[List[Dict[str, Tensor]]] = None, + ) -> Tuple[List[Tensor], Dict[str, Tensor]]: """ Args: images (ImageList): images for which we want to compute the predictions diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 08d48c68020..fa2293f7f63 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -4,7 +4,7 @@ from collections import OrderedDict from functools import partial from torch import nn, Tensor -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional from . import _utils as det_utils from .ssd import SSD, SSDScoringHead