Skip to content

Commit c790216

Browse files
oke-adityakhushi-411datumbox
authored
Add typing Annotations to detection/utils (#4583)
* Start annotating utils * checking * Add annotations at _utils.py * Remove unnecessary comments. * re-checked typings * Update typing * Ignore small error * Use optional tensor * Ignore for JIT Co-authored-by: Khushi Agrawal <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 9ad3326 commit c790216

File tree

2 files changed

+69
-26
lines changed

2 files changed

+69
-26
lines changed

mypy.ini

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,55 @@ ignore_errors = True
1717

1818
ignore_errors=True
1919

20-
[mypy-torchvision.models.detection.*]
20+
[mypy-torchvision.models.detection.anchor_utils]
21+
22+
ignore_errors = True
23+
24+
[mypy-torchvision.models.detection.backbone_utils]
25+
26+
ignore_errors = True
27+
28+
[mypy-torchvision.models.detection.image_list]
29+
30+
ignore_errors = True
31+
32+
[mypy-torchvision.models.detection.transform]
33+
34+
ignore_errors = True
35+
36+
[mypy-torchvision.models.detection.rpn]
37+
38+
ignore_errors = True
39+
40+
[mypy-torchvision.models.detection.roi_heads]
41+
42+
ignore_errors = True
43+
44+
[mypy-torchvision.models.detection.generalized_rcnn]
45+
46+
ignore_errors = True
47+
48+
[mypy-torchvision.models.detection.faster_rcnn]
49+
50+
ignore_errors = True
51+
52+
[mypy-torchvision.models.detection.mask_rcnn]
53+
54+
ignore_errors = True
55+
56+
[mypy-torchvision.models.detection.keypoint_rcnn]
57+
58+
ignore_errors = True
59+
60+
[mypy-torchvision.models.detection.retinanet]
61+
62+
ignore_errors = True
63+
64+
[mypy-torchvision.models.detection.ssd]
65+
66+
ignore_errors = True
67+
68+
[mypy-torchvision.models.detection.ssdlite]
2169

2270
ignore_errors = True
2371

torchvision/models/detection/_utils.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import List, Tuple
44

55
import torch
6-
from torch import Tensor
6+
from torch import Tensor, nn
77
from torchvision.ops.misc import FrozenBatchNorm2d
88

99

@@ -12,18 +12,16 @@ class BalancedPositiveNegativeSampler(object):
1212
This class samples batches, ensuring that they contain a fixed proportion of positives
1313
"""
1414

15-
def __init__(self, batch_size_per_image, positive_fraction):
16-
# type: (int, float) -> None
15+
def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
1716
"""
1817
Args:
1918
batch_size_per_image (int): number of elements to be selected per image
20-
positive_fraction (float): percentace of positive elements per batch
19+
positive_fraction (float): percentage of positive elements per batch
2120
"""
2221
self.batch_size_per_image = batch_size_per_image
2322
self.positive_fraction = positive_fraction
2423

25-
def __call__(self, matched_idxs):
26-
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
24+
def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
2725
"""
2826
Args:
2927
matched idxs: list of tensors containing -1, 0 or positive values.
@@ -73,8 +71,7 @@ def __call__(self, matched_idxs):
7371

7472

7573
@torch.jit._script_if_tracing
76-
def encode_boxes(reference_boxes, proposals, weights):
77-
# type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
74+
def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
7875
"""
7976
Encode a set of proposals with respect to some
8077
reference boxes
@@ -127,8 +124,9 @@ class BoxCoder(object):
127124
the representation used for training the regressors.
128125
"""
129126

130-
def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)):
131-
# type: (Tuple[float, float, float, float], float) -> None
127+
def __init__(
128+
self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
129+
) -> None:
132130
"""
133131
Args:
134132
weights (4-element tuple)
@@ -137,15 +135,14 @@ def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)):
137135
self.weights = weights
138136
self.bbox_xform_clip = bbox_xform_clip
139137

140-
def encode(self, reference_boxes, proposals):
141-
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
138+
def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
142139
boxes_per_image = [len(b) for b in reference_boxes]
143140
reference_boxes = torch.cat(reference_boxes, dim=0)
144141
proposals = torch.cat(proposals, dim=0)
145142
targets = self.encode_single(reference_boxes, proposals)
146143
return targets.split(boxes_per_image, 0)
147144

148-
def encode_single(self, reference_boxes, proposals):
145+
def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
149146
"""
150147
Encode a set of proposals with respect to some
151148
reference boxes
@@ -161,8 +158,7 @@ def encode_single(self, reference_boxes, proposals):
161158

