Skip to content

Commit 2ce6b18

Browse files
authored
Added annotation typing to resnet (#2863)
* style: Added annotation typing for resnet * fix: Fixed annotation to pass classes * fix: Fixed annotation typing * fix: Fixed annotation typing * fix: Fixed annotation typing for resnet * refactor: Removed un-necessary import * fix: Fixed constructor typing * style: Added black formatting on _resnet
1 parent 65591f1 commit 2ce6b18

File tree

1 file changed

+66
-30
lines changed

1 file changed

+66
-30
lines changed

torchvision/models/resnet.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
2+
from torch import Tensor
23
import torch.nn as nn
34
from .utils import load_state_dict_from_url
5+
from typing import Type, Any, Callable, Union, List, Optional
46

57

68
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
@@ -21,22 +23,31 @@
2123
}
2224

2325

24-
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
26+
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
2527
"""3x3 convolution with padding"""
2628
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
2729
padding=dilation, groups=groups, bias=False, dilation=dilation)
2830

2931

30-
def conv1x1(in_planes, out_planes, stride=1):
32+
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
3133
"""1x1 convolution"""
3234
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
3335

3436

3537
class BasicBlock(nn.Module):
36-
expansion = 1
37-
38-
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
39-
base_width=64, dilation=1, norm_layer=None):
38+
expansion: int = 1
39+
40+
def __init__(
41+
self,
42+
inplanes: int,
43+
planes: int,
44+
stride: int = 1,
45+
downsample: Optional[nn.Module] = None,
46+
groups: int = 1,
47+
base_width: int = 64,
48+
dilation: int = 1,
49+
norm_layer: Optional[Callable[..., nn.Module]] = None
50+
) -> None:
4051
super(BasicBlock, self).__init__()
4152
if norm_layer is None:
4253
norm_layer = nn.BatchNorm2d
@@ -53,7 +64,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
5364
self.downsample = downsample
5465
self.stride = stride
5566

56-
def forward(self, x):
67+
def forward(self, x: Tensor) -> Tensor:
5768
identity = x
5869

5970
out = self.conv1(x)
@@ -79,10 +90,19 @@ class Bottleneck(nn.Module):
7990
# This variant is also known as ResNet V1.5 and improves accuracy according to
8091
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
8192

82-
expansion = 4
83-
84-
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
85-
base_width=64, dilation=1, norm_layer=None):
93+
expansion: int = 4
94+
95+
def __init__(
96+
self,
97+
inplanes: int,
98+
planes: int,
99+
stride: int = 1,
100+
downsample: Optional[nn.Module] = None,
101+
groups: int = 1,
102+
base_width: int = 64,
103+
dilation: int = 1,
104+
norm_layer: Optional[Callable[..., nn.Module]] = None
105+
) -> None:
86106
super(Bottleneck, self).__init__()
87107
if norm_layer is None:
88108
norm_layer = nn.BatchNorm2d
@@ -98,7 +118,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
98118
self.downsample = downsample
99119
self.stride = stride
100120

101-
def forward(self, x):
121+
def forward(self, x: Tensor) -> Tensor:
102122
identity = x
103123

104124
out = self.conv1(x)
@@ -123,9 +143,17 @@ def forward(self, x):
123143

124144
class ResNet(nn.Module):
125145

126-
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
127-
groups=1, width_per_group=64, replace_stride_with_dilation=None,
128-
norm_layer=None):
146+
def __init__(
147+
self,
148+
block: Type[Union[BasicBlock, Bottleneck]],
149+
layers: List[int],
150+
num_classes: int = 1000,
151+
zero_init_residual: bool = False,
152+
groups: int = 1,
153+
width_per_group: int = 64,
154+
replace_stride_with_dilation: Optional[List[bool]] = None,
155+
norm_layer: Optional[Callable[..., nn.Module]] = None
156+
) -> None:
129157
super(ResNet, self).__init__()
130158
if norm_layer is None:
131159
norm_layer = nn.BatchNorm2d
@@ -170,11 +198,12 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
170198
if zero_init_residual:
171199
for m in self.modules():
172200
if isinstance(m, Bottleneck):
173-
nn.init.constant_(m.bn3.weight, 0)
201+
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
174202
elif isinstance(m, BasicBlock):
175-
nn.init.constant_(m.bn2.weight, 0)
203+
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
176204

