Skip to content

Commit 64cd24e

Browse files
Vincent Moensdatumbox
authored andcommitted
[fbsync] Add typing annotations to detection/rpn (#4619)
Summary: * Annotate rpn * Small fix * Small fix and ignore Reviewed By: NicolasHug Differential Revision: D31758315 fbshipit-source-id: 4d16ddc96f26f4c01676cf2f9f6ab3208917ad8a Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 16ff526 commit 64cd24e

File tree

2 files changed

+51
-52
lines changed

2 files changed

+51
-52
lines changed

mypy.ini

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@ ignore_errors = True
2929

3030
ignore_errors = True
3131

32-
[mypy-torchvision.models.detection.rpn]
33-
34-
ignore_errors = True
35-
3632
[mypy-torchvision.models.detection.roi_heads]
3733

3834
ignore_errors = True

torchvision/models/detection/rpn.py

Lines changed: 51 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Dict, Tuple
1+
from typing import List, Optional, Dict, Tuple, cast
22

33
import torch
44
import torchvision
@@ -14,14 +14,14 @@
1414

1515

1616
@torch.jit.unused
17-
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
18-
# type: (Tensor, int) -> Tuple[int, int]
17+
def _onnx_get_num_anchors_and_pre_nms_top_n(ob: Tensor, orig_pre_nms_top_n: int) -> Tuple[int, int]:
1918
from torch.onnx import operators
2019

2120
num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
2221
pre_nms_top_n = torch.min(torch.cat((torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0))
2322

24-
return num_anchors, pre_nms_top_n
23+
# for mypy we cast at runtime
24+
return cast(int, num_anchors), cast(int, pre_nms_top_n)
2525

2626

2727
class RPNHead(nn.Module):
@@ -33,18 +33,17 @@ class RPNHead(nn.Module):
3333
num_anchors (int): number of anchors to be predicted
3434
"""
3535

36-
def __init__(self, in_channels, num_anchors):
36+
def __init__(self, in_channels: int, num_anchors: int) -> None:
3737
super(RPNHead, self).__init__()
3838
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
3939
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
4040
self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
4141

4242
for layer in self.children():
43-
torch.nn.init.normal_(layer.weight, std=0.01)
44-
torch.nn.init.constant_(layer.bias, 0)
43+
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
44+
torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
4545

46-
def forward(self, x):
47-
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
46+
def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
4847
logits = []
4948
bbox_reg = []
5049
for feature in x:
@@ -54,16 +53,14 @@ def forward(self, x):
5453
return logits, bbox_reg
5554

5655

57-
def permute_and_flatten(layer, N, A, C, H, W):
58-
# type: (Tensor, int, int, int, int, int) -> Tensor
56+
def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor:
5957
layer = layer.view(N, -1, C, H, W)
6058
layer = layer.permute(0, 3, 4, 1, 2)
6159
layer = layer.reshape(N, -1, C)
6260
return layer
6361

6462

65-
def concat_box_prediction_layers(box_cls, box_regression):
66-
# type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
63+
def concat_box_prediction_layers(box_cls: List[Tensor], box_regression: List[Tensor]) -> Tuple[Tensor, Tensor]:
6764
box_cls_flattened = []
6865
box_regression_flattened = []
6966
# for each feature level, permute the outputs to make them be in the
@@ -104,10 +101,10 @@ class RegionProposalNetwork(torch.nn.Module):
104101
for computing the loss
105102
positive_fraction (float): proportion of positive anchors in a mini-batch during training
106103
of the RPN
107-
pre_nms_top_n (Dict[int]): number of proposals to keep before applying NMS. It should
104+
pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should
108105
contain two fields: training and testing, to allow for different values depending
109106
on training or evaluation
110-
post_nms_top_n (Dict[int]): number of proposals to keep after applying NMS. It should
107+
post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should
111108
contain two fields: training and testing, to allow for different values depending
112109
on training or evaluation
113110
nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
@@ -118,25 +115,23 @@ class RegionProposalNetwork(torch.nn.Module):
118115
"box_coder": det_utils.BoxCoder,
119116
"proposal_matcher": det_utils.Matcher,
120117
"fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
121-
"pre_nms_top_n": Dict[str, int],
122-
"post_nms_top_n": Dict[str, int],
123118
}
124119

125120
def __init__(
126121
self,
127-
anchor_generator,
128-
head,
129-
#
130-
fg_iou_thresh,
131-
bg_iou_thresh,
132-
batch_size_per_image,
133-
positive_fraction,
134-
#
135-
pre_nms_top_n,
136-
post_nms_top_n,
137-
nms_thresh,
138-
score_thresh=0.0,
139-
):
122+
anchor_generator: AnchorGenerator,
123+
head: nn.Module,
124+
# Faster-RCNN Training
125+
fg_iou_thresh: float,
126+
bg_iou_thresh: float,
127+
batch_size_per_image: int,
128+
positive_fraction: float,
129+
# Faster-RCNN Inference
130+
pre_nms_top_n: Dict[str, int],
131+
post_nms_top_n: Dict[str, int],
132+
nms_thresh: float,
133+
score_thresh: float = 0.0,
134+
) -> None:
140135
super(RegionProposalNetwork, self).__init__()
141136
self.anchor_generator = anchor_generator
142137
self.head = head
@@ -159,18 +154,20 @@ def __init__(
159154
self.score_thresh = score_thresh
160155
self.min_size = 1e-3
161156

162-
def pre_nms_top_n(self):
157+
def pre_nms_top_n(self) -> int:
163158
if self.training:
164159
return self._pre_nms_top_n["training"]
165160
return self._pre_nms_top_n["testing"]
166161

167-
def post_nms_top_n(self):
162+
def post_nms_top_n(self) -> int:
168163
if self.training:
169164
return self._post_nms_top_n["training"]
170165
return self._post_nms_top_n["testing"]
171166

172-
def assign_targets_to_anchors(self, anchors, targets):
173-
# type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]]
167+
def assign_targets_to_anchors(
168+
self, anchors: List[Tensor], targets: List[Dict[str, Tensor]]
169+
) -> Tuple[List[Tensor], List[Tensor]]:
170+
174171
labels = []
175172
matched_gt_boxes = []
176173
for anchors_per_image, targets_per_image in zip(anchors, targets):
@@ -205,8 +202,7 @@ def assign_targets_to_anchors(self, anchors, targets):
205202
matched_gt_boxes.append(matched_gt_boxes_per_image)
206203
return labels, matched_gt_boxes
207204

208-
def _get_top_n_idx(self, objectness, num_anchors_per_level):
209-
# type: (Tensor, List[int]) -> Tensor
205+
def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -> Tensor:
210206
r = []
211207
offset = 0
212208
for ob in objectness.split(num_anchors_per_level, 1):
@@ -220,8 +216,14 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level):
220216
offset += num_anchors
221217
return torch.cat(r, dim=1)
222218

223-
def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level):
224-
# type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> Tuple[List[Tensor], List[Tensor]]
219+
def filter_proposals(
220+
self,
221+
proposals: Tensor,
222+
objectness: Tensor,
223+
image_shapes: List[Tuple[int, int]],
224+
num_anchors_per_level: List[int],
225+
) -> Tuple[List[Tensor], List[Tensor]]:
226+
225227
num_images = proposals.shape[0]
226228
device = proposals.device
227229
# do not backprop through objectness
@@ -271,8 +273,9 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_
271273
final_scores.append(scores)
272274
return final_boxes, final_scores
273275

274-
def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
275-
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
276+
def compute_loss(
277+
self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: List[Tensor], regression_targets: List[Tensor]
278+
) -> Tuple[Tensor, Tensor]:
276279
"""
277280
Args:
278281
objectness (Tensor)
@@ -312,25 +315,25 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets)
312315

313316
def forward(
314317
self,
315-
images, # type: ImageList
316-
features, # type: Dict[str, Tensor]
317-
targets=None, # type: Optional[List[Dict[str, Tensor]]]
318-
):
319-
# type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]]
318+
images: ImageList,
319+
features: Dict[str, Tensor],
320+
targets: Optional[List[Dict[str, Tensor]]] = None,
321+
) -> Tuple[List[Tensor], Dict[str, Tensor]]:
322+
320323
"""
321324
Args:
322325
images (ImageList): images for which we want to compute the predictions
323-
features (OrderedDict[Tensor]): features computed from the images that are
326+
features (Dict[str, Tensor]): features computed from the images that are
324327
used for computing the predictions. Each tensor in the list
325328
correspond to different feature levels
326-
targets (List[Dict[Tensor]]): ground-truth boxes present in the image (optional).
329+
targets (List[Dict[str, Tensor]]): ground-truth boxes present in the image (optional).
327330
If provided, each element in the dict should contain a field `boxes`,
328331
with the locations of the ground-truth boxes.
329332
330333
Returns:
331334
boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
332335
image.
333-
losses (Dict[Tensor]): the losses for the model during training. During
336+
losses (Dict[str, Tensor]): the losses for the model during training. During
334337
testing, it is an empty dict.
335338
"""
336339
# RPN uses all feature maps that are available

0 commit comments

Comments
 (0)