162159
return targets
163160

164-
def decode(self, rel_codes, boxes):
165-
# type: (Tensor, List[Tensor]) -> Tensor
161+
def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
166162
assert isinstance(boxes, (list, tuple))
167163
assert isinstance(rel_codes, torch.Tensor)
168164
boxes_per_image = [b.size(0) for b in boxes]
@@ -177,7 +173,7 @@ def decode(self, rel_codes, boxes):
177173
pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
178174
return pred_boxes
179175

180-
def decode_single(self, rel_codes, boxes):
176+
def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
181177
"""
182178
From a set of original boxes and encoded relative box offsets,
183179
get the decoded boxes.
@@ -244,8 +240,7 @@ class Matcher(object):
244240
"BETWEEN_THRESHOLDS": int,
245241
}
246242

247-
def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
248-
# type: (float, float, bool) -> None
243+
def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
249244
"""
250245
Args:
251246
high_threshold (float): quality values greater than or equal to
@@ -266,7 +261,7 @@ def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=Fals
266261
self.low_threshold = low_threshold
267262
self.allow_low_quality_matches = allow_low_quality_matches
268263

269-
def __call__(self, match_quality_matrix):
264+
def __call__(self, match_quality_matrix: Tensor) -> Tensor:
270265
"""
271266
Args:
272267
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
@@ -290,7 +285,7 @@ def __call__(self, match_quality_matrix):
290285
if self.allow_low_quality_matches:
291286
all_matches = matches.clone()
292287
else:
293-
all_matches = None
288+
all_matches = None # type: ignore[assignment]
294289

295290
# Assign candidate matches with low quality to negative (unassigned) values
296291
below_low_threshold = matched_vals < self.low_threshold
@@ -304,7 +299,7 @@ def __call__(self, match_quality_matrix):
304299

305300
return matches
306301

307-
def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
302+
def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
308303
"""
309304
Produce additional matches for predictions that have only low-quality matches.
310305
Specifically, for each ground-truth find the set of predictions that have
@@ -335,10 +330,10 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
335330

336331

337332
class SSDMatcher(Matcher):
338-
def __init__(self, threshold):
333+
def __init__(self, threshold: float) -> None:
339334
super().__init__(threshold, threshold, allow_low_quality_matches=False)
340335

341-
def __call__(self, match_quality_matrix):
336+
def __call__(self, match_quality_matrix: Tensor) -> Tensor:
342337
matches = super().__call__(match_quality_matrix)
343338

344339
# For each gt, find the prediction with which it has the highest quality
@@ -350,7 +345,7 @@ def __call__(self, match_quality_matrix):
350345
return matches
351346

352347

353-
def overwrite_eps(model, eps):
348+
def overwrite_eps(model: nn.Module, eps: float) -> None:
354349
"""
355350
This method overwrites the default eps values of all the
356351
FrozenBatchNorm2d layers of the model with the provided value.
@@ -368,7 +363,7 @@ def overwrite_eps(model, eps):
368363
module.eps = eps
369364

370365

371-
def retrieve_out_channels(model, size):
366+
def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
372367
"""
373368
This method retrieves the number of output channels of a specific model.
374369

0 commit comments

Comments
 (0)