@@ -672,17 +672,22 @@ def forward(self, images, targets=None):
672
672
return self .eager_outputs (losses , detections )
673
673
674
674
675
+ _COMMON_META = {
676
+ "task" : "image_object_detection" ,
677
+ "architecture" : "RetinaNet" ,
678
+ "categories" : _COCO_CATEGORIES ,
679
+ "interpolation" : InterpolationMode .BILINEAR ,
680
+ }
681
+
682
+
675
683
class RetinaNet_ResNet50_FPN_Weights (WeightsEnum ):
676
684
COCO_V1 = Weights (
677
685
url = "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth" ,
678
686
transforms = ObjectDetection ,
679
687
meta = {
680
- "task" : "image_object_detection" ,
681
- "architecture" : "RetinaNet" ,
688
+ ** _COMMON_META ,
682
689
"publication_year" : 2017 ,
683
690
"num_params" : 34014999 ,
684
- "categories" : _COCO_CATEGORIES ,
685
- "interpolation" : InterpolationMode .BILINEAR ,
686
691
"recipe" : "https://github.com/pytorch/vision/tree/main/references/detection#retinanet" ,
687
692
"map" : 36.4 ,
688
693
},
@@ -691,7 +696,18 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
691
696
692
697
693
698
class RetinaNet_ResNet50_FPN_V2_Weights (WeightsEnum ):
694
- pass
699
+ COCO_V1 = Weights (
700
+ url = "https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth" ,
701
+ transforms = ObjectDetection ,
702
+ meta = {
703
+ ** _COMMON_META ,
704
+ "publication_year" : 2019 ,
705
+ "num_params" : 38198935 ,
706
+ "recipe" : "https://github.com/pytorch/vision/pull/5756" ,
707
+ "map" : 41.5 ,
708
+ },
709
+ )
710
+ DEFAULT = COCO_V1
695
711
696
712
697
713
@handle_legacy_interface (
0 commit comments