Skip to content

Add typing annotations to detection/rpn #4619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 15, 2021
4 changes: 0 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ ignore_errors = True

ignore_errors = True

[mypy-torchvision.models.detection.rpn]

ignore_errors = True

[mypy-torchvision.models.detection.roi_heads]

ignore_errors = True
Expand Down
99 changes: 51 additions & 48 deletions torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Dict, Tuple
from typing import List, Optional, Dict, Tuple, cast

import torch
import torchvision
Expand All @@ -14,14 +14,14 @@


@torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
# type: (Tensor, int) -> Tuple[int, int]
def _onnx_get_num_anchors_and_pre_nms_top_n(ob: Tensor, orig_pre_nms_top_n: int) -> Tuple[int, int]:
from torch.onnx import operators

num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
pre_nms_top_n = torch.min(torch.cat((torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0))

return num_anchors, pre_nms_top_n
# for mypy we cast at runtime
return cast(int, num_anchors), cast(int, pre_nms_top_n)


class RPNHead(nn.Module):
Expand All @@ -33,18 +33,17 @@ class RPNHead(nn.Module):
num_anchors (int): number of anchors to be predicted
"""

def __init__(self, in_channels, num_anchors):
def __init__(self, in_channels: int, num_anchors: int) -> None:
super(RPNHead, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)

for layer in self.children():
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.constant_(layer.bias, 0)
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]

def forward(self, x):
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
logits = []
bbox_reg = []
for feature in x:
Expand All @@ -54,16 +53,14 @@ def forward(self, x):
return logits, bbox_reg


def permute_and_flatten(layer, N, A, C, H, W):
# type: (Tensor, int, int, int, int, int) -> Tensor
def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor:
layer = layer.view(N, -1, C, H, W)
layer = layer.permute(0, 3, 4, 1, 2)
layer = layer.reshape(N, -1, C)
return layer


def concat_box_prediction_layers(box_cls, box_regression):
# type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
def concat_box_prediction_layers(box_cls: List[Tensor], box_regression: List[Tensor]) -> Tuple[Tensor, Tensor]:
box_cls_flattened = []
box_regression_flattened = []
# for each feature level, permute the outputs to make them be in the
Expand Down Expand Up @@ -104,10 +101,10 @@ class RegionProposalNetwork(torch.nn.Module):
for computing the loss
positive_fraction (float): proportion of positive anchors in a mini-batch during training
of the RPN
pre_nms_top_n (Dict[int]): number of proposals to keep before applying NMS. It should
pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should
contain two fields: training and testing, to allow for different values depending
on training or evaluation
post_nms_top_n (Dict[int]): number of proposals to keep after applying NMS. It should
post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should
contain two fields: training and testing, to allow for different values depending
on training or evaluation
nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
Expand All @@ -118,25 +115,23 @@ class RegionProposalNetwork(torch.nn.Module):
"box_coder": det_utils.BoxCoder,
"proposal_matcher": det_utils.Matcher,
"fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
"pre_nms_top_n": Dict[str, int],
"post_nms_top_n": Dict[str, int],
}

def __init__(
self,
anchor_generator,
head,
#
fg_iou_thresh,
bg_iou_thresh,
batch_size_per_image,
positive_fraction,
#
pre_nms_top_n,
post_nms_top_n,
nms_thresh,
score_thresh=0.0,
):
anchor_generator: AnchorGenerator,
head: nn.Module,
# Faster-RCNN Training
fg_iou_thresh: float,
bg_iou_thresh: float,
batch_size_per_image: int,
positive_fraction: float,
# Faster-RCNN Inference
pre_nms_top_n: Dict[str, int],
post_nms_top_n: Dict[str, int],
nms_thresh: float,
score_thresh: float = 0.0,
) -> None:
super(RegionProposalNetwork, self).__init__()
self.anchor_generator = anchor_generator
self.head = head
Expand All @@ -159,18 +154,20 @@ def __init__(
self.score_thresh = score_thresh
self.min_size = 1e-3

def pre_nms_top_n(self):
def pre_nms_top_n(self) -> int:
if self.training:
return self._pre_nms_top_n["training"]
return self._pre_nms_top_n["testing"]

def post_nms_top_n(self):
def post_nms_top_n(self) -> int:
if self.training:
return self._post_nms_top_n["training"]
return self._post_nms_top_n["testing"]

def assign_targets_to_anchors(self, anchors, targets):
# type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]]
def assign_targets_to_anchors(
self, anchors: List[Tensor], targets: List[Dict[str, Tensor]]
) -> Tuple[List[Tensor], List[Tensor]]:

labels = []
matched_gt_boxes = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
Expand Down Expand Up @@ -205,8 +202,7 @@ def assign_targets_to_anchors(self, anchors, targets):
matched_gt_boxes.append(matched_gt_boxes_per_image)
return labels, matched_gt_boxes

def _get_top_n_idx(self, objectness, num_anchors_per_level):
# type: (Tensor, List[int]) -> Tensor
def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -> Tensor:
r = []
offset = 0
for ob in objectness.split(num_anchors_per_level, 1):
Expand All @@ -220,8 +216,14 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level):
offset += num_anchors
return torch.cat(r, dim=1)

def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level):
# type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> Tuple[List[Tensor], List[Tensor]]
def filter_proposals(
self,
proposals: Tensor,
objectness: Tensor,
image_shapes: List[Tuple[int, int]],
num_anchors_per_level: List[int],
) -> Tuple[List[Tensor], List[Tensor]]:

num_images = proposals.shape[0]
device = proposals.device
# do not backprop through objectness
Expand Down Expand Up @@ -271,8 +273,9 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_
final_scores.append(scores)
return final_boxes, final_scores

def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
def compute_loss(
self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: List[Tensor], regression_targets: List[Tensor]
) -> Tuple[Tensor, Tensor]:
"""
Args:
objectness (Tensor)
Expand Down Expand Up @@ -312,25 +315,25 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets)

def forward(
self,
images, # type: ImageList
features, # type: Dict[str, Tensor]
targets=None, # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]]
images: ImageList,
features: Dict[str, Tensor],
targets: Optional[List[Dict[str, Tensor]]] = None,
) -> Tuple[List[Tensor], Dict[str, Tensor]]:

"""
Args:
images (ImageList): images for which we want to compute the predictions
features (OrderedDict[Tensor]): features computed from the images that are
features (Dict[str, Tensor]): features computed from the images that are
used for computing the predictions. Each tensor in the list
correspond to different feature levels
targets (List[Dict[Tensor]]): ground-truth boxes present in the image (optional).
targets (List[Dict[str, Tensor]]): ground-truth boxes present in the image (optional).
If provided, each element in the dict should contain a field `boxes`,
with the locations of the ground-truth boxes.

Returns:
boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
image.
losses (Dict[Tensor]): the losses for the model during training. During
losses (Dict[str, Tensor]): the losses for the model during training. During
testing, it is an empty dict.
"""
# RPN uses all feature maps that are available
Expand Down