diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 14f9521886c..9ba090ad09b 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -1,6 +1,8 @@ import torch +from torch import Tensor import torch.nn as nn from .utils import load_state_dict_from_url +from typing import Callable, Any, List __all__ = [ @@ -16,8 +18,7 @@ } -def channel_shuffle(x, groups): - # type: (torch.Tensor, int) -> torch.Tensor +def channel_shuffle(x: Tensor, groups: int) -> Tensor: batchsize, num_channels, height, width = x.data.size() channels_per_group = num_channels // groups @@ -34,7 +35,12 @@ def channel_shuffle(x, groups): class InvertedResidual(nn.Module): - def __init__(self, inp, oup, stride): + def __init__( + self, + inp: int, + oup: int, + stride: int + ) -> None: super(InvertedResidual, self).__init__() if not (1 <= stride <= 3): @@ -68,10 +74,17 @@ def __init__(self, inp, oup, stride): ) @staticmethod - def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): + def depthwise_conv( + i: int, + o: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = False + ) -> nn.Conv2d: return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: if self.stride == 1: x1, x2 = x.chunk(2, dim=1) out = torch.cat((x1, self.branch2(x2)), dim=1) @@ -84,7 +97,13 @@ def forward(self, x): class ShuffleNetV2(nn.Module): - def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual): + def __init__( + self, + stages_repeats: List[int], + stages_out_channels: List[int], + num_classes: int = 1000, + inverted_residual: Callable[..., nn.Module] = InvertedResidual + ) -> None: super(ShuffleNetV2, self).__init__() if len(stages_repeats) != 3: @@ -104,6 +123,10 @@ def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, invert self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + # Static annotations for mypy + self.stage2: nn.Sequential + self.stage3: nn.Sequential + self.stage4: nn.Sequential stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] for name, repeats, output_channels in zip( stage_names, stages_repeats, self._stage_out_channels[1:]): @@ -122,7 +145,7 @@ def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, invert self.fc = nn.Linear(output_channels, num_classes) - def _forward_impl(self, x): + def _forward_impl(self, x: Tensor) -> Tensor: # See note [TorchScript super()] x = self.conv1(x) x = self.maxpool(x) @@ -134,11 +157,11 @@ def _forward_impl(self, x): x = self.fc(x) return x - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): +def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2: model = ShuffleNetV2(*args, **kwargs) if pretrained: @@ -152,7 +175,7 @@ def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): return model -def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs): +def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 0.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" @@ -166,7 +189,7 @@ def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs): [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) -def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs): +def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 1.0x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" @@ -180,7 +203,7 @@ def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs): [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) -def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs): +def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 1.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" @@ -194,7 +217,7 @@ def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs): [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) -def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs): +def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 2.0x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"