Skip to content

Added annotation typing to shufflenet #2864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 6, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions torchvision/models/shufflenetv2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from torch import Tensor
import torch.nn as nn
from .utils import load_state_dict_from_url
from typing import Callable, Any, List


__all__ = [
Expand All @@ -16,8 +18,7 @@
}


def channel_shuffle(x, groups):
# type: (torch.Tensor, int) -> torch.Tensor
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups

Expand All @@ -34,7 +35,12 @@ def channel_shuffle(x, groups):


class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride):
def __init__(
self,
inp: int,
oup: int,
stride: int
) -> None:
super(InvertedResidual, self).__init__()

if not (1 <= stride <= 3):
Expand Down Expand Up @@ -68,10 +74,17 @@ def __init__(self, inp, oup, stride):
)

@staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
def depthwise_conv(
i: int,
o: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = False
) -> nn.Conv2d:
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
Expand All @@ -84,7 +97,13 @@ def forward(self, x):


class ShuffleNetV2(nn.Module):
def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual):
def __init__(
self,
stages_repeats: List[int],
stages_out_channels: List[int],
num_classes: int = 1000,
inverted_residual: Callable[..., nn.Module] = InvertedResidual
) -> None:
super(ShuffleNetV2, self).__init__()

if len(stages_repeats) != 3:
Expand All @@ -104,6 +123,10 @@ def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, invert

self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

# Static annotations for mypy
self.stage2: nn.Sequential
self.stage3: nn.Sequential
self.stage4: nn.Sequential
Comment on lines +127 to +129
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: it would have been preferable to let stage{} be a nn.Module, as they could potentially be any type of nn.Module, not just Sequential

stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip(
stage_names, stages_repeats, self._stage_out_channels[1:]):
Expand All @@ -122,7 +145,7 @@ def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, invert

self.fc = nn.Linear(output_channels, num_classes)

def _forward_impl(self, x):
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.maxpool(x)
Expand All @@ -134,11 +157,11 @@ def _forward_impl(self, x):
x = self.fc(x)
return x

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)


def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2:
model = ShuffleNetV2(*args, **kwargs)

if pretrained:
Expand All @@ -152,7 +175,7 @@ def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
return model


def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
Expand All @@ -166,7 +189,7 @@ def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)


def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
Expand All @@ -180,7 +203,7 @@ def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)


def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 1.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
Expand All @@ -194,7 +217,7 @@ def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)


def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 2.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
Expand Down