177-
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
205+
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
206+
stride: int = 1, dilate: bool = False) -> nn.Sequential:
178207
norm_layer = self._norm_layer
179208
downsample = None
180209
previous_dilation = self.dilation
@@ -198,7 +227,7 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
198227

199228
return nn.Sequential(*layers)
200229

201-
def _forward_impl(self, x):
230+
def _forward_impl(self, x: Tensor) -> Tensor:
202231
# See note [TorchScript super()]
203232
x = self.conv1(x)
204233
x = self.bn1(x)
@@ -216,11 +245,18 @@ def _forward_impl(self, x):
216245

217246
return x
218247

219-
def forward(self, x):
248+
def forward(self, x: Tensor) -> Tensor:
220249
return self._forward_impl(x)
221250

222251

223-
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
252+
def _resnet(
253+
arch: str,
254+
block: Type[Union[BasicBlock, Bottleneck]],
255+
layers: List[int],
256+
pretrained: bool,
257+
progress: bool,
258+
**kwargs: Any
259+
) -> ResNet:
224260
model = ResNet(block, layers, **kwargs)
225261
if pretrained:
226262
state_dict = load_state_dict_from_url(model_urls[arch],
@@ -229,7 +265,7 @@ def _resnet(arch, block, layers, pretrained, progress, **kwargs):
229265
return model
230266

231267

232-
def resnet18(pretrained=False, progress=True, **kwargs):
268+
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
233269
r"""ResNet-18 model from
234270
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
235271
@@ -241,7 +277,7 @@ def resnet18(pretrained=False, progress=True, **kwargs):
241277
**kwargs)
242278

243279

244-
def resnet34(pretrained=False, progress=True, **kwargs):
280+
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
245281
r"""ResNet-34 model from
246282
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
247283
@@ -253,7 +289,7 @@ def resnet34(pretrained=False, progress=True, **kwargs):
253289
**kwargs)
254290

255291

256-
def resnet50(pretrained=False, progress=True, **kwargs):
292+
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
257293
r"""ResNet-50 model from
258294
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
259295
@@ -265,7 +301,7 @@ def resnet50(pretrained=False, progress=True, **kwargs):
265301
**kwargs)
266302

267303

268-
def resnet101(pretrained=False, progress=True, **kwargs):
304+
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
269305
r"""ResNet-101 model from
270306
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
271307
@@ -277,7 +313,7 @@ def resnet101(pretrained=False, progress=True, **kwargs):
277313
**kwargs)
278314

279315

280-
def resnet152(pretrained=False, progress=True, **kwargs):
316+
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
281317
r"""ResNet-152 model from
282318
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
283319
@@ -289,7 +325,7 @@ def resnet152(pretrained=False, progress=True, **kwargs):
289325
**kwargs)
290326

291327

292-
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
328+
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
293329
r"""ResNeXt-50 32x4d model from
294330
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
295331
@@ -303,7 +339,7 @@ def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
303339
pretrained, progress, **kwargs)
304340

305341

306-
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
342+
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
307343
r"""ResNeXt-101 32x8d model from
308344
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
309345
@@ -317,7 +353,7 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
317353
pretrained, progress, **kwargs)
318354

319355

320-
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
356+
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
321357
r"""Wide ResNet-50-2 model from
322358
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
323359
@@ -335,7 +371,7 @@ def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
335371
pretrained, progress, **kwargs)
336372

337373

338-
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
374+
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
339375
r"""Wide ResNet-101-2 model from
340376
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
341377

0 commit comments

Comments
 (0)