Skip to content

Commit b9d7344

Browse files
committed
Use P6 and P7 in FPN for retinanet.
1 parent 5218afd commit b9d7344

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchvision/models/detection/retinanet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .rpn import AnchorGenerator
1010
from .transform import GeneralizedRCNNTransform
1111
from .backbone_utils import resnet_fpn_backbone
12+
from ...ops.feature_pyramid_network import LastLevelP6P7
1213

1314

1415
__all__ = [
@@ -362,7 +363,7 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
362363
if pretrained:
363364
# no need to download the backbone if pretrained is set
364365
pretrained_backbone = False
365-
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
366+
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, extra_blocks=LastLevelP6P7(256, 256))
366367
model = RetinaNet(backbone, num_classes, **kwargs)
367368
if pretrained:
368369
state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'],

0 commit comments

Comments
 (0)