1
1
import torch
2
+ from torch import Tensor
2
3
import torch .nn as nn
3
4
from .utils import load_state_dict_from_url
5
+ from typing import Type , Any , Callable , Union , List , Optional
4
6
5
7
6
8
__all__ = ['ResNet' , 'resnet18' , 'resnet34' , 'resnet50' , 'resnet101' ,
21
23
}
22
24
23
25
24
- def conv3x3 (in_planes , out_planes , stride = 1 , groups = 1 , dilation = 1 ) :
26
+ def conv3x3 (in_planes : int , out_planes : int , stride : int = 1 , groups : int = 1 , dilation : int = 1 ) -> nn . Conv2d :
25
27
"""3x3 convolution with padding"""
26
28
return nn .Conv2d (in_planes , out_planes , kernel_size = 3 , stride = stride ,
27
29
padding = dilation , groups = groups , bias = False , dilation = dilation )
28
30
29
31
30
- def conv1x1 (in_planes , out_planes , stride = 1 ) :
32
+ def conv1x1 (in_planes : int , out_planes : int , stride : int = 1 ) -> nn . Conv2d :
31
33
"""1x1 convolution"""
32
34
return nn .Conv2d (in_planes , out_planes , kernel_size = 1 , stride = stride , bias = False )
33
35
34
36
35
37
class BasicBlock (nn .Module ):
36
- expansion = 1
37
-
38
- def __init__ (self , inplanes , planes , stride = 1 , downsample = None , groups = 1 ,
39
- base_width = 64 , dilation = 1 , norm_layer = None ):
38
+ expansion : int = 1
39
+
40
+ def __init__ (
41
+ self ,
42
+ inplanes : int ,
43
+ planes : int ,
44
+ stride : int = 1 ,
45
+ downsample : Optional [nn .Module ] = None ,
46
+ groups : int = 1 ,
47
+ base_width : int = 64 ,
48
+ dilation : int = 1 ,
49
+ norm_layer : Optional [Callable [..., nn .Module ]] = None
50
+ ) -> None :
40
51
super (BasicBlock , self ).__init__ ()
41
52
if norm_layer is None :
42
53
norm_layer = nn .BatchNorm2d
@@ -53,7 +64,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
53
64
self .downsample = downsample
54
65
self .stride = stride
55
66
56
- def forward (self , x ) :
67
+ def forward (self , x : Tensor ) -> Tensor :
57
68
identity = x
58
69
59
70
out = self .conv1 (x )
@@ -79,10 +90,19 @@ class Bottleneck(nn.Module):
79
90
# This variant is also known as ResNet V1.5 and improves accuracy according to
80
91
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
81
92
82
- expansion = 4
83
-
84
- def __init__ (self , inplanes , planes , stride = 1 , downsample = None , groups = 1 ,
85
- base_width = 64 , dilation = 1 , norm_layer = None ):
93
+ expansion : int = 4
94
+
95
+ def __init__ (
96
+ self ,
97
+ inplanes : int ,
98
+ planes : int ,
99
+ stride : int = 1 ,
100
+ downsample : Optional [nn .Module ] = None ,
101
+ groups : int = 1 ,
102
+ base_width : int = 64 ,
103
+ dilation : int = 1 ,
104
+ norm_layer : Optional [Callable [..., nn .Module ]] = None
105
+ ) -> None :
86
106
super (Bottleneck , self ).__init__ ()
87
107
if norm_layer is None :
88
108
norm_layer = nn .BatchNorm2d
@@ -98,7 +118,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
98
118
self .downsample = downsample
99
119
self .stride = stride
100
120
101
- def forward (self , x ) :
121
+ def forward (self , x : Tensor ) -> Tensor :
102
122
identity = x
103
123
104
124
out = self .conv1 (x )
@@ -123,9 +143,17 @@ def forward(self, x):
123
143
124
144
class ResNet (nn .Module ):
125
145
126
- def __init__ (self , block , layers , num_classes = 1000 , zero_init_residual = False ,
127
- groups = 1 , width_per_group = 64 , replace_stride_with_dilation = None ,
128
- norm_layer = None ):
146
+ def __init__ (
147
+ self ,
148
+ block : Type [Union [BasicBlock , Bottleneck ]],
149
+ layers : List [int ],
150
+ num_classes : int = 1000 ,
151
+ zero_init_residual : bool = False ,
152
+ groups : int = 1 ,
153
+ width_per_group : int = 64 ,
154
+ replace_stride_with_dilation : Optional [List [bool ]] = None ,
155
+ norm_layer : Optional [Callable [..., nn .Module ]] = None
156
+ ) -> None :
129
157
super (ResNet , self ).__init__ ()
130
158
if norm_layer is None :
131
159
norm_layer = nn .BatchNorm2d
@@ -170,11 +198,12 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
170
198
if zero_init_residual :
171
199
for m in self .modules ():
172
200
if isinstance (m , Bottleneck ):
173
- nn .init .constant_ (m .bn3 .weight , 0 )
201
+ nn .init .constant_ (m .bn3 .weight , 0 ) # type: ignore[arg-type]
174
202
elif isinstance (m , BasicBlock ):
175
- nn .init .constant_ (m .bn2 .weight , 0 )
203
+ nn .init .constant_ (m .bn2 .weight , 0 ) # type: ignore[arg-type]
176
204
177
- def _make_layer (self , block , planes , blocks , stride = 1 , dilate = False ):
205
+ def _make_layer (self , block : Type [Union [BasicBlock , Bottleneck ]], planes : int , blocks : int ,
206
+ stride : int = 1 , dilate : bool = False ) -> nn .Sequential :
178
207
norm_layer = self ._norm_layer
179
208
downsample = None
180
209
previous_dilation = self .dilation
@@ -198,7 +227,7 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
198
227
199
228
return nn .Sequential (* layers )
200
229
201
- def _forward_impl (self , x ) :
230
+ def _forward_impl (self , x : Tensor ) -> Tensor :
202
231
# See note [TorchScript super()]
203
232
x = self .conv1 (x )
204
233
x = self .bn1 (x )
@@ -216,11 +245,18 @@ def _forward_impl(self, x):
216
245
217
246
return x
218
247
219
- def forward (self , x ) :
248
+ def forward (self , x : Tensor ) -> Tensor :
220
249
return self ._forward_impl (x )
221
250
222
251
223
- def _resnet (arch , block , layers , pretrained , progress , ** kwargs ):
252
+ def _resnet (
253
+ arch : str ,
254
+ block : Type [Union [BasicBlock , Bottleneck ]],
255
+ layers : List [int ],
256
+ pretrained : bool ,
257
+ progress : bool ,
258
+ ** kwargs : Any
259
+ ) -> ResNet :
224
260
model = ResNet (block , layers , ** kwargs )
225
261
if pretrained :
226
262
state_dict = load_state_dict_from_url (model_urls [arch ],
@@ -229,7 +265,7 @@ def _resnet(arch, block, layers, pretrained, progress, **kwargs):
229
265
return model
230
266
231
267
232
- def resnet18 (pretrained = False , progress = True , ** kwargs ) :
268
+ def resnet18 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> ResNet :
233
269
r"""ResNet-18 model from
234
270
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
235
271
@@ -241,7 +277,7 @@ def resnet18(pretrained=False, progress=True, **kwargs):
241
277
** kwargs )
242
278
243
279
244
- def resnet34 (pretrained = False , progress = True , ** kwargs ) :
280
+ def resnet34 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> ResNet :
245
281
r"""ResNet-34 model from
246
282
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
247
283
@@ -253,7 +289,7 @@ def resnet34(pretrained=False, progress=True, **kwargs):
253
289
** kwargs )
254
290
255
291
256
- def resnet50 (pretrained = False , progress = True , ** kwargs ) :
292
+ def resnet50 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> ResNet :
257
293
r"""ResNet-50 model from
258
294
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
259
295
@@ -265,7 +301,7 @@ def resnet50(pretrained=False, progress=True, **kwargs):
265
301
** kwargs )
266
302
267
303
268
- def resnet101 (pretrained = False , progress = True , ** kwargs ) :
304
+ def resnet101 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> ResNet :
269
305
r"""ResNet-101 model from
270
306
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
271
307
@@ -277,7 +313,7 @@ def resnet101(pretrained=False, progress=True, **kwargs):
277
313
** kwargs )
278
314
279
315
280
- def resnet152 (pretrained = False , progress = True , ** kwargs ) :
316
+ def resnet152 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> ResNet :
281
317
r"""ResNet-152 model from
282
318
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
283
319
@@ -289,7 +325,7 @@ def resnet152(pretrained=False, progress=True, **kwargs):
289
325
** kwargs )
290
326
291
327
292
- def resnext50_32x4d (pretrained = False , progress = True , ** kwargs ) :
328
+ def resnext50_32x4d (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> ResNet :
293
329
r"""ResNeXt-50 32x4d model from
294
330
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
295
331
@@ -303,7 +339,7 @@ def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
303
339
pretrained , progress , ** kwargs )
304
340
305
341
306
- def resnext101_32x8d (pretrained = False , progress = True , ** kwargs ) :
342
+ def resnext101_32x8d (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> ResNet :
307
343
r"""ResNeXt-101 32x8d model from
308
344
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
309
345
@@ -317,7 +353,7 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
317
353
pretrained , progress , ** kwargs )
318
354
319
355
320
- def wide_resnet50_2 (pretrained = False , progress = True , ** kwargs ) :
356
+ def wide_resnet50_2 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> ResNet :
321
357
r"""Wide ResNet-50-2 model from
322
358
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
323
359
@@ -335,7 +371,7 @@ def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
335
371
pretrained , progress , ** kwargs )
336
372
337
373
338
- def wide_resnet101_2 (pretrained = False , progress = True , ** kwargs ) :
374
+ def wide_resnet101_2 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> ResNet :
339
375
r"""Wide ResNet-101-2 model from
340
376
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
341
377
0 commit comments