diff --git a/mypy.ini b/mypy.ini index dbcaab8770c..a2733d3ae3b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -29,10 +29,6 @@ ignore_errors = True ignore_errors = True -[mypy-torchvision.models.detection.rpn] - -ignore_errors = True - [mypy-torchvision.models.detection.roi_heads] ignore_errors = True diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index c58e1a37af0..4994eeba35b 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Dict, Tuple +from typing import List, Optional, Dict, Tuple, cast import torch import torchvision @@ -14,14 +14,14 @@ @torch.jit.unused -def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): - # type: (Tensor, int) -> Tuple[int, int] +def _onnx_get_num_anchors_and_pre_nms_top_n(ob: Tensor, orig_pre_nms_top_n: int) -> Tuple[int, int]: from torch.onnx import operators num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0) pre_nms_top_n = torch.min(torch.cat((torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0)) - return num_anchors, pre_nms_top_n + # for mypy we cast at runtime + return cast(int, num_anchors), cast(int, pre_nms_top_n) class RPNHead(nn.Module): @@ -33,18 +33,17 @@ class RPNHead(nn.Module): num_anchors (int): number of anchors to be predicted """ - def __init__(self, in_channels, num_anchors): + def __init__(self, in_channels: int, num_anchors: int) -> None: super(RPNHead, self).__init__() self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1) for layer in self.children(): - torch.nn.init.normal_(layer.weight, std=0.01) - torch.nn.init.constant_(layer.bias, 0) + torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type] + torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type] - def forward(self, x): - # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] + def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: logits = [] bbox_reg = [] for feature in x: @@ -54,16 +53,14 @@ def forward(self, x): return logits, bbox_reg -def permute_and_flatten(layer, N, A, C, H, W): - # type: (Tensor, int, int, int, int, int) -> Tensor +def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor: layer = layer.view(N, -1, C, H, W) layer = layer.permute(0, 3, 4, 1, 2) layer = layer.reshape(N, -1, C) return layer -def concat_box_prediction_layers(box_cls, box_regression): - # type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] +def concat_box_prediction_layers(box_cls: List[Tensor], box_regression: List[Tensor]) -> Tuple[Tensor, Tensor]: box_cls_flattened = [] box_regression_flattened = [] # for each feature level, permute the outputs to make them be in the @@ -104,10 +101,10 @@ class RegionProposalNetwork(torch.nn.Module): for computing the loss positive_fraction (float): proportion of positive anchors in a mini-batch during training of the RPN - pre_nms_top_n (Dict[int]): number of proposals to keep before applying NMS. It should + pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should contain two fields: training and testing, to allow for different values depending on training or evaluation - post_nms_top_n (Dict[int]): number of proposals to keep after applying NMS. It should + post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should contain two fields: training and testing, to allow for different values depending on training or evaluation nms_thresh (float): NMS threshold used for postprocessing the RPN proposals @@ -118,25 +115,23 @@ class RegionProposalNetwork(torch.nn.Module): "box_coder": det_utils.BoxCoder, "proposal_matcher": det_utils.Matcher, "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler, - "pre_nms_top_n": Dict[str, int], - "post_nms_top_n": Dict[str, int], } def __init__( self, - anchor_generator, - head, - # - fg_iou_thresh, - bg_iou_thresh, - batch_size_per_image, - positive_fraction, - # - pre_nms_top_n, - post_nms_top_n, - nms_thresh, - score_thresh=0.0, - ): + anchor_generator: AnchorGenerator, + head: nn.Module, + # Faster-RCNN Training + fg_iou_thresh: float, + bg_iou_thresh: float, + batch_size_per_image: int, + positive_fraction: float, + # Faster-RCNN Inference + pre_nms_top_n: Dict[str, int], + post_nms_top_n: Dict[str, int], + nms_thresh: float, + score_thresh: float = 0.0, + ) -> None: super(RegionProposalNetwork, self).__init__() self.anchor_generator = anchor_generator self.head = head @@ -159,18 +154,20 @@ def __init__( self.score_thresh = score_thresh self.min_size = 1e-3 - def pre_nms_top_n(self): + def pre_nms_top_n(self) -> int: if self.training: return self._pre_nms_top_n["training"] return self._pre_nms_top_n["testing"] - def post_nms_top_n(self): + def post_nms_top_n(self) -> int: if self.training: return self._post_nms_top_n["training"] return self._post_nms_top_n["testing"] - def assign_targets_to_anchors(self, anchors, targets): - # type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]] + def assign_targets_to_anchors( + self, anchors: List[Tensor], targets: List[Dict[str, Tensor]] + ) -> Tuple[List[Tensor], List[Tensor]]: + labels = [] matched_gt_boxes = [] for anchors_per_image, targets_per_image in zip(anchors, targets): @@ -205,8 +202,7 @@ def assign_targets_to_anchors(self, anchors, targets): matched_gt_boxes.append(matched_gt_boxes_per_image) return labels, matched_gt_boxes - def _get_top_n_idx(self, objectness, num_anchors_per_level): - # type: (Tensor, List[int]) -> Tensor + def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -> Tensor: r = [] offset = 0 for ob in objectness.split(num_anchors_per_level, 1): @@ -220,8 +216,14 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level): offset += num_anchors return torch.cat(r, dim=1) - def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level): - # type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> Tuple[List[Tensor], List[Tensor]] + def filter_proposals( + self, + proposals: Tensor, + objectness: Tensor, + image_shapes: List[Tuple[int, int]], + num_anchors_per_level: List[int], + ) -> Tuple[List[Tensor], List[Tensor]]: + num_images = proposals.shape[0] device = proposals.device # do not backprop through objectness @@ -271,8 +273,9 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_ final_scores.append(scores) return final_boxes, final_scores - def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets): - # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] + def compute_loss( + self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: List[Tensor], regression_targets: List[Tensor] + ) -> Tuple[Tensor, Tensor]: """ Args: objectness (Tensor) @@ -312,25 +315,25 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets) def forward( self, - images, # type: ImageList - features, # type: Dict[str, Tensor] - targets=None, # type: Optional[List[Dict[str, Tensor]]] - ): - # type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]] + images: ImageList, + features: Dict[str, Tensor], + targets: Optional[List[Dict[str, Tensor]]] = None, + ) -> Tuple[List[Tensor], Dict[str, Tensor]]: + """ Args: images (ImageList): images for which we want to compute the predictions - features (OrderedDict[Tensor]): features computed from the images that are + features (Dict[str, Tensor]): features computed from the images that are used for computing the predictions. Each tensor in the list correspond to different feature levels - targets (List[Dict[Tensor]]): ground-truth boxes present in the image (optional). + targets (List[Dict[str, Tensor]]): ground-truth boxes present in the image (optional). If provided, each element in the dict should contain a field `boxes`, with the locations of the ground-truth boxes. Returns: boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per image. - losses (Dict[Tensor]): the losses for the model during training. During + losses (Dict[str, Tensor]): the losses for the model during training. During testing, it is an empty dict. """ # RPN uses all feature maps that are available