From 053a956fa1c22045b6905126e7f7c0f666dd1f03 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 15:50:04 +0100 Subject: [PATCH 1/4] Densenet121 added --- torchvision/prototype/models/__init__.py | 1 + torchvision/prototype/models/densenet.py | 79 ++++++++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 torchvision/prototype/models/densenet.py diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index c77d27e8009..33e5b8ae949 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -1,3 +1,4 @@ from .resnet import * +from .densenet import * from . import detection from . import quantization diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py new file mode 100644 index 00000000000..30f51e336aa --- /dev/null +++ b/torchvision/prototype/models/densenet.py @@ -0,0 +1,79 @@ +import re +import warnings +from functools import partial +from typing import Any, Optional, Tuple + +import torch.nn as nn + +from ...models.densenet import _DenseLayer, _DenseBlock, _Transition, DenseNet +from ..transforms.presets import ImageNetEval +from ._api import Weights, WeightEntry +from ._meta import _IMAGENET_CATEGORIES + + +__all__ = ["DenseNet", "DenseNet121Weights", "densenet121"] + +def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None: + # '.'s are no longer allowed in module names, but previous _DenseLayer + # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. + # They are also in the checkpoints in model_urls. This pattern is used + # to find such keys. + pattern = re.compile( + r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) + + state_dict = weights.state_dict(progress=progress) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + model.load_state_dict(state_dict) + + +def _densenet( + growth_rate: int, + block_config: Tuple[int, int, int, int], + num_init_features: int, + weights: Weights, + progress: bool, + **kwargs: Any, +) -> DenseNet: + if weights is not None: + kwargs["num_classes"] = len(weights.meta["categories"]) + + model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) + + if weights is not None: + _load_state_dict(model=model, weights=weights, progress=progress) + + return model + + +_common_meta = { + "size": (224, 224), + "categories": _IMAGENET_CATEGORIES, +} + + +class DenseNet121Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/densenet121-a639ec97.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 74.434, + "acc@5": 91.972, + }, + ) + + +def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = DenseNet121Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = DenseNet121Weights.verify(weights) + + return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) From 59e52e630408e8d95e628052a294e5eeddf1aff7 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 16:06:32 +0100 Subject: [PATCH 2/4] All densenet prototypes added --- torchvision/prototype/models/densenet.py | 81 +++++++++++++++++++++++- 1 file changed, 79 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py index 30f51e336aa..e56e5d3b15f 100644 --- a/torchvision/prototype/models/densenet.py +++ b/torchvision/prototype/models/densenet.py @@ -11,7 +11,18 @@ from ._meta import _IMAGENET_CATEGORIES -__all__ = ["DenseNet", "DenseNet121Weights", "densenet121"] +__all__ = [ + "DenseNet", + "DenseNet121Weights", + "DenseNet161Weights", + "DenseNet169Weights", + "DenseNet201Weights", + "densenet121", + "densenet161", + "densenet169", + "densenet201", +] + def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None: # '.'s are no longer allowed in module names, but previous _DenseLayer @@ -68,7 +79,46 @@ class DenseNet121Weights(Weights): "acc@5": 91.972, }, ) - + + +class DenseNet161Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/densenet161-8d451a50.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 77.138, + "acc@5": 93.560, + }, + ) + + +class DenseNet169Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 75.600, + "acc@5": 92.806, + }, + ) + + +class DenseNet201Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/densenet201-c1103571.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 76.896, + "acc@5": 93.370, + }, + ) + def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: if "pretrained" in kwargs: @@ -77,3 +127,30 @@ def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = T weights = DenseNet121Weights.verify(weights) return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) + + +def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = DenseNet161Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = DenseNet161Weights.verify(weights) + + return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) + + +def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = DenseNet169Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = DenseNet169Weights.verify(weights) + + return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) + + +def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = DenseNet201Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = DenseNet201Weights.verify(weights) + + return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) From dfbbd4a9a2f2035abaa9815ab8a8bde298491036 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 16:27:51 +0100 Subject: [PATCH 3/4] fixing flake8 errors --- torchvision/prototype/models/densenet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py index e56e5d3b15f..a6d9e4f4d59 100644 --- a/torchvision/prototype/models/densenet.py +++ b/torchvision/prototype/models/densenet.py @@ -5,7 +5,7 @@ import torch.nn as nn -from ...models.densenet import _DenseLayer, _DenseBlock, _Transition, DenseNet +from ...models.densenet import DenseNet from ..transforms.presets import ImageNetEval from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES From ca29cc6c9276c468f57a0d2f51e6e13df892c355 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 17:04:49 +0100 Subject: [PATCH 4/4] fixing argument type --- torchvision/prototype/models/densenet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py index a6d9e4f4d59..5dcd80a0bac 100644 --- a/torchvision/prototype/models/densenet.py +++ b/torchvision/prototype/models/densenet.py @@ -47,7 +47,7 @@ def _densenet( growth_rate: int, block_config: Tuple[int, int, int, int], num_init_features: int, - weights: Weights, + weights: Optional[Weights], progress: bool, **kwargs: Any, ) -> DenseNet: