@@ -25,7 +25,7 @@ class GeneralizedRCNN(nn.Module):
25
25
the model
26
26
"""
27
27
28
- def __init__ (self , backbone , rpn , roi_heads , transform ) :
28
+ def __init__ (self , backbone : nn . Module , rpn : nn . Module , roi_heads : nn . Module , transform : nn . Module ) -> None :
29
29
super ().__init__ ()
30
30
_log_api_usage_once (self )
31
31
self .transform = transform
@@ -36,19 +36,26 @@ def __init__(self, backbone, rpn, roi_heads, transform):
36
36
self ._has_warned = False
37
37
38
38
@torch .jit .unused
39
- def eager_outputs (self , losses , detections ):
40
- # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
39
+ def eager_outputs (
40
+ self ,
41
+ losses : Dict [str , Tensor ],
42
+ detections : List [Dict [str , Tensor ]],
43
+ ) -> Union [Dict [str , Tensor ], List [Dict [str , Tensor ]]]:
44
+
41
45
if self .training :
42
46
return losses
43
47
44
48
return detections
45
49
46
- def forward (self , images , targets = None ):
47
- # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
50
+ def forward (
51
+ self ,
52
+ images : List [Tensor ],
53
+ targets : Optional [List [Dict [str , Tensor ]]] = None ,
54
+ ) -> Union [Tuple [Dict [str , Tensor ], List [Dict [str , Tensor ]]], Dict [str , Tensor ], List [Dict [str , Tensor ]]]:
48
55
"""
49
56
Args:
50
57
images (list[Tensor]): images to be processed
51
- targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
58
+ targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
52
59
53
60
Returns:
54
61
result (list[BoxList] or dict[Tensor]): the output from the model.
@@ -97,7 +104,7 @@ def forward(self, images, targets=None):
97
104
features = OrderedDict ([("0" , features )])
98
105
proposals , proposal_losses = self .rpn (images , features , targets )
99
106
detections , detector_losses = self .roi_heads (features , proposals , images .image_sizes , targets )
100
- detections = self .transform .postprocess (detections , images .image_sizes , original_image_sizes )
107
+ detections = self .transform .postprocess (detections , images .image_sizes , original_image_sizes ) # type: ignore[operator]
101
108
102
109
losses = {}
103
110
losses .update (detector_losses )
0 commit comments