1
+ from torch import nn
2
+ from typing import Any , Optional
1
3
from .._utils import IntermediateLayerGetter
2
4
from ..._internally_replaced_utils import load_state_dict_from_url
3
5
from .. import mobilenetv3
22
24
}
23
25
24
26
25
- def _segm_model (name , backbone_name , num_classes , aux , pretrained_backbone = True ):
27
+ def _segm_model (
28
+ name : str ,
29
+ backbone_name : str ,
30
+ num_classes : int ,
31
+ aux : Optional [bool ],
32
+ pretrained_backbone : bool = True
33
+ ) -> nn .Module :
26
34
if 'resnet' in backbone_name :
27
35
backbone = resnet .__dict__ [backbone_name ](
28
36
pretrained = pretrained_backbone ,
@@ -66,7 +74,15 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
66
74
return model
67
75
68
76
69
- def _load_model (arch_type , backbone , pretrained , progress , num_classes , aux_loss , ** kwargs ):
77
+ def _load_model (
78
+ arch_type : str ,
79
+ backbone : str ,
80
+ pretrained : bool ,
81
+ progress : bool ,
82
+ num_classes : int ,
83
+ aux_loss : Optional [bool ],
84
+ ** kwargs : Any
85
+ ) -> nn .Module :
70
86
if pretrained :
71
87
aux_loss = True
72
88
kwargs ["pretrained_backbone" ] = False
@@ -76,7 +92,7 @@ def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss
76
92
return model
77
93
78
94
79
- def _load_weights (model , arch_type , backbone , progress ) :
95
+ def _load_weights (model : nn . Module , arch_type : str , backbone : str , progress : bool ) -> None :
80
96
arch = arch_type + '_' + backbone + '_coco'
81
97
model_url = model_urls .get (arch , None )
82
98
if model_url is None :
@@ -86,7 +102,7 @@ def _load_weights(model, arch_type, backbone, progress):
86
102
model .load_state_dict (state_dict )
87
103
88
104
89
- def _segm_lraspp_mobilenetv3 (backbone_name , num_classes , pretrained_backbone = True ):
105
+ def _segm_lraspp_mobilenetv3 (backbone_name : str , num_classes : int , pretrained_backbone : bool = True ) -> LRASPP :
90
106
backbone = mobilenetv3 .__dict__ [backbone_name ](pretrained = pretrained_backbone , dilated = True ).features
91
107
92
108
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
@@ -103,8 +119,13 @@ def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=Tru
103
119
return model
104
120
105
121
106
- def fcn_resnet50 (pretrained = False , progress = True ,
107
- num_classes = 21 , aux_loss = None , ** kwargs ):
122
+ def fcn_resnet50 (
123
+ pretrained : bool = False ,
124
+ progress : bool = True ,
125
+ num_classes : int = 21 ,
126
+ aux_loss : Optional [bool ] = None ,
127
+ ** kwargs : Any
128
+ ) -> nn .Module :
108
129
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
109
130
110
131
Args:
@@ -117,8 +138,13 @@ def fcn_resnet50(pretrained=False, progress=True,
117
138
return _load_model ('fcn' , 'resnet50' , pretrained , progress , num_classes , aux_loss , ** kwargs )
118
139
119
140
120
- def fcn_resnet101 (pretrained = False , progress = True ,
121
- num_classes = 21 , aux_loss = None , ** kwargs ):
141
+ def fcn_resnet101 (
142
+ pretrained : bool = False ,
143
+ progress : bool = True ,
144
+ num_classes : int = 21 ,
145
+ aux_loss : Optional [bool ] = None ,
146
+ ** kwargs : Any
147
+ ) -> nn .Module :
122
148
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
123
149
124
150
Args:
@@ -131,8 +157,13 @@ def fcn_resnet101(pretrained=False, progress=True,
131
157
return _load_model ('fcn' , 'resnet101' , pretrained , progress , num_classes , aux_loss , ** kwargs )
132
158
133
159
134
- def deeplabv3_resnet50 (pretrained = False , progress = True ,
135
- num_classes = 21 , aux_loss = None , ** kwargs ):
160
+ def deeplabv3_resnet50 (
161
+ pretrained : bool = False ,
162
+ progress : bool = True ,
163
+ num_classes : int = 21 ,
164
+ aux_loss : Optional [bool ] = None ,
165
+ ** kwargs : Any
166
+ ) -> nn .Module :
136
167
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
137
168
138
169
Args:
@@ -145,8 +176,13 @@ def deeplabv3_resnet50(pretrained=False, progress=True,
145
176
return _load_model ('deeplabv3' , 'resnet50' , pretrained , progress , num_classes , aux_loss , ** kwargs )
146
177
147
178
148
- def deeplabv3_resnet101 (pretrained = False , progress = True ,
149
- num_classes = 21 , aux_loss = None , ** kwargs ):
179
+ def deeplabv3_resnet101 (
180
+ pretrained : bool = False ,
181
+ progress : bool = True ,
182
+ num_classes : int = 21 ,
183
+ aux_loss : Optional [bool ] = None ,
184
+ ** kwargs : Any
185
+ ) -> nn .Module :
150
186
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
151
187
152
188
Args:
@@ -159,8 +195,13 @@ def deeplabv3_resnet101(pretrained=False, progress=True,
159
195
return _load_model ('deeplabv3' , 'resnet101' , pretrained , progress , num_classes , aux_loss , ** kwargs )
160
196
161
197
162
- def deeplabv3_mobilenet_v3_large (pretrained = False , progress = True ,
163
- num_classes = 21 , aux_loss = None , ** kwargs ):
198
+ def deeplabv3_mobilenet_v3_large (
199
+ pretrained : bool = False ,
200
+ progress : bool = True ,
201
+ num_classes : int = 21 ,
202
+ aux_loss : Optional [bool ] = None ,
203
+ ** kwargs : Any
204
+ ) -> nn .Module :
164
205
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
165
206
166
207
Args:
@@ -173,7 +214,12 @@ def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
173
214
return _load_model ('deeplabv3' , 'mobilenet_v3_large' , pretrained , progress , num_classes , aux_loss , ** kwargs )
174
215
175
216
176
- def lraspp_mobilenet_v3_large (pretrained = False , progress = True , num_classes = 21 , ** kwargs ):
217
+ def lraspp_mobilenet_v3_large (
218
+ pretrained : bool = False ,
219
+ progress : bool = True ,
220
+ num_classes : int = 21 ,
221
+ ** kwargs : Any
222
+ ) -> nn .Module :
177
223
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
178
224
179
225
Args:
0 commit comments