Skip to content

Commit 5bb81c8

Browse files
fmassahgaiserHans GaiserHans Gaiser
authored
RetinaNet object detection (take 2) (#2784)
* Add rough implementation of RetinaNet. * Move AnchorGenerator to a seperate file. * Move box similarity to Matcher. * Expose extra blocks in FPN. * Expose retinanet in __init__.py. * Use P6 and P7 in FPN for retinanet. * Use parameters from retinanet for anchor generation. * General fixes for retinanet model. * Implement loss for retinanet heads. * Output reshaped outputs from retinanet heads. * Add postprocessing of detections. * Small fixes. * Remove unused argument. * Remove python2 invocation of super. * Add postprocessing for additional outputs. * Add missing import of ImageList. * Remove redundant import. * Simplify class correction. * Fix pylint warnings. * Remove the label adjustment for background class. * Set default score threshold to 0.05. * Add weight initialization for regression layer. * Allow training on images with no annotations. * Use smooth_l1_loss with beta value. * Add more typehints for TorchScript conversions. * Fix linting issues. * Fix type hints in postprocess_detections. * Fix type annotations for TorchScript. * Fix inconsistency with matched_idxs. * Add retinanet model test. * Add missing JIT annotations. * Remove redundant model construction Make tests pass * Fix bugs during training on newer PyTorch and unused params in DDP Needs cleanup and to add back support for images with no annotations * Cleanup resnet_fpn_backbone * Use L1 loss for regression Gives 1mAP improvement over smooth l1 * Disable support for images with no annotations Need to fix distributed first * Fix retinanet tests Need to deduplicate those box checks * Fix Lint * Add pretrained model * Add training info for retinanet Co-authored-by: Hans Gaiser <[email protected]> Co-authored-by: Hans Gaiser <[email protected]> Co-authored-by: Hans Gaiser <[email protected]>
1 parent 42e7f1f commit 5bb81c8

File tree

13 files changed

+884
-169
lines changed

13 files changed

+884
-169
lines changed

docs/source/models.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ the instances set of COCO train2017 and evaluated on COCO val2017.
350350
Network box AP mask AP keypoint AP
351351
================================ ======= ======== ===========
352352
Faster R-CNN ResNet-50 FPN 37.0 - -
353+
RetinaNet ResNet-50 FPN 36.4 - -
353354
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
354355
================================ ======= ======== ===========
355356

@@ -405,6 +406,7 @@ precision-recall.
405406
Network train time (s / it) test time (s / it) memory (GB)
406407
============================== =================== ================== ===========
407408
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
409+
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
408410
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
409411
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
410412
============================== =================== ================== ===========
@@ -416,6 +418,12 @@ Faster R-CNN
416418
.. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn
417419

418420

421+
RetinaNet
422+
------------
423+
424+
.. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn
425+
426+
419427
Mask R-CNN
420428
----------
421429

references/detection/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
2727
--lr-steps 16 22 --aspect-ratio-group-factor 3
2828
```
2929

30+
### RetinaNet
31+
```
32+
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
33+
--dataset coco --model retinanet_resnet50_fpn --epochs 26\
34+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
35+
```
36+
3037

3138
### Mask R-CNN
3239
```
Binary file not shown.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .faster_rcnn import *
22
from .mask_rcnn import *
33
from .keypoint_rcnn import *
4+
from .retinanet import *
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2+
import torch
3+
from torch import nn
4+
5+
from torch.jit.annotations import List, Optional, Dict
6+
from .image_list import ImageList
7+
8+
9+
class AnchorGenerator(nn.Module):
10+
"""
11+
Module that generates anchors for a set of feature maps and
12+
image sizes.
13+
14+
The module support computing anchors at multiple sizes and aspect ratios
15+
per feature map. This module assumes aspect ratio = height / width for
16+
each anchor.
17+
18+
sizes and aspect_ratios should have the same number of elements, and it should
19+
correspond to the number of feature maps.
20+
21+
sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
22+
and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
23+
per spatial location for feature map i.
24+
25+
Arguments:
26+
sizes (Tuple[Tuple[int]]):
27+
aspect_ratios (Tuple[Tuple[float]]):
28+
"""
29+
30+
__annotations__ = {
31+
"cell_anchors": Optional[List[torch.Tensor]],
32+
"_cache": Dict[str, List[torch.Tensor]]
33+
}
34+
35+
def __init__(
36+
self,
37+
sizes=((128, 256, 512),),
38+
aspect_ratios=((0.5, 1.0, 2.0),),
39+
):
40+
super(AnchorGenerator, self).__init__()
41+
42+
if not isinstance(sizes[0], (list, tuple)):
43+
# TODO change this
44+
sizes = tuple((s,) for s in sizes)
45+
if not isinstance(aspect_ratios[0], (list, tuple)):
46+
aspect_ratios = (aspect_ratios,) * len(sizes)
47+
48+
assert len(sizes) == len(aspect_ratios)
49+
50+
self.sizes = sizes
51+
self.aspect_ratios = aspect_ratios
52+
self.cell_anchors = None
53+
self._cache = {}
54+
55+
# TODO: https://github.com/pytorch/pytorch/issues/26792
56+
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
57+
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
58+
# This method assumes aspect ratio = height / width for an anchor.
59+
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
60+
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821
61+
scales = torch.as_tensor(scales, dtype=dtype, device=device)
62+
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
63+
h_ratios = torch.sqrt(aspect_ratios)
64+
w_ratios = 1 / h_ratios
65+
66+
ws = (w_ratios[:, None] * scales[None, :]).view(-1)
67+
hs = (h_ratios[:, None] * scales[None, :]).view(-1)
68+
69+
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
70+
return base_anchors.round()
71+
72+
def set_cell_anchors(self, dtype, device):
73+
# type: (int, Device) -> None # noqa: F821
74+
if self.cell_anchors is not None:
75+
cell_anchors = self.cell_anchors
76+
assert cell_anchors is not None
77+
# suppose that all anchors have the same device
78+
# which is a valid assumption in the current state of the codebase
79+
if cell_anchors[0].device == device:
80+
return
81+
82+
cell_anchors = [
83+
self.generate_anchors(
84+
sizes,
85+
aspect_ratios,
86+
dtype,
87+
device
88+
)
89+
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
90+
]
91+
self.cell_anchors = cell_anchors
92+
93+
def num_anchors_per_location(self):
94+
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
95+
96+
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
97+
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
98+
def grid_anchors(self, grid_sizes, strides):
99+
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
100+
anchors = []
101+
cell_anchors = self.cell_anchors
102+
assert cell_anchors is not None
103+
assert len(grid_sizes) == len(strides) == len(cell_anchors)
104+
105+
for size, stride, base_anchors in zip(
106+
grid_sizes, strides, cell_anchors
107+
):
108+
grid_height, grid_width = size
109+
stride_height, stride_width = stride
110+
device = base_anchors.device
111+
112+
# For output anchor, compute [x_center, y_center, x_center, y_center]
113+
shifts_x = torch.arange(
114+
0, grid_width, dtype=torch.float32, device=device
115+
) * stride_width
116+
shifts_y = torch.arange(
117+
0, grid_height, dtype=torch.float32, device=device
118+
) * stride_height
119+
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
120+
shift_x = shift_x.reshape(-1)
121+
shift_y = shift_y.reshape(-1)
122+
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
123+
124+
# For every (base anchor, output anchor) pair,
125+
# offset each zero-centered base anchor by the center of the output anchor.
126+
anchors.append(
127+
(shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
128+
)
129+
130+
return anchors
131+
132+
def cached_grid_anchors(self, grid_sizes, strides):
133+
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
134+
key = str(grid_sizes) + str(strides)
135+
if key in self._cache:
136+
return self._cache[key]
137+
anchors = self.grid_anchors(grid_sizes, strides)
138+
self._cache[key] = anchors
139+
return anchors
140+
141+
def forward(self, image_list, feature_maps):
142+
# type: (ImageList, List[Tensor]) -> List[Tensor]
143+
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
144+
image_size = image_list.tensors.shape[-2:]
145+
dtype, device = feature_maps[0].dtype, feature_maps[0].device
146+
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
147+
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
148+
self.set_cell_anchors(dtype, device)
149+
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
150+
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
151+
for i, (image_height, image_width) in enumerate(image_list.image_sizes):
152+
anchors_in_image = []
153+
for anchors_per_feature_map in anchors_over_all_feature_maps:
154+
anchors_in_image.append(anchors_per_feature_map)
155+
anchors.append(anchors_in_image)
156+
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
157+
# Clear the cache in case that memory leaks.
158+
self._cache.clear()
159+
return anchors

torchvision/models/detection/backbone_utils.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,17 @@ class BackboneWithFPN(nn.Module):
2525
Attributes:
2626
out_channels (int): the number of channels in the FPN
2727
"""
28-
def __init__(self, backbone, return_layers, in_channels_list, out_channels):
28+
def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
2929
super(BackboneWithFPN, self).__init__()
30+
31+
if extra_blocks is None:
32+
extra_blocks = LastLevelMaxPool()
33+
3034
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
3135
self.fpn = FeaturePyramidNetwork(
3236
in_channels_list=in_channels_list,
3337
out_channels=out_channels,
34-
extra_blocks=LastLevelMaxPool(),
38+
extra_blocks=extra_blocks,
3539
)
3640
self.out_channels = out_channels
3741

@@ -41,7 +45,14 @@ def forward(self, x):
4145
return x
4246

4347

44-
def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=3):
48+
def resnet_fpn_backbone(
49+
backbone_name,
50+
pretrained,
51+
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
52+
trainable_layers=3,
53+
returned_layers=None,
54+
extra_blocks=None
55+
):
4556
"""
4657
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
4758
@@ -82,14 +93,15 @@ def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.Frozen
8293
if all([not name.startswith(layer) for layer in layers_to_train]):
8394
parameter.requires_grad_(False)
8495

85-
return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}
96+
if extra_blocks is None:
97+
extra_blocks = LastLevelMaxPool()
98+
99+
if returned_layers is None:
100+
returned_layers = [1, 2, 3, 4]
101+
assert min(returned_layers) > 0 and max(returned_layers) < 5
102+
return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)}
86103

87104
in_channels_stage2 = backbone.inplanes // 8
88-
in_channels_list = [
89-
in_channels_stage2,
90-
in_channels_stage2 * 2,
91-
in_channels_stage2 * 4,
92-
in_channels_stage2 * 8,
93-
]
105+
in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
94106
out_channels = 256
95-
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels)
107+
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)

torchvision/models/detection/faster_rcnn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
from ..utils import load_state_dict_from_url
1111

12+
from .anchor_utils import AnchorGenerator
1213
from .generalized_rcnn import GeneralizedRCNN
13-
from .rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
14+
from .rpn import RPNHead, RegionProposalNetwork
1415
from .roi_heads import RoIHeads
1516
from .transform import GeneralizedRCNNTransform
1617
from .backbone_utils import resnet_fpn_backbone

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class KeypointRCNN(FasterRCNN):
103103
>>> import torch
104104
>>> import torchvision
105105
>>> from torchvision.models.detection import KeypointRCNN
106-
>>> from torchvision.models.detection.rpn import AnchorGenerator
106+
>>> from torchvision.models.detection.anchor_utils import AnchorGenerator
107107
>>>
108108
>>> # load a pre-trained model for classification and return
109109
>>> # only the features

torchvision/models/detection/mask_rcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class MaskRCNN(FasterRCNN):
107107
>>> import torch
108108
>>> import torchvision
109109
>>> from torchvision.models.detection import MaskRCNN
110-
>>> from torchvision.models.detection.rpn import AnchorGenerator
110+
>>> from torchvision.models.detection.anchor_utils import AnchorGenerator
111111
>>>
112112
>>> # load a pre-trained model for classification and return
113113
>>> # only the features

0 commit comments

Comments
 (0)