3
3
from typing import List , Tuple
4
4
5
5
import torch
6
- from torch import Tensor
6
+ from torch import Tensor , nn
7
7
from torchvision .ops .misc import FrozenBatchNorm2d
8
8
9
9
@@ -12,18 +12,16 @@ class BalancedPositiveNegativeSampler(object):
12
12
This class samples batches, ensuring that they contain a fixed proportion of positives
13
13
"""
14
14
15
- def __init__ (self , batch_size_per_image , positive_fraction ):
16
- # type: (int, float) -> None
15
+ def __init__ (self , batch_size_per_image : int , positive_fraction : float ) -> None :
17
16
"""
18
17
Args:
19
18
batch_size_per_image (int): number of elements to be selected per image
20
- positive_fraction (float): percentace of positive elements per batch
19
+ positive_fraction (float): percentage of positive elements per batch
21
20
"""
22
21
self .batch_size_per_image = batch_size_per_image
23
22
self .positive_fraction = positive_fraction
24
23
25
- def __call__ (self , matched_idxs ):
26
- # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
24
+ def __call__ (self , matched_idxs : List [Tensor ]) -> Tuple [List [Tensor ], List [Tensor ]]:
27
25
"""
28
26
Args:
29
27
matched idxs: list of tensors containing -1, 0 or positive values.
@@ -73,8 +71,7 @@ def __call__(self, matched_idxs):
73
71
74
72
75
73
@torch .jit ._script_if_tracing
76
- def encode_boxes (reference_boxes , proposals , weights ):
77
- # type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
74
+ def encode_boxes (reference_boxes : Tensor , proposals : Tensor , weights : Tensor ) -> Tensor :
78
75
"""
79
76
Encode a set of proposals with respect to some
80
77
reference boxes
@@ -127,8 +124,9 @@ class BoxCoder(object):
127
124
the representation used for training the regressors.
128
125
"""
129
126
130
- def __init__ (self , weights , bbox_xform_clip = math .log (1000.0 / 16 )):
131
- # type: (Tuple[float, float, float, float], float) -> None
127
+ def __init__ (
128
+ self , weights : Tuple [float , float , float , float ], bbox_xform_clip : float = math .log (1000.0 / 16 )
129
+ ) -> None :
132
130
"""
133
131
Args:
134
132
weights (4-element tuple)
@@ -137,15 +135,14 @@ def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)):
137
135
self .weights = weights
138
136
self .bbox_xform_clip = bbox_xform_clip
139
137
140
- def encode (self , reference_boxes , proposals ):
141
- # type: (List[Tensor], List[Tensor]) -> List[Tensor]
138
+ def encode (self , reference_boxes : List [Tensor ], proposals : List [Tensor ]) -> List [Tensor ]:
142
139
boxes_per_image = [len (b ) for b in reference_boxes ]
143
140
reference_boxes = torch .cat (reference_boxes , dim = 0 )
144
141
proposals = torch .cat (proposals , dim = 0 )
145
142
targets = self .encode_single (reference_boxes , proposals )
146
143
return targets .split (boxes_per_image , 0 )
147
144
148
- def encode_single (self , reference_boxes , proposals ) :
145
+ def encode_single (self , reference_boxes : Tensor , proposals : Tensor ) -> Tensor :
149
146
"""
150
147
Encode a set of proposals with respect to some
151
148
reference boxes
@@ -161,8 +158,7 @@ def encode_single(self, reference_boxes, proposals):
161
158
162
159
return targets
163
160
164
- def decode (self , rel_codes , boxes ):
165
- # type: (Tensor, List[Tensor]) -> Tensor
161
+ def decode (self , rel_codes : Tensor , boxes : List [Tensor ]) -> Tensor :
166
162
assert isinstance (boxes , (list , tuple ))
167
163
assert isinstance (rel_codes , torch .Tensor )
168
164
boxes_per_image = [b .size (0 ) for b in boxes ]
@@ -177,7 +173,7 @@ def decode(self, rel_codes, boxes):
177
173
pred_boxes = pred_boxes .reshape (box_sum , - 1 , 4 )
178
174
return pred_boxes
179
175
180
- def decode_single (self , rel_codes , boxes ) :
176
+ def decode_single (self , rel_codes : Tensor , boxes : Tensor ) -> Tensor :
181
177
"""
182
178
From a set of original boxes and encoded relative box offsets,
183
179
get the decoded boxes.
@@ -244,8 +240,7 @@ class Matcher(object):
244
240
"BETWEEN_THRESHOLDS" : int ,
245
241
}
246
242
247
- def __init__ (self , high_threshold , low_threshold , allow_low_quality_matches = False ):
248
- # type: (float, float, bool) -> None
243
+ def __init__ (self , high_threshold : float , low_threshold : float , allow_low_quality_matches : bool = False ) -> None :
249
244
"""
250
245
Args:
251
246
high_threshold (float): quality values greater than or equal to
@@ -266,7 +261,7 @@ def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=Fals
266
261
self .low_threshold = low_threshold
267
262
self .allow_low_quality_matches = allow_low_quality_matches
268
263
269
- def __call__ (self , match_quality_matrix ) :
264
+ def __call__ (self , match_quality_matrix : Tensor ) -> Tensor :
270
265
"""
271
266
Args:
272
267
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
@@ -290,7 +285,7 @@ def __call__(self, match_quality_matrix):
290
285
if self .allow_low_quality_matches :
291
286
all_matches = matches .clone ()
292
287
else :
293
- all_matches = None
288
+ all_matches = None # type: ignore[assignment]
294
289
295
290
# Assign candidate matches with low quality to negative (unassigned) values
296
291
below_low_threshold = matched_vals < self .low_threshold
@@ -304,7 +299,7 @@ def __call__(self, match_quality_matrix):
304
299
305
300
return matches
306
301
307
- def set_low_quality_matches_ (self , matches , all_matches , match_quality_matrix ) :
302
+ def set_low_quality_matches_ (self , matches : Tensor , all_matches : Tensor , match_quality_matrix : Tensor ) -> None :
308
303
"""
309
304
Produce additional matches for predictions that have only low-quality matches.
310
305
Specifically, for each ground-truth find the set of predictions that have
@@ -335,10 +330,10 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
335
330
336
331
337
332
class SSDMatcher (Matcher ):
338
- def __init__ (self , threshold ) :
333
+ def __init__ (self , threshold : float ) -> None :
339
334
super ().__init__ (threshold , threshold , allow_low_quality_matches = False )
340
335
341
- def __call__ (self , match_quality_matrix ) :
336
+ def __call__ (self , match_quality_matrix : Tensor ) -> Tensor :
342
337
matches = super ().__call__ (match_quality_matrix )
343
338
344
339
# For each gt, find the prediction with which it has the highest quality
@@ -350,7 +345,7 @@ def __call__(self, match_quality_matrix):
350
345
return matches
351
346
352
347
353
- def overwrite_eps (model , eps ) :
348
+ def overwrite_eps (model : nn . Module , eps : float ) -> None :
354
349
"""
355
350
This method overwrites the default eps values of all the
356
351
FrozenBatchNorm2d layers of the model with the provided value.
@@ -368,7 +363,7 @@ def overwrite_eps(model, eps):
368
363
module .eps = eps
369
364
370
365
371
- def retrieve_out_channels (model , size ) :
366
+ def retrieve_out_channels (model : nn . Module , size : Tuple [ int , int ]) -> List [ int ] :
372
367
"""
373
368
This method retrieves the number of output channels of a specific model.
374
369
0 commit comments