Skip to content

Added annotation typing to densenet #2860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ ignore_errors = True

ignore_errors = True

[mypy-torchvision.models.densenet.*]

ignore_errors = True

[mypy-torchvision.models.quantization.*]

ignore_errors = True
Expand Down
97 changes: 64 additions & 33 deletions torchvision/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import OrderedDict
from .utils import load_state_dict_from_url
from torch import Tensor
from torch.jit.annotations import List
from typing import Any, List, Tuple


__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
Expand All @@ -20,56 +20,64 @@


class _DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
def __init__(
self,
num_input_features: int,
growth_rate: int,
bn_size: int,
drop_rate: float,
memory_efficient: bool = False
) -> None:
super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.norm1: nn.BatchNorm2d
self.add_module('norm1', nn.BatchNorm2d(num_input_features))
self.relu1: nn.ReLU
self.add_module('relu1', nn.ReLU(inplace=True))
self.conv1: nn.Conv2d
Comment on lines +32 to +36
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note for the future: we should just use

self.norm1 = nn.BatchNorm2(...)

instead of the add_module alternative, as the names are not dynamic here.

self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1,
bias=False)),
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)),
bias=False))
self.norm2: nn.BatchNorm2d
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate))
self.relu2: nn.ReLU
self.add_module('relu2', nn.ReLU(inplace=True))
self.conv2: nn.Conv2d
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1,
bias=False)),
bias=False))
self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient

def bn_function(self, inputs):
# type: (List[Tensor]) -> Tensor
def bn_function(self, inputs: List[Tensor]) -> Tensor:
concated_features = torch.cat(inputs, 1)
bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484
return bottleneck_output

# todo: rewrite when torchscript supports any
def any_requires_grad(self, input):
# type: (List[Tensor]) -> bool
def any_requires_grad(self, input: List[Tensor]) -> bool:
for tensor in input:
if tensor.requires_grad:
return True
return False

@torch.jit.unused # noqa: T484
def call_checkpoint_bottleneck(self, input):
# type: (List[Tensor]) -> Tensor
def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
def closure(*inputs):
return self.bn_function(inputs)

return cp.checkpoint(closure, *input)

@torch.jit._overload_method # noqa: F811
def forward(self, input):
# type: (List[Tensor]) -> (Tensor)
def forward(self, input: List[Tensor]) -> Tensor:
pass

@torch.jit._overload_method # noqa: F811
def forward(self, input):
# type: (Tensor) -> (Tensor)
@torch.jit._overload_method # type: ignore[no-redef] # noqa: F811
def forward(self, input: Tensor) -> Tensor:
pass

# torchscript does not yet support *args, so we overload method
# allowing it to take either a List[Tensor] or single Tensor
def forward(self, input): # noqa: F811
def forward(self, input: Tensor) -> Tensor: # type: ignore[no-redef] # noqa: F811
if isinstance(input, Tensor):
prev_features = [input]
else:
Expand All @@ -93,7 +101,15 @@ def forward(self, input): # noqa: F811
class _DenseBlock(nn.ModuleDict):
_version = 2

def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
def __init__(
self,
num_layers: int,
num_input_features: int,
bn_size: int,
growth_rate: int,
drop_rate: float,
memory_efficient: bool = False
) -> None:
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(
Expand All @@ -105,7 +121,7 @@ def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_ra
)
self.add_module('denselayer%d' % (i + 1), layer)

def forward(self, init_features):
def forward(self, init_features: Tensor) -> Tensor: # type: ignore[override]
features = [init_features]
for name, layer in self.items():
new_features = layer(features)
Expand All @@ -114,7 +130,7 @@ def forward(self, init_features):


class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
def __init__(self, num_input_features: int, num_output_features: int) -> None:
super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True))
Expand All @@ -139,8 +155,16 @@ class DenseNet(nn.Module):
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
"""

def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False):
def __init__(
self,
growth_rate: int = 32,
block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
num_init_features: int = 64,
bn_size: int = 4,
drop_rate: float = 0,
num_classes: int = 1000,
memory_efficient: bool = False
) -> None:

super(DenseNet, self).__init__()

Expand Down Expand Up @@ -188,7 +212,7 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
features = self.features(x)
out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1))
Expand All @@ -197,7 +221,7 @@ def forward(self, x):
return out


def _load_state_dict(model, model_url, progress):
def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None:
# '.'s are no longer allowed in module names, but previous _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
Expand All @@ -215,15 +239,22 @@ def _load_state_dict(model, model_url, progress):
model.load_state_dict(state_dict)


def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress,
**kwargs):
def _densenet(
arch: str,
growth_rate: int,
block_config: Tuple[int, int, int, int],
num_init_features: int,
pretrained: bool,
progress: bool,
**kwargs: Any
) -> DenseNet:
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
if pretrained:
_load_state_dict(model, model_urls[arch], progress)
return model


def densenet121(pretrained=False, progress=True, **kwargs):
def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

Expand All @@ -237,7 +268,7 @@ def densenet121(pretrained=False, progress=True, **kwargs):
**kwargs)


def densenet161(pretrained=False, progress=True, **kwargs):
def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-161 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

Expand All @@ -251,7 +282,7 @@ def densenet161(pretrained=False, progress=True, **kwargs):
**kwargs)


def densenet169(pretrained=False, progress=True, **kwargs):
def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-169 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

Expand All @@ -265,7 +296,7 @@ def densenet169(pretrained=False, progress=True, **kwargs):
**kwargs)


def densenet201(pretrained=False, progress=True, **kwargs):
def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

Expand Down