Skip to content

Commit 185be3a

Browse files
authored
Added typing annotations to models/segmentation (#4227)
* style: Added typing annotations to segmentation/_utils * style: Added typing annotations to segmentation/segmentation * style: Added typing annotations to remaining segmentation models * style: Fixed typing of DeepLab * style: Fixed typing * fix: Fixed typing annotations & default values * Fixing python_type_check
1 parent 7947fc8 commit 185be3a

File tree

5 files changed

+97
-31
lines changed

5 files changed

+97
-31
lines changed

torchvision/models/segmentation/_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
11
from collections import OrderedDict
2+
from typing import Optional, Dict
23

3-
from torch import nn
4+
from torch import nn, Tensor
45
from torch.nn import functional as F
56

67

78
class _SimpleSegmentationModel(nn.Module):
89
__constants__ = ['aux_classifier']
910

10-
def __init__(self, backbone, classifier, aux_classifier=None):
11+
def __init__(
12+
self,
13+
backbone: nn.Module,
14+
classifier: nn.Module,
15+
aux_classifier: Optional[nn.Module] = None
16+
) -> None:
1117
super(_SimpleSegmentationModel, self).__init__()
1218
self.backbone = backbone
1319
self.classifier = classifier
1420
self.aux_classifier = aux_classifier
1521

16-
def forward(self, x):
22+
def forward(self, x: Tensor) -> Dict[str, Tensor]:
1723
input_shape = x.shape[-2:]
1824
# contract: features is a dict of tensors
1925
features = self.backbone(x)

torchvision/models/segmentation/deeplabv3.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from torch import nn
33
from torch.nn import functional as F
4+
from typing import List
45

56
from ._utils import _SimpleSegmentationModel
67

@@ -27,7 +28,7 @@ class DeepLabV3(_SimpleSegmentationModel):
2728

2829

2930
class DeepLabHead(nn.Sequential):
30-
def __init__(self, in_channels, num_classes):
31+
def __init__(self, in_channels: int, num_classes: int) -> None:
3132
super(DeepLabHead, self).__init__(
3233
ASPP(in_channels, [12, 24, 36]),
3334
nn.Conv2d(256, 256, 3, padding=1, bias=False),
@@ -38,7 +39,7 @@ def __init__(self, in_channels, num_classes):
3839

3940

4041
class ASPPConv(nn.Sequential):
41-
def __init__(self, in_channels, out_channels, dilation):
42+
def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
4243
modules = [
4344
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
4445
nn.BatchNorm2d(out_channels),
@@ -48,22 +49,22 @@ def __init__(self, in_channels, out_channels, dilation):
4849

4950

5051
class ASPPPooling(nn.Sequential):
51-
def __init__(self, in_channels, out_channels):
52+
def __init__(self, in_channels: int, out_channels: int) -> None:
5253
super(ASPPPooling, self).__init__(
5354
nn.AdaptiveAvgPool2d(1),
5455
nn.Conv2d(in_channels, out_channels, 1, bias=False),
5556
nn.BatchNorm2d(out_channels),
5657
nn.ReLU())
5758

58-
def forward(self, x):
59+
def forward(self, x: torch.Tensor) -> torch.Tensor:
5960
size = x.shape[-2:]
6061
for mod in self:
6162
x = mod(x)
6263
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
6364

6465

6566
class ASPP(nn.Module):
66-
def __init__(self, in_channels, atrous_rates, out_channels=256):
67+
def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
6768
super(ASPP, self).__init__()
6869
modules = []
6970
modules.append(nn.Sequential(
@@ -85,9 +86,9 @@ def __init__(self, in_channels, atrous_rates, out_channels=256):
8586
nn.ReLU(),
8687
nn.Dropout(0.5))
8788

88-
def forward(self, x):
89-
res = []
89+
def forward(self, x: torch.Tensor) -> torch.Tensor:
90+
_res = []
9091
for conv in self.convs:
91-
res.append(conv(x))
92-
res = torch.cat(res, dim=1)
92+
_res.append(conv(x))
93+
res = torch.cat(_res, dim=1)
9394
return self.project(res)

torchvision/models/segmentation/fcn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class FCN(_SimpleSegmentationModel):
2323

2424

2525
class FCNHead(nn.Sequential):
26-
def __init__(self, in_channels, channels):
26+
def __init__(self, in_channels: int, channels: int) -> None:
2727
inter_channels = in_channels // 4
2828
layers = [
2929
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),

torchvision/models/segmentation/lraspp.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,19 @@ class LRASPP(nn.Module):
2424
inter_channels (int, optional): the number of channels for intermediate computations.
2525
"""
2626

27-
def __init__(self, backbone, low_channels, high_channels, num_classes, inter_channels=128):
27+
def __init__(
28+
self,
29+
backbone: nn.Module,
30+
low_channels: int,
31+
high_channels: int,
32+
num_classes: int,
33+
inter_channels: int = 128
34+
) -> None:
2835
super().__init__()
2936
self.backbone = backbone
3037
self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels)
3138

32-
def forward(self, input):
39+
def forward(self, input: Tensor) -> Dict[str, Tensor]:
3340
features = self.backbone(input)
3441
out = self.classifier(features)
3542
out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False)
@@ -42,7 +49,13 @@ def forward(self, input):
4249

4350
class LRASPPHead(nn.Module):
4451

45-
def __init__(self, low_channels, high_channels, num_classes, inter_channels):
52+
def __init__(
53+
self,
54+
low_channels: int,
55+
high_channels: int,
56+
num_classes: int,
57+
inter_channels: int
58+
) -> None:
4659
super().__init__()
4760
self.cbr = nn.Sequential(
4861
nn.Conv2d(high_channels, inter_channels, 1, bias=False),

torchvision/models/segmentation/segmentation.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from torch import nn
2+
from typing import Any, Optional
13
from .._utils import IntermediateLayerGetter
24
from ..._internally_replaced_utils import load_state_dict_from_url
35
from .. import mobilenetv3
@@ -22,7 +24,13 @@
2224
}
2325

2426

25-
def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True):
27+
def _segm_model(
28+
name: str,
29+
backbone_name: str,
30+
num_classes: int,
31+
aux: Optional[bool],
32+
pretrained_backbone: bool = True
33+
) -> nn.Module:
2634
if 'resnet' in backbone_name:
2735
backbone = resnet.__dict__[backbone_name](
2836
pretrained=pretrained_backbone,
@@ -66,7 +74,15 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
6674
return model
6775

6876

69-
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
77+
def _load_model(
78+
arch_type: str,
79+
backbone: str,
80+
pretrained: bool,
81+
progress: bool,
82+
num_classes: int,
83+
aux_loss: Optional[bool],
84+
**kwargs: Any
85+
) -> nn.Module:
7086
if pretrained:
7187
aux_loss = True
7288
kwargs["pretrained_backbone"] = False
@@ -76,7 +92,7 @@ def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss
7692
return model
7793

7894

79-
def _load_weights(model, arch_type, backbone, progress):
95+
def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: bool) -> None:
8096
arch = arch_type + '_' + backbone + '_coco'
8197
model_url = model_urls.get(arch, None)
8298
if model_url is None:
@@ -86,7 +102,7 @@ def _load_weights(model, arch_type, backbone, progress):
86102
model.load_state_dict(state_dict)
87103

88104

89-
def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=True):
105+
def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_backbone: bool = True) -> LRASPP:
90106
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features
91107

92108
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
@@ -103,8 +119,13 @@ def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=Tru
103119
return model
104120

105121

106-
def fcn_resnet50(pretrained=False, progress=True,
107-
num_classes=21, aux_loss=None, **kwargs):
122+
def fcn_resnet50(
123+
pretrained: bool = False,
124+
progress: bool = True,
125+
num_classes: int = 21,
126+
aux_loss: Optional[bool] = None,
127+
**kwargs: Any
128+
) -> nn.Module:
108129
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
109130
110131
Args:
@@ -117,8 +138,13 @@ def fcn_resnet50(pretrained=False, progress=True,
117138
return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
118139

119140

120-
def fcn_resnet101(pretrained=False, progress=True,
121-
num_classes=21, aux_loss=None, **kwargs):
141+
def fcn_resnet101(
142+
pretrained: bool = False,
143+
progress: bool = True,
144+
num_classes: int = 21,
145+
aux_loss: Optional[bool] = None,
146+
**kwargs: Any
147+
) -> nn.Module:
122148
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
123149
124150
Args:
@@ -131,8 +157,13 @@ def fcn_resnet101(pretrained=False, progress=True,
131157
return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
132158

133159

134-
def deeplabv3_resnet50(pretrained=False, progress=True,
135-
num_classes=21, aux_loss=None, **kwargs):
160+
def deeplabv3_resnet50(
161+
pretrained: bool = False,
162+
progress: bool = True,
163+
num_classes: int = 21,
164+
aux_loss: Optional[bool] = None,
165+
**kwargs: Any
166+
) -> nn.Module:
136167
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
137168
138169
Args:
@@ -145,8 +176,13 @@ def deeplabv3_resnet50(pretrained=False, progress=True,
145176
return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
146177

147178

148-
def deeplabv3_resnet101(pretrained=False, progress=True,
149-
num_classes=21, aux_loss=None, **kwargs):
179+
def deeplabv3_resnet101(
180+
pretrained: bool = False,
181+
progress: bool = True,
182+
num_classes: int = 21,
183+
aux_loss: Optional[bool] = None,
184+
**kwargs: Any
185+
) -> nn.Module:
150186
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
151187
152188
Args:
@@ -159,8 +195,13 @@ def deeplabv3_resnet101(pretrained=False, progress=True,
159195
return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
160196

161197

162-
def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
163-
num_classes=21, aux_loss=None, **kwargs):
198+
def deeplabv3_mobilenet_v3_large(
199+
pretrained: bool = False,
200+
progress: bool = True,
201+
num_classes: int = 21,
202+
aux_loss: Optional[bool] = None,
203+
**kwargs: Any
204+
) -> nn.Module:
164205
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
165206
166207
Args:
@@ -173,7 +214,12 @@ def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
173214
return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)
174215

175216

176-
def lraspp_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, **kwargs):
217+
def lraspp_mobilenet_v3_large(
218+
pretrained: bool = False,
219+
progress: bool = True,
220+
num_classes: int = 21,
221+
**kwargs: Any
222+
) -> nn.Module:
177223
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
178224
179225
Args:

0 commit comments

Comments
 (0)