Skip to content

Commit edb2a1a

Browse files
authored
Merge pull request pytorch#3 from o295/main
Fixing python lint, docstrings and add typing annotations
2 parents 5d75049 + d4c08d3 commit edb2a1a

File tree

1 file changed

+51
-17
lines changed
  • torchvision/models/detection

1 file changed

+51
-17
lines changed

torchvision/models/detection/fcos.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
import warnings
33
from collections import OrderedDict
4+
from functools import partial
45
from typing import Dict, List, Tuple, Optional
56

67
import torch
@@ -26,6 +27,7 @@
2627
class FCOSHead(nn.Module):
2728
"""
2829
A regression and classification head for use in FCOS.
30+
2931
Args:
3032
in_channels (int): number of channels of the input feature
3133
num_anchors (int): number of anchors to be predicted
@@ -117,6 +119,7 @@ def forward(self, x):
117119
class FCOSClassificationHead(nn.Module):
118120
"""
119121
A classification head for use in FCOS.
122+
120123
Args:
121124
in_channels (int): number of channels of the input feature
122125
num_anchors (int): number of anchors to be predicted
@@ -131,7 +134,7 @@ def __init__(self, in_channels, num_anchors, num_classes, num_convs=4, prior_pro
131134
self.num_anchors = num_anchors
132135

133136
if norm_layer is None:
134-
norm_layer = lambda channels: nn.GroupNorm(32, channels)
137+
norm_layer = partial(nn.GroupNorm, 32)
135138

136139
conv = []
137140
for _ in range(num_convs):
@@ -149,8 +152,7 @@ def __init__(self, in_channels, num_anchors, num_classes, num_convs=4, prior_pro
149152
torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
150153
torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
151154

152-
def forward(self, x):
153-
# type: (List[Tensor]) -> Tensor
155+
def forward(self, x: List[Tensor]) -> Tensor:
154156
all_cls_logits = []
155157

156158
for features in x:
@@ -171,6 +173,7 @@ def forward(self, x):
171173
class FCOSRegressionHead(nn.Module):
172174
"""
173175
A regression head for use in FCOS.
176+
174177
Args:
175178
in_channels (int): number of channels of the input feature
176179
num_anchors (int): number of anchors to be predicted
@@ -181,7 +184,7 @@ def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None):
181184
super().__init__()
182185

183186
if norm_layer is None:
184-
norm_layer = lambda channels: nn.GroupNorm(32, channels)
187+
norm_layer = partial(nn.GroupNorm, 32)
185188

186189
conv = []
187190
for _ in range(num_convs):
@@ -201,8 +204,7 @@ def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None):
201204
torch.nn.init.normal_(layer.weight, std=0.01)
202205
torch.nn.init.zeros_(layer.bias)
203206

204-
def forward(self, x):
205-
# type: (List[Tensor]) -> Tensor
207+
def forward(self, x: List[Tensor]) -> Tensor:
206208
all_bbox_regression = []
207209
all_bbox_ctrness = []
208210

@@ -230,23 +232,29 @@ def forward(self, x):
230232
class FCOS(nn.Module):
231233
"""
232234
Implements FCOS.
235+
233236
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
234237
image, and should be in 0-1 range. Different images can have different sizes.
238+
235239
The behavior of the model changes depending if it is in training or evaluation mode.
240+
236241
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
237242
containing:
238243
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
239244
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
240245
- labels (Int64Tensor[N]): the class label for each ground-truth box
246+
241247
The model returns a Dict[Tensor] during training, containing the classification and regression
242248
losses.
249+
243250
During inference, the model requires only the input tensors, and returns the post-processed
244251
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
245252
follows:
246253
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
247254
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
248255
- labels (Int64Tensor[N]): the predicted labels for each image
249256
- scores (Tensor[N]): the scores for each prediction
257+
250258
Args:
251259
backbone (nn.Module): the network used to compute the features for the model.
252260
It should contain an out_channels attribute, which indicates the number of output
@@ -272,7 +280,9 @@ class FCOS(nn.Module):
272280
nms_thresh (float): NMS threshold used for postprocessing the detections.
273281
detections_per_img (int): Number of best detections to keep after NMS.
274282
topk_candidates (int): Number of best detections to keep before NMS.
283+
275284
Example:
285+
276286
>>> import torch
277287
>>> import torchvision
278288
>>> from torchvision.models.detection import FCOS
@@ -364,15 +374,23 @@ def __init__(
364374
self._has_warned = False
365375

366376
@torch.jit.unused
367-
def eager_outputs(self, losses, detections):
368-
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
377+
def eager_outputs(
378+
self,
379+
losses: Dict[str, Tensor],
380+
detections: List[Dict[str, Tensor]]
381+
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
369382
if self.training:
370383
return losses
371384

372385
return detections
373386

374-
def compute_loss(self, targets, head_outputs, anchors, num_anchors_per_level):
375-
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[int]) -> Dict[str, Tensor]
387+
def compute_loss(
388+
self,
389+
targets: List[Dict[str, Tensor]],
390+
head_outputs: Dict[str, Tensor],
391+
anchors: List[Tensor],
392+
num_anchors_per_level: List[int],
393+
) -> Dict[str, Tensor]:
376394
matched_idxs = []
377395
for anchors_per_image, targets_per_image in zip(anchors, targets):
378396
if targets_per_image["boxes"].numel() == 0:
@@ -417,8 +435,12 @@ def compute_loss(self, targets, head_outputs, anchors, num_anchors_per_level):
417435

418436
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs, self.box_coder)
419437

420-
def postprocess_detections(self, head_outputs, anchors, image_shapes):
421-
# type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
438+
def postprocess_detections(
439+
self,
440+
head_outputs: Dict[str, List[Tensor]],
441+
anchors: List[List[Tensor]],
442+
image_shapes: List[Tuple[int, int]]
443+
) -> List[Dict[str, Tensor]]:
422444
class_logits = head_outputs["cls_logits"]
423445
box_regression = head_outputs["bbox_regression"]
424446
box_ctrness = head_outputs["bbox_ctrness"]
@@ -484,12 +506,16 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):
484506

485507
return detections
486508

487-
def forward(self, images, targets=None):
488-
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
509+
def forward(
510+
self,
511+
images: List[Tensor],
512+
targets: Optional[List[Dict[str, Tensor]]] = None,
513+
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
489514
"""
490515
Args:
491516
images (list[Tensor]): images to be processed
492517
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
518+
493519
Returns:
494520
result (list[BoxList] or dict[Tensor]): the output from the model.
495521
During training, it returns a dict[Tensor] which contains the losses.
@@ -570,14 +596,15 @@ def forward(self, images, targets=None):
570596

571597
if torch.jit.is_scripting():
572598
if not self._has_warned:
573-
warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting")
599+
warnings.warn("FCOS always returns a (Losses, Detections) tuple in scripting")
574600
self._has_warned = True
575601
return losses, detections
576602
return self.eager_outputs(losses, detections)
577603

578604

579605
model_urls = {
580-
"fcos_resnet50_fpn_coco": "",
606+
"fcos_resnet50_fpn_coco":
607+
"https://github.com/o295/checkpoints/releases/download/coco/fcos_resnet50_fpn_coco-46080c1a.pth",
581608
}
582609

583610

@@ -587,16 +614,20 @@ def fcos_resnet50_fpn(
587614
"""
588615
Constructs a FCOS model with a ResNet-50-FPN backbone.
589616
Reference: `"FCOS: Fully Convolutional One-Stage Object Detection" <https://arxiv.org/abs/1904.01355>`_.
617+
590618
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
591619
image, and should be in ``0-1`` range. Different images can have different sizes.
620+
592621
The behavior of the model changes depending if it is in training or evaluation mode.
622+
593623
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
594624
containing:
595625
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
596626
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
597627
- labels (``Int64Tensor[N]``): the class label for each ground-truth box
598628
The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
599629
losses.
630+
600631
During inference, the model requires only the input tensors, and returns the post-processed
601632
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
602633
follows, where ``N`` is the number of detections:
@@ -605,11 +636,14 @@ def fcos_resnet50_fpn(
605636
- labels (``Int64Tensor[N]``): the predicted labels for each detection
606637
- scores (``Tensor[N]``): the scores of each detection
607638
For more details on the output, you may refer to :ref:`instance_seg_output`.
608-
Example::
639+
640+
Example:
641+
609642
>>> model = torchvision.models.detection.fcos_resnet50_fpn(pretrained=True)
610643
>>> model.eval()
611644
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
612645
>>> predictions = model(x)
646+
613647
Args:
614648
pretrained (bool): If True, returns a model pre-trained on COCO train2017
615649
progress (bool): If True, displays a progress bar of the download to stderr

0 commit comments

Comments
 (0)