Skip to content

Commit 1a64a9a

Browse files
datumboxzhiqwangjdsgomes
authored andcommitted
[fbsync] add FCOS (#4961)
Summary: * add fcos * update fcos * add giou_loss * add BoxLinearCoder for FCOS * add full code for FCOS * add giou loss * add fcos * add __all__ * Fixing lint * Fixing lint in giou_loss.py * Add typing annotation to fcos * Add trained checkpoints * Use partial to replace lambda * Minor fixes to docstrings * Apply ufmt format * Fixing docstrings * Fixing jit scripting * Minor fixes to docstrings * Fixing jit scripting * Ignore mypy in fcos * Fixing trained checkpoints * Fixing unit-test of jit script * Fixing docstrings * Add test/expect/ModelTester.test_fcos_resnet50_fpn_expect.pkl * Fixing test_detection_model_trainable_backbone_layers * Update test_fcos_resnet50_fpn_expect.pkl * rename stride to box size * remove TODO and fix some typo * merge some code for better * impove the comments * remove decode and encode of BoxLinearCoder * remove some unnecessary hints * use the default value in detectron2. * update doc * Add unittest for BoxLinearCoder * Add types in FCOS * Add docstring for BoxLinearCoder * Minor fix for the docstring * update doc * Update fcos_resnet50_fpn_coco pretained weights url * Update torchvision/models/detection/fcos.py * Update torchvision/models/detection/fcos.py * Update torchvision/models/detection/fcos.py * Update torchvision/models/detection/fcos.py * Add FCOS model documentation * Fix typo in FCOS documentation * Add fcos to the prototype builder * Capitalize COCO_V1 * Fix params of fcos * fix bug for partial * Fixing docs indentation * Fixing docs format in giou_loss * Adopt Reference for GIoU Loss * Rename giou_loss to generalized_box_iou_loss * remove overwrite_eps * Update AP test values * Minor fixes for the docs * Minor fixes for the docs * Update torchvision/models/detection/fcos.py * Update torchvision/prototype/models/detection/fcos.py Reviewed By: jdsgomes, prabhat00155 Differential Revision: D33739385 fbshipit-source-id: 7dab616adfd0c34fe21f0153c1da51f97ef43b95 Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Zhiqiang Wang <[email protected]> Co-authored-by: Zhiqiang Wang <[email protected]> Co-authored-by: zhiqiang <[email protected]> Co-authored-by: Joao Gomes <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Joao Gomes <[email protected]>
1 parent 165a270 commit 1a64a9a

File tree

13 files changed

+979
-0
lines changed

13 files changed

+979
-0
lines changed

docs/source/models.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ The models subpackage contains definitions for the following model
597597
architectures for detection:
598598

599599
- `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_
600+
- `FCOS <https://arxiv.org/abs/1904.01355>`_
600601
- `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_
601602
- `RetinaNet <https://arxiv.org/abs/1708.02002>`_
602603
- `SSD <https://arxiv.org/abs/1512.02325>`_
@@ -642,6 +643,7 @@ Network box AP mask AP keypoint AP
642643
Faster R-CNN ResNet-50 FPN 37.0 - -
643644
Faster R-CNN MobileNetV3-Large FPN 32.8 - -
644645
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
646+
FCOS ResNet-50 FPN 39.2 - -
645647
RetinaNet ResNet-50 FPN 36.4 - -
646648
SSD300 VGG16 25.1 - -
647649
SSDlite320 MobileNetV3-Large 21.3 - -
@@ -702,6 +704,7 @@ Network train time (s / it) test time (s / it)
702704
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
703705
Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0
704706
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
707+
FCOS ResNet-50 FPN 0.1450 0.0539 3.3
705708
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
706709
SSD300 VGG16 0.2093 0.0744 1.5
707710
SSDlite320 MobileNetV3-Large 0.1773 0.0906 1.5
@@ -721,6 +724,15 @@ Faster R-CNN
721724
torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn
722725
torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn
723726

727+
FCOS
728+
----
729+
730+
.. autosummary::
731+
:toctree: generated/
732+
:template: function.rst
733+
734+
torchvision.models.detection.fcos_resnet50_fpn
735+
724736

725737
RetinaNet
726738
---------

mypy.ini

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ ignore_errors = True
7070

7171
ignore_errors = True
7272

73+
[mypy-torchvision.models.detection.fcos]
74+
75+
ignore_errors = True
76+
7377
[mypy-torchvision.ops.*]
7478

7579
ignore_errors = True

references/detection/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ torchrun --nproc_per_node=8 train.py\
4141
--lr-steps 16 22 --aspect-ratio-group-factor 3
4242
```
4343

44+
### FCOS ResNet-50 FPN
45+
```
46+
torchrun --nproc_per_node=8 train.py\
47+
--dataset coco --model fcos_resnet50_fpn --epochs 26\
48+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp
49+
```
50+
4451
### RetinaNet
4552
```
4653
torchrun --nproc_per_node=8 train.py\
Binary file not shown.

test/test_models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def _check_input_backprop(model, inputs):
218218
"retinanet_resnet50_fpn": lambda x: x[1],
219219
"ssd300_vgg16": lambda x: x[1],
220220
"ssdlite320_mobilenet_v3_large": lambda x: x[1],
221+
"fcos_resnet50_fpn": lambda x: x[1],
221222
}
222223

223224

