diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index d0977480bef..d8356d83748 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -1,7 +1,7 @@ import copy import math from functools import partial -from typing import Any, Callable, List, Optional, Sequence +from typing import Any, Callable, Optional, List, Sequence import torch from torch import nn, Tensor @@ -263,7 +263,15 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _efficientnet_conf(width_mult: float, depth_mult: float, **kwargs: Any) -> List[MBConvConfig]: +def _efficientnet( + arch: str, + width_mult: float, + depth_mult: float, + dropout: float, + pretrained: bool, + progress: bool, + **kwargs: Any, +) -> EfficientNet: bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult) inverted_residual_setting = [ bneck_conf(1, 3, 1, 32, 16, 1), @@ -274,17 +282,6 @@ def _efficientnet_conf(width_mult: float, depth_mult: float, **kwargs: Any) -> L bneck_conf(6, 5, 2, 112, 192, 4), bneck_conf(6, 3, 1, 192, 320, 1), ] - return inverted_residual_setting - - -def _efficientnet( - arch: str, - inverted_residual_setting: List[MBConvConfig], - dropout: float, - pretrained: bool, - progress: bool, - **kwargs: Any, -) -> EfficientNet: model = EfficientNet(inverted_residual_setting, dropout, **kwargs) if pretrained: if model_urls.get(arch, None) is None: @@ -303,8 +300,7 @@ def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - inverted_residual_setting = _efficientnet_conf(width_mult=1.0, depth_mult=1.0, **kwargs) - return _efficientnet("efficientnet_b0", inverted_residual_setting, 0.2, pretrained, progress, **kwargs) + return _efficientnet("efficientnet_b0", 1.0, 1.0, 0.2, pretrained, progress, **kwargs) def efficientnet_b1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: @@ -316,8 +312,7 @@ def efficientnet_b1(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - inverted_residual_setting = _efficientnet_conf(width_mult=1.0, depth_mult=1.1, **kwargs) - return _efficientnet("efficientnet_b1", inverted_residual_setting, 0.2, pretrained, progress, **kwargs) + return _efficientnet("efficientnet_b1", 1.0, 1.1, 0.2, pretrained, progress, **kwargs) def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: @@ -329,8 +324,7 @@ def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - inverted_residual_setting = _efficientnet_conf(width_mult=1.1, depth_mult=1.2, **kwargs) - return _efficientnet("efficientnet_b2", inverted_residual_setting, 0.3, pretrained, progress, **kwargs) + return _efficientnet("efficientnet_b2", 1.1, 1.2, 0.3, pretrained, progress, **kwargs) def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: @@ -342,8 +336,7 @@ def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - inverted_residual_setting = _efficientnet_conf(width_mult=1.2, depth_mult=1.4, **kwargs) - return _efficientnet("efficientnet_b3", inverted_residual_setting, 0.3, pretrained, progress, **kwargs) + return _efficientnet("efficientnet_b3", 1.2, 1.4, 0.3, pretrained, progress, **kwargs) def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: @@ -355,8 +348,7 @@ def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - inverted_residual_setting = _efficientnet_conf(width_mult=1.4, depth_mult=1.8, **kwargs) - return _efficientnet("efficientnet_b4", inverted_residual_setting, 0.4, pretrained, progress, **kwargs) + return _efficientnet("efficientnet_b4", 1.4, 1.8, 0.4, pretrained, progress, **kwargs) def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: @@ -368,10 +360,10 @@ def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - inverted_residual_setting = _efficientnet_conf(width_mult=1.6, depth_mult=2.2, **kwargs) return _efficientnet( "efficientnet_b5", - inverted_residual_setting, + 1.6, + 2.2, 0.4, pretrained, progress, @@ -389,10 +381,10 @@ def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - inverted_residual_setting = _efficientnet_conf(width_mult=1.8, depth_mult=2.6, **kwargs) return _efficientnet( "efficientnet_b6", - inverted_residual_setting, + 1.8, + 2.6, 0.5, pretrained, progress, @@ -410,10 +402,10 @@ def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - inverted_residual_setting = _efficientnet_conf(width_mult=2.0, depth_mult=3.1, **kwargs) return _efficientnet( "efficientnet_b7", - inverted_residual_setting, + 2.0, + 3.1, 0.5, pretrained, progress, diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py index e1ec493b96a..472e390c3ce 100644 --- a/torchvision/prototype/models/efficientnet.py +++ b/torchvision/prototype/models/efficientnet.py @@ -1,11 +1,11 @@ import warnings from functools import partial -from typing import Any, List, Optional +from typing import Any, Optional from torch import nn from torchvision.transforms.functional import InterpolationMode -from ...models.efficientnet import EfficientNet, MBConvConfig, _efficientnet_conf +from ...models.efficientnet import EfficientNet, MBConvConfig from ..transforms.presets import ImageNetEval from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES @@ -33,7 +33,8 @@ def _efficientnet( - inverted_residual_setting: List[MBConvConfig], + width_mult: float, + depth_mult: float, dropout: float, weights: Optional[Weights], progress: bool, @@ -42,6 +43,17 @@ def _efficientnet( if weights is not None: kwargs["num_classes"] = len(weights.meta["categories"]) + bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult) + inverted_residual_setting = [ + bneck_conf(1, 3, 1, 32, 16, 1), + bneck_conf(6, 3, 2, 16, 24, 2), + bneck_conf(6, 5, 2, 24, 40, 2), + bneck_conf(6, 3, 2, 40, 80, 3), + bneck_conf(6, 5, 1, 80, 112, 3), + bneck_conf(6, 5, 2, 112, 192, 4), + bneck_conf(6, 3, 1, 192, 320, 1), + ] + model = EfficientNet(inverted_residual_setting, dropout, **kwargs) if weights is not None: @@ -172,8 +184,7 @@ def efficientnet_b0( warnings.warn("The argument pretrained is deprecated, please use weights instead.") weights = EfficientNetB0Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None weights = EfficientNetB0Weights.verify(weights) - inverted_residual_setting = _efficientnet_conf(width_mult=1.0, depth_mult=1.0, **kwargs) - return _efficientnet(inverted_residual_setting, dropout=0.2, weights=weights, progress=progress, **kwargs) + return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs) def efficientnet_b1( @@ -183,8 +194,7 @@ def efficientnet_b1( warnings.warn("The argument pretrained is deprecated, please use weights instead.") weights = EfficientNetB1Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None weights = EfficientNetB1Weights.verify(weights) - inverted_residual_setting = _efficientnet_conf(width_mult=1.0, depth_mult=1.1, **kwargs) - return _efficientnet(inverted_residual_setting, dropout=0.2, weights=weights, progress=progress, **kwargs) + return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs) def efficientnet_b2( @@ -194,8 +204,7 @@ def efficientnet_b2( warnings.warn("The argument pretrained is deprecated, please use weights instead.") weights = EfficientNetB2Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None weights = EfficientNetB2Weights.verify(weights) - inverted_residual_setting = _efficientnet_conf(width_mult=1.1, depth_mult=1.2, **kwargs) - return _efficientnet(inverted_residual_setting, dropout=0.3, weights=weights, progress=progress, **kwargs) + return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs) def efficientnet_b3( @@ -205,8 +214,7 @@ def efficientnet_b3( warnings.warn("The argument pretrained is deprecated, please use weights instead.") weights = EfficientNetB3Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None weights = EfficientNetB3Weights.verify(weights) - inverted_residual_setting = _efficientnet_conf(width_mult=1.2, depth_mult=1.4, **kwargs) - return _efficientnet(inverted_residual_setting, dropout=0.3, weights=weights, progress=progress, **kwargs) + return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs) def efficientnet_b4( @@ -216,8 +224,7 @@ def efficientnet_b4( warnings.warn("The argument pretrained is deprecated, please use weights instead.") weights = EfficientNetB4Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None weights = EfficientNetB4Weights.verify(weights) - inverted_residual_setting = _efficientnet_conf(width_mult=1.4, depth_mult=1.8, **kwargs) - return _efficientnet(inverted_residual_setting, dropout=0.4, weights=weights, progress=progress, **kwargs) + return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs) def efficientnet_b5( @@ -227,9 +234,9 @@ def efficientnet_b5( warnings.warn("The argument pretrained is deprecated, please use weights instead.") weights = EfficientNetB5Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None weights = EfficientNetB5Weights.verify(weights) - inverted_residual_setting = _efficientnet_conf(width_mult=1.6, depth_mult=2.2, **kwargs) return _efficientnet( - inverted_residual_setting, + width_mult=1.6, + depth_mult=2.2, dropout=0.4, weights=weights, progress=progress, @@ -245,9 +252,9 @@ def efficientnet_b6( warnings.warn("The argument pretrained is deprecated, please use weights instead.") weights = EfficientNetB6Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None weights = EfficientNetB6Weights.verify(weights) - inverted_residual_setting = _efficientnet_conf(width_mult=1.8, depth_mult=2.6, **kwargs) return _efficientnet( - inverted_residual_setting, + width_mult=1.8, + depth_mult=2.6, dropout=0.5, weights=weights, progress=progress, @@ -263,9 +270,9 @@ def efficientnet_b7( warnings.warn("The argument pretrained is deprecated, please use weights instead.") weights = EfficientNetB7Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None weights = EfficientNetB7Weights.verify(weights) - inverted_residual_setting = _efficientnet_conf(width_mult=2.0, depth_mult=3.1, **kwargs) return _efficientnet( - inverted_residual_setting, + width_mult=2.0, + depth_mult=3.1, dropout=0.5, weights=weights, progress=progress,