@@ -274,6 +275,13 @@ def _check_input_backprop(model, inputs):
274275
"max_size": 224,
275276
"input_shape": (3, 224, 224),
276277
},
278+
"fcos_resnet50_fpn": {
279+
"num_classes": 2,
280+
"score_thresh": 0.05,
281+
"min_size": 224,
282+
"max_size": 224,
283+
"input_shape": (3, 224, 224),
284+
},
277285
"maskrcnn_resnet50_fpn": {
278286
"num_classes": 10,
279287
"min_size": 224,
@@ -325,6 +333,10 @@ def _check_input_backprop(model, inputs):
325333
"max_trainable": 6,
326334
"n_trn_params_per_layer": [96, 99, 138, 200, 239, 257, 266],
327335
},
336+
"fcos_resnet50_fpn": {
337+
"max_trainable": 5,
338+
"n_trn_params_per_layer": [54, 64, 83, 96, 106, 107],
339+
},
328340
}
329341

330342

test/test_models_detection_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,19 @@ def test_balanced_positive_negative_sampler(self):
2222
assert neg[0].sum() == 3
2323
assert neg[0][0:6].sum() == 3
2424

25+
def test_box_linear_coder(self):
26+
box_coder = _utils.BoxLinearCoder(normalize_by_size=True)
27+
# Generate a random 10x4 boxes tensor, with coordinates < 50.
28+
boxes = torch.rand(10, 4) * 50
29+
boxes.clamp_(min=1.0) # tiny boxes cause numerical instability in box regression
30+
boxes[:, 2:] += boxes[:, :2]
31+
32+
proposals = torch.tensor([0, 0, 101, 101] * 10).reshape(10, 4).float()
33+
34+
rel_codes = box_coder.encode_single(boxes, proposals)
35+
pred_boxes = box_coder.decode_single(rel_codes, boxes)
36+
torch.allclose(proposals, pred_boxes)
37+
2538
@pytest.mark.parametrize("train_layers, exp_froz_params", [(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)])
2639
def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params):
2740
# we know how many initial layers and parameters of the network should

torchvision/models/detection/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .retinanet import *
55
from .ssd import *
66
from .ssdlite import *
7+
from .fcos import *

torchvision/models/detection/_utils.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,83 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
217217
return pred_boxes
218218

219219

220+
class BoxLinearCoder:
221+
"""
222+
The linear box-to-box transform defined in FCOS. The transformation is parameterized
223+
by the distance from the center of (square) src box to 4 edges of the target box.
224+
"""
225+
226+
def __init__(self, normalize_by_size: bool = True) -> None:
227+
"""
228+
Args:
229+
normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
230+
"""
231+
self.normalize_by_size = normalize_by_size
232+
233+
def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
234+
"""
235+
Encode a set of proposals with respect to some reference boxes
236+
237+
Args:
238+
reference_boxes (Tensor): reference boxes
239+
proposals (Tensor): boxes to be encoded
240+
241+
Returns:
242+
Tensor: the encoded relative box offsets that can be used to
243+
decode the boxes.
244+
"""
245+
# get the center of reference_boxes
246+
reference_boxes_ctr_x = 0.5 * (reference_boxes[:, 0] + reference_boxes[:, 2])
247+
reference_boxes_ctr_y = 0.5 * (reference_boxes[:, 1] + reference_boxes[:, 3])
248+
249+
# get box regression transformation deltas
250+
target_l = reference_boxes_ctr_x - proposals[:, 0]
251+
target_t = reference_boxes_ctr_y - proposals[:, 1]
252+
target_r = proposals[:, 2] - reference_boxes_ctr_x
253+
target_b = proposals[:, 3] - reference_boxes_ctr_y
254+
255+
targets = torch.stack((target_l, target_t, target_r, target_b), dim=1)
256+
if self.normalize_by_size:
257+
reference_boxes_w = reference_boxes[:, 2] - reference_boxes[:, 0]
258+
reference_boxes_h = reference_boxes[:, 3] - reference_boxes[:, 1]
259+
reference_boxes_size = torch.stack(
260+
(reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=1
261+
)
262+
targets = targets / reference_boxes_size
263+
264+
return targets
265+
266+
def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
267+
"""
268+
From a set of original boxes and encoded relative box offsets,
269+
get the decoded boxes.
270+
271+
Args:
272+
rel_codes (Tensor): encoded boxes
273+
boxes (Tensor): reference boxes.
274+
275+
Returns:
276+
Tensor: the predicted boxes with the encoded relative box offsets.
277+
"""
278+
279+
boxes = boxes.to(rel_codes.dtype)
280+
281+
ctr_x = 0.5 * (boxes[:, 0] + boxes[:, 2])
282+
ctr_y = 0.5 * (boxes[:, 1] + boxes[:, 3])
283+
if self.normalize_by_size:
284+
boxes_w = boxes[:, 2] - boxes[:, 0]
285+
boxes_h = boxes[:, 3] - boxes[:, 1]
286+
boxes_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=1)
287+
rel_codes = rel_codes * boxes_size
288+
289+
pred_boxes1 = ctr_x - rel_codes[:, 0]
290+
pred_boxes2 = ctr_y - rel_codes[:, 1]
291+
pred_boxes3 = ctr_x + rel_codes[:, 2]
292+
pred_boxes4 = ctr_y + rel_codes[:, 3]
293+
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=1)
294+
return pred_boxes
295+
296+
220297
class Matcher:
221298
"""
222299
This class assigns to each predicted "element" (e.g., a box) a ground-truth

0 commit comments

Comments
 